diff --git a/README.md b/README.md index 750cdb345..9a6a19f31 100644 --- a/README.md +++ b/README.md @@ -181,6 +181,20 @@ The rate limit block specifies the actual rate limit that will be used when ther Currently the service supports per second, minute, hour, and day limits. More types of limits may be added in the future based on user demand. +### Rate limit algorithm + +Ratelimit supports two algorithms: + +1. Fixed window +For a limit of 60 requests per hour, there can only 60 requests in a single time window (e.g: 01:00 - 01:59). +Fixed window algorithm does not care when did the request arrive, all 60 can arrive at 01:01 or 01:50 and the limit will still reset at 02:00. + +2. Rolling window +For a limit of 60 requests per hour. Initially rate limiter can take a burst of 60 requests at once, then the limit is restored by 1 each minute. Requests are allowed as long as there's still some available limit. + +Configure rate limit algorithm with `RATE_LIMIT_ALGORITHM` environment variable. +Use `FIXED_WINDOW` and `ROLLING_WINDOW` respectively. + ### Examples #### Example 1 diff --git a/go.mod b/go.mod index 1c282fcd9..c1d574070 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/coocood/freecache v1.1.0 github.com/envoyproxy/go-control-plane v0.9.7 github.com/fsnotify/fsnotify v1.4.7 // indirect - github.com/golang/mock v1.4.1 + github.com/golang/mock v1.4.4 github.com/golang/protobuf v1.4.2 github.com/gorilla/mux v1.7.4-0.20191121170500-49c01487a141 github.com/kavu/go_reuseport v1.2.0 @@ -26,4 +26,5 @@ require ( google.golang.org/grpc v1.27.0 google.golang.org/protobuf v1.25.0 // indirect gopkg.in/yaml.v2 v2.3.0 + rsc.io/quote/v3 v3.1.0 // indirect ) diff --git a/go.sum b/go.sum index 071a59b37..f228bbd87 100644 --- a/go.sum +++ b/go.sum @@ -35,6 +35,8 @@ github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfU github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.4.1 h1:ocYkMQY5RrXTYgXl7ICpV0IXwlEQGwKIsery4gyXa1U= github.com/golang/mock v1.4.1/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= +github.com/golang/mock v1.4.4 h1:l75CXGRSwbaYNpl/Z2X1XIIAMSCquvXgpVZDhwEIJsc= +github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= diff --git a/src/algorithm/base_window.go b/src/algorithm/base_window.go new file mode 100644 index 000000000..bd9de5545 --- /dev/null +++ b/src/algorithm/base_window.go @@ -0,0 +1,104 @@ +package algorithm + +import ( + "github.com/coocood/freecache" + pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" + "github.com/envoyproxy/ratelimit/src/config" + "github.com/envoyproxy/ratelimit/src/utils" + logger "github.com/sirupsen/logrus" +) + +type WindowImpl struct { + algorithm RatelimitAlgorithm + cacheKeyGenerator utils.CacheKeyGenerator + localCache *freecache.Cache + timeSource utils.TimeSource +} + +func (w *WindowImpl) GetResponseDescriptorStatus(key string, limit *config.RateLimit, results int64, + isOverLimitWithLocalCache bool, hitsAddend int64) *pb.RateLimitResponse_DescriptorStatus { + if key == "" { + return &pb.RateLimitResponse_DescriptorStatus{ + Code: pb.RateLimitResponse_OK, + CurrentLimit: nil, + LimitRemaining: 0, + } + } + + if isOverLimitWithLocalCache { + PopulateStats(limit, 0, uint64(hitsAddend), uint64(hitsAddend)) + + return &pb.RateLimitResponse_DescriptorStatus{ + Code: pb.RateLimitResponse_OVER_LIMIT, + CurrentLimit: limit.Limit, + LimitRemaining: 0, + DurationUntilReset: w.algorithm.CalculateSimpleReset(limit, w.timeSource), + } + } + + isOverLimit, limitRemaining, durationUntilReset := w.algorithm.IsOverLimit(limit, int64(results), hitsAddend) + + if !isOverLimit { + duration := w.algorithm.CalculateReset(isOverLimit, limit, w.timeSource) + return &pb.RateLimitResponse_DescriptorStatus{ + Code: pb.RateLimitResponse_OK, + CurrentLimit: limit.Limit, + LimitRemaining: uint32(limitRemaining), + DurationUntilReset: duration, + } + } else { + if w.localCache != nil { + durationUntilReset = utils.MaxInt(1, durationUntilReset) + + err := w.localCache.Set([]byte(key), []byte{}, durationUntilReset) + if err != nil { + logger.Errorf("Failing to set local cache key: %s", key) + } + } + duration := w.algorithm.CalculateReset(isOverLimit, limit, w.timeSource) + return &pb.RateLimitResponse_DescriptorStatus{ + Code: pb.RateLimitResponse_OVER_LIMIT, + CurrentLimit: limit.Limit, + LimitRemaining: 0, + DurationUntilReset: duration, + } + } +} + +func (w *WindowImpl) IsOverLimitWithLocalCache(key string) bool { + if w.localCache != nil { + _, err := w.localCache.Get([]byte(key)) + if err == nil { + return true + } + } + return false +} + +func (w *WindowImpl) GenerateCacheKeys(request *pb.RateLimitRequest, + limits []*config.RateLimit, hitsAddend int64, timestamp int64) []utils.CacheKey { + return w.cacheKeyGenerator.GenerateCacheKeys(request, limits, uint32(hitsAddend), timestamp) +} + +func (w *WindowImpl) GetExpirationSeconds() int64 { + return w.algorithm.GetExpirationSeconds() +} + +func (w *WindowImpl) GetResultsAfterIncrease() int64 { + return w.algorithm.GetResultsAfterIncrease() +} + +func PopulateStats(limit *config.RateLimit, nearLimit uint64, overLimit uint64, overLimitWithLocalCache uint64) { + limit.Stats.NearLimit.Add(nearLimit) + limit.Stats.OverLimit.Add(overLimit) + limit.Stats.OverLimitWithLocalCache.Add(overLimitWithLocalCache) +} + +func NewWindow(algorithm RatelimitAlgorithm, cacheKeyPrefix string, localCache *freecache.Cache, timeSource utils.TimeSource) *WindowImpl { + return &WindowImpl{ + algorithm: algorithm, + cacheKeyGenerator: utils.NewCacheKeyGenerator(cacheKeyPrefix), + localCache: localCache, + timeSource: timeSource, + } +} diff --git a/src/algorithm/fixed_window.go b/src/algorithm/fixed_window.go new file mode 100644 index 000000000..8b361f47e --- /dev/null +++ b/src/algorithm/fixed_window.go @@ -0,0 +1,72 @@ +package algorithm + +import ( + "math" + + "github.com/golang/protobuf/ptypes/duration" + + "github.com/coocood/freecache" + "github.com/envoyproxy/ratelimit/src/config" + "github.com/envoyproxy/ratelimit/src/utils" +) + +var _ RatelimitAlgorithm = (*FixedWindowImpl)(nil) + +type FixedWindowImpl struct { + timeSource utils.TimeSource + cacheKeyGenerator utils.CacheKeyGenerator + localCache *freecache.Cache + nearLimitRatio float32 +} + +func (fw *FixedWindowImpl) IsOverLimit(limit *config.RateLimit, results int64, hitsAddend int64) (bool, int64, int) { + limitAfterIncrease := results + limitBeforeIncrease := limitAfterIncrease - int64(hitsAddend) + overLimitThreshold := int64(limit.Limit.RequestsPerUnit) + nearLimitThreshold := int64(math.Floor(float64(float32(overLimitThreshold) * fw.nearLimitRatio))) + + if limitAfterIncrease > overLimitThreshold { + if limitBeforeIncrease >= overLimitThreshold { + PopulateStats(limit, 0, uint64(hitsAddend), 0) + } else { + PopulateStats(limit, uint64(overLimitThreshold-utils.MaxInt64(nearLimitThreshold, limitBeforeIncrease)), uint64(limitAfterIncrease-overLimitThreshold), 0) + } + + return true, 0, int(utils.UnitToDivider(limit.Limit.Unit)) + } else { + if limitAfterIncrease > nearLimitThreshold { + if limitBeforeIncrease >= nearLimitThreshold { + PopulateStats(limit, uint64(hitsAddend), 0, 0) + } else { + PopulateStats(limit, uint64(limitAfterIncrease-nearLimitThreshold), 0, 0) + } + } + + return false, overLimitThreshold - limitAfterIncrease, int(utils.UnitToDivider(limit.Limit.Unit)) + } +} + +func (fw *FixedWindowImpl) GetExpirationSeconds() int64 { + return 0 +} + +func (fw *FixedWindowImpl) GetResultsAfterIncrease() int64 { + return 0 +} + +func (fw *FixedWindowImpl) CalculateSimpleReset(limit *config.RateLimit, timeSource utils.TimeSource) *duration.Duration { + return utils.CalculateFixedReset(limit.Limit, timeSource) +} + +func (fw *FixedWindowImpl) CalculateReset(isOverLimit bool, limit *config.RateLimit, timeSource utils.TimeSource) *duration.Duration { + return fw.CalculateSimpleReset(limit, timeSource) +} + +func NewFixedWindowAlgorithm(timeSource utils.TimeSource, localCache *freecache.Cache, nearLimitRatio float32, cacheKeyPrefix string) *FixedWindowImpl { + return &FixedWindowImpl{ + timeSource: timeSource, + cacheKeyGenerator: utils.NewCacheKeyGenerator(cacheKeyPrefix), + localCache: localCache, + nearLimitRatio: nearLimitRatio, + } +} diff --git a/src/algorithm/ratelimit_algorithm.go b/src/algorithm/ratelimit_algorithm.go new file mode 100644 index 000000000..be2395981 --- /dev/null +++ b/src/algorithm/ratelimit_algorithm.go @@ -0,0 +1,15 @@ +package algorithm + +import ( + "github.com/envoyproxy/ratelimit/src/config" + "github.com/envoyproxy/ratelimit/src/utils" + "github.com/golang/protobuf/ptypes/duration" +) + +type RatelimitAlgorithm interface { + CalculateSimpleReset(limit *config.RateLimit, timeSource utils.TimeSource) *duration.Duration + CalculateReset(isOverLimit bool, limit *config.RateLimit, timeSource utils.TimeSource) *duration.Duration + IsOverLimit(limit *config.RateLimit, results int64, hitsAddend int64) (bool, int64, int) + GetExpirationSeconds() int64 + GetResultsAfterIncrease() int64 +} diff --git a/src/algorithm/rolling_window.go b/src/algorithm/rolling_window.go new file mode 100644 index 000000000..2c0c43ac4 --- /dev/null +++ b/src/algorithm/rolling_window.go @@ -0,0 +1,100 @@ +package algorithm + +import ( + "math" + + "github.com/coocood/freecache" + "github.com/envoyproxy/ratelimit/src/config" + "github.com/envoyproxy/ratelimit/src/utils" + "github.com/golang/protobuf/ptypes/duration" +) + +var _ RatelimitAlgorithm = (*RollingWindowImpl)(nil) + +type RollingWindowImpl struct { + timeSource utils.TimeSource + cacheKeyGenerator utils.CacheKeyGenerator + localCache *freecache.Cache + nearLimitRatio float32 + arrivedAt int64 + tat int64 + newTat int64 + diff int64 +} + +func (rw *RollingWindowImpl) IsOverLimit(limit *config.RateLimit, results int64, hitsAddend int64) (bool, int64, int) { + now := rw.timeSource.UnixNanoNow() + + // Time during computation should be in nanosecond + rw.arrivedAt = now + // Tat is set to current request timestamp if not set before + rw.tat = utils.MaxInt64(results, rw.arrivedAt) + totalLimit := int64(limit.Limit.RequestsPerUnit) + period := utils.SecondsToNanoseconds(utils.UnitToDivider(limit.Limit.Unit)) + quantity := int64(hitsAddend) + + // GCRA computation + // Emission interval is the cost of each request + emissionInterval := period / totalLimit + // New tat define the end of the window + rw.newTat = rw.tat + emissionInterval*quantity + // We allow the request if it's inside the window + allowAt := rw.newTat - period + rw.diff = rw.arrivedAt - allowAt + + previousAllowAt := rw.tat - period + previousLimitRemaining := int64(math.Ceil(float64((rw.arrivedAt - previousAllowAt) / emissionInterval))) + previousLimitRemaining = utils.MaxInt64(previousLimitRemaining, 0) + nearLimitWindow := int64(math.Ceil(float64(float32(limit.Limit.RequestsPerUnit) * (1.0 - rw.nearLimitRatio)))) + limitRemaining := int64(math.Ceil(float64(rw.diff / emissionInterval))) + hitNearLimit := quantity - (utils.MaxInt64(previousLimitRemaining, nearLimitWindow) - nearLimitWindow) + + if rw.diff < 0 { + PopulateStats(limit, uint64(utils.MinInt64(previousLimitRemaining, nearLimitWindow)), uint64(quantity-previousLimitRemaining), 0) + + return true, 0, int(utils.NanosecondsToSeconds(-rw.diff)) + } else { + if hitNearLimit > 0 { + PopulateStats(limit, uint64(hitNearLimit), 0, 0) + } + + return false, limitRemaining, 0 + } +} + +func (rw *RollingWindowImpl) GetExpirationSeconds() int64 { + if rw.diff < 0 { + return utils.NanosecondsToSeconds(rw.tat-rw.arrivedAt) + 1 + } + return utils.NanosecondsToSeconds(rw.newTat-rw.arrivedAt) + 1 +} + +func (rw *RollingWindowImpl) GetResultsAfterIncrease() int64 { + if rw.diff < 0 { + return rw.tat + } + return rw.newTat +} + +func (rw *RollingWindowImpl) CalculateSimpleReset(limit *config.RateLimit, timeSource utils.TimeSource) *duration.Duration { + secondsToReset := utils.UnitToDivider(limit.Limit.Unit) + secondsToReset -= utils.NanosecondsToSeconds(timeSource.UnixNanoNow()) % secondsToReset + return &duration.Duration{Seconds: secondsToReset} +} + +func (rw *RollingWindowImpl) CalculateReset(isOverLimit bool, limit *config.RateLimit, timeSource utils.TimeSource) *duration.Duration { + if !isOverLimit { + return utils.NanosecondsToDuration(rw.newTat - rw.arrivedAt) + } else { + return utils.NanosecondsToDuration(int64(math.Ceil(float64(rw.tat - rw.arrivedAt)))) + } +} + +func NewRollingWindowAlgorithm(timeSource utils.TimeSource, localCache *freecache.Cache, nearLimitRatio float32, cacheKeyPrefix string) *RollingWindowImpl { + return &RollingWindowImpl{ + timeSource: timeSource, + cacheKeyGenerator: utils.NewCacheKeyGenerator(cacheKeyPrefix), + localCache: localCache, + nearLimitRatio: nearLimitRatio, + } +} diff --git a/src/limiter/base_limiter.go b/src/limiter/base_limiter.go deleted file mode 100644 index 4a9aec1de..000000000 --- a/src/limiter/base_limiter.go +++ /dev/null @@ -1,177 +0,0 @@ -package limiter - -import ( - "github.com/coocood/freecache" - pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" - "github.com/envoyproxy/ratelimit/src/assert" - "github.com/envoyproxy/ratelimit/src/config" - "github.com/envoyproxy/ratelimit/src/utils" - logger "github.com/sirupsen/logrus" - "math" - "math/rand" -) - -type BaseRateLimiter struct { - timeSource utils.TimeSource - JitterRand *rand.Rand - ExpirationJitterMaxSeconds int64 - cacheKeyGenerator CacheKeyGenerator - localCache *freecache.Cache - nearLimitRatio float32 -} - -type LimitInfo struct { - limit *config.RateLimit - limitBeforeIncrease uint32 - limitAfterIncrease uint32 - nearLimitThreshold uint32 - overLimitThreshold uint32 -} - -func NewRateLimitInfo(limit *config.RateLimit, limitBeforeIncrease uint32, limitAfterIncrease uint32, - nearLimitThreshold uint32, overLimitThreshold uint32) *LimitInfo { - return &LimitInfo{limit: limit, limitBeforeIncrease: limitBeforeIncrease, limitAfterIncrease: limitAfterIncrease, - nearLimitThreshold: nearLimitThreshold, overLimitThreshold: overLimitThreshold} -} - -// Generates cache keys for given rate limit request. Each cache key is represented by a concatenation of -// domain, descriptor and current timestamp. -func (this *BaseRateLimiter) GenerateCacheKeys(request *pb.RateLimitRequest, - limits []*config.RateLimit, hitsAddend uint32) []CacheKey { - assert.Assert(len(request.Descriptors) == len(limits)) - cacheKeys := make([]CacheKey, len(request.Descriptors)) - now := this.timeSource.UnixNow() - for i := 0; i < len(request.Descriptors); i++ { - // generateCacheKey() returns an empty string in the key if there is no limit - // so that we can keep the arrays all the same size. - cacheKeys[i] = this.cacheKeyGenerator.GenerateCacheKey(request.Domain, request.Descriptors[i], limits[i], now) - // Increase statistics for limits hit by their respective requests. - if limits[i] != nil { - limits[i].Stats.TotalHits.Add(uint64(hitsAddend)) - } - } - return cacheKeys -} - -// Returns `true` in case local cache is enabled and contains value for provided cache key, `false` otherwise. -func (this *BaseRateLimiter) IsOverLimitWithLocalCache(key string) bool { - if this.localCache != nil { - // Get returns the value or not found error. - _, err := this.localCache.Get([]byte(key)) - if err == nil { - return true - } - } - return false -} - -// Generates response descriptor status based on cache key, over the limit with local cache, over the limit and -// near the limit thresholds. Thresholds are checked in order and are mutually exclusive. -func (this *BaseRateLimiter) GetResponseDescriptorStatus(key string, limitInfo *LimitInfo, - isOverLimitWithLocalCache bool, hitsAddend uint32) *pb.RateLimitResponse_DescriptorStatus { - if key == "" { - return this.generateResponseDescriptorStatus(pb.RateLimitResponse_OK, - nil, 0) - } - if isOverLimitWithLocalCache { - limitInfo.limit.Stats.OverLimit.Add(uint64(hitsAddend)) - limitInfo.limit.Stats.OverLimitWithLocalCache.Add(uint64(hitsAddend)) - return this.generateResponseDescriptorStatus(pb.RateLimitResponse_OVER_LIMIT, - limitInfo.limit.Limit, 0) - } - var responseDescriptorStatus *pb.RateLimitResponse_DescriptorStatus - limitInfo.overLimitThreshold = limitInfo.limit.Limit.RequestsPerUnit - // The nearLimitThreshold is the number of requests that can be made before hitting the nearLimitRatio. - // We need to know it in both the OK and OVER_LIMIT scenarios. - limitInfo.nearLimitThreshold = uint32(math.Floor(float64(float32(limitInfo.overLimitThreshold) * this.nearLimitRatio))) - logger.Debugf("cache key: %s current: %d", key, limitInfo.limitAfterIncrease) - if limitInfo.limitAfterIncrease > limitInfo.overLimitThreshold { - responseDescriptorStatus = this.generateResponseDescriptorStatus(pb.RateLimitResponse_OVER_LIMIT, - limitInfo.limit.Limit, 0) - - checkOverLimitThreshold(limitInfo, hitsAddend) - - if this.localCache != nil { - // Set the TTL of the local_cache to be the entire duration. - // Since the cache_key gets changed once the time crosses over current time slot, the over-the-limit - // cache keys in local_cache lose effectiveness. - // For example, if we have an hour limit on all mongo connections, the cache key would be - // similar to mongo_1h, mongo_2h, etc. In the hour 1 (0h0m - 0h59m), the cache key is mongo_1h, we start - // to get ratelimited in the 50th minute, the ttl of local_cache will be set as 1 hour(0h50m-1h49m). - // In the time of 1h1m, since the cache key becomes different (mongo_2h), it won't get ratelimited. - err := this.localCache.Set([]byte(key), []byte{}, int(utils.UnitToDivider(limitInfo.limit.Limit.Unit))) - if err != nil { - logger.Errorf("Failing to set local cache key: %s", key) - } - } - } else { - responseDescriptorStatus = this.generateResponseDescriptorStatus(pb.RateLimitResponse_OK, - limitInfo.limit.Limit, limitInfo.overLimitThreshold-limitInfo.limitAfterIncrease) - - // The limit is OK but we additionally want to know if we are near the limit. - checkNearLimitThreshold(limitInfo, hitsAddend) - } - return responseDescriptorStatus -} - -func NewBaseRateLimit(timeSource utils.TimeSource, jitterRand *rand.Rand, expirationJitterMaxSeconds int64, - localCache *freecache.Cache, nearLimitRatio float32, cacheKeyPrefix string) *BaseRateLimiter { - return &BaseRateLimiter{ - timeSource: timeSource, - JitterRand: jitterRand, - ExpirationJitterMaxSeconds: expirationJitterMaxSeconds, - cacheKeyGenerator: NewCacheKeyGenerator(cacheKeyPrefix), - localCache: localCache, - nearLimitRatio: nearLimitRatio, - } -} - -func checkOverLimitThreshold(limitInfo *LimitInfo, hitsAddend uint32) { - // Increase over limit statistics. Because we support += behavior for increasing the limit, we need to - // assess if the entire hitsAddend were over the limit. That is, if the limit's value before adding the - // N hits was over the limit, then all the N hits were over limit. - // Otherwise, only the difference between the current limit value and the over limit threshold - // were over limit hits. - if limitInfo.limitBeforeIncrease >= limitInfo.overLimitThreshold { - limitInfo.limit.Stats.OverLimit.Add(uint64(hitsAddend)) - } else { - limitInfo.limit.Stats.OverLimit.Add(uint64(limitInfo.limitAfterIncrease - limitInfo.overLimitThreshold)) - - // If the limit before increase was below the over limit value, then some of the hits were - // in the near limit range. - limitInfo.limit.Stats.NearLimit.Add(uint64(limitInfo.overLimitThreshold - - utils.Max(limitInfo.nearLimitThreshold, limitInfo.limitBeforeIncrease))) - } -} - -func checkNearLimitThreshold(limitInfo *LimitInfo, hitsAddend uint32) { - if limitInfo.limitAfterIncrease > limitInfo.nearLimitThreshold { - // Here we also need to assess which portion of the hitsAddend were in the near limit range. - // If all the hits were over the nearLimitThreshold, then all hits are near limit. Otherwise, - // only the difference between the current limit value and the near limit threshold were near - // limit hits. - if limitInfo.limitBeforeIncrease >= limitInfo.nearLimitThreshold { - limitInfo.limit.Stats.NearLimit.Add(uint64(hitsAddend)) - } else { - limitInfo.limit.Stats.NearLimit.Add(uint64(limitInfo.limitAfterIncrease - limitInfo.nearLimitThreshold)) - } - } -} - -func (this *BaseRateLimiter) generateResponseDescriptorStatus(responseCode pb.RateLimitResponse_Code, - limit *pb.RateLimitResponse_RateLimit, limitRemaining uint32) *pb.RateLimitResponse_DescriptorStatus { - if limit != nil { - return &pb.RateLimitResponse_DescriptorStatus{ - Code: responseCode, - CurrentLimit: limit, - LimitRemaining: limitRemaining, - DurationUntilReset: utils.CalculateReset(limit, this.timeSource), - } - } else { - return &pb.RateLimitResponse_DescriptorStatus{ - Code: responseCode, - CurrentLimit: limit, - LimitRemaining: limitRemaining, - } - } -} diff --git a/src/limiter/cache.go b/src/limiter/rate_limit_cache.go similarity index 100% rename from src/limiter/cache.go rename to src/limiter/rate_limit_cache.go diff --git a/src/memcached/cache_impl.go b/src/memcached/cache_impl.go index 892565e8b..de7941d93 100644 --- a/src/memcached/cache_impl.go +++ b/src/memcached/cache_impl.go @@ -16,186 +16,41 @@ package memcached import ( - "context" + "fmt" "math/rand" - "strconv" - "sync" - - "github.com/coocood/freecache" - stats "github.com/lyft/gostats" "github.com/bradfitz/gomemcache/memcache" - - logger "github.com/sirupsen/logrus" - - pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" - - "github.com/envoyproxy/ratelimit/src/config" + "github.com/coocood/freecache" "github.com/envoyproxy/ratelimit/src/limiter" "github.com/envoyproxy/ratelimit/src/settings" "github.com/envoyproxy/ratelimit/src/utils" + stats "github.com/lyft/gostats" ) -type rateLimitMemcacheImpl struct { - client Client - timeSource utils.TimeSource - jitterRand *rand.Rand - expirationJitterMaxSeconds int64 - localCache *freecache.Cache - waitGroup sync.WaitGroup - nearLimitRatio float32 - baseRateLimiter *limiter.BaseRateLimiter -} - var AutoFlushForIntegrationTests bool = false -var _ limiter.RateLimitCache = (*rateLimitMemcacheImpl)(nil) - -func (this *rateLimitMemcacheImpl) DoLimit( - ctx context.Context, - request *pb.RateLimitRequest, - limits []*config.RateLimit) []*pb.RateLimitResponse_DescriptorStatus { - - logger.Debugf("starting cache lookup") - - // request.HitsAddend could be 0 (default value) if not specified by the caller in the Ratelimit request. - hitsAddend := utils.Max(1, request.HitsAddend) - - // First build a list of all cache keys that we are actually going to hit. - cacheKeys := this.baseRateLimiter.GenerateCacheKeys(request, limits, hitsAddend) - - isOverLimitWithLocalCache := make([]bool, len(request.Descriptors)) - - keysToGet := make([]string, 0, len(request.Descriptors)) - - for i, cacheKey := range cacheKeys { - if cacheKey.Key == "" { - continue - } - - // Check if key is over the limit in local cache. - if this.baseRateLimiter.IsOverLimitWithLocalCache(cacheKey.Key) { - isOverLimitWithLocalCache[i] = true - logger.Debugf("cache key is over the limit: %s", cacheKey.Key) - continue - } - - logger.Debugf("looking up cache key: %s", cacheKey.Key) - keysToGet = append(keysToGet, cacheKey.Key) - } - - // Now fetch from memcache. - responseDescriptorStatuses := make([]*pb.RateLimitResponse_DescriptorStatus, - len(request.Descriptors)) - - var memcacheValues map[string]*memcache.Item - var err error - - if len(keysToGet) > 0 { - memcacheValues, err = this.client.GetMulti(keysToGet) - if err != nil { - logger.Errorf("Error multi-getting memcache keys (%s): %s", keysToGet, err) - } - } - - for i, cacheKey := range cacheKeys { - - rawMemcacheValue, ok := memcacheValues[cacheKey.Key] - var limitBeforeIncrease uint32 - if ok { - decoded, err := strconv.ParseInt(string(rawMemcacheValue.Value), 10, 32) - if err != nil { - logger.Errorf("Unexpected non-numeric value in memcached: %v", rawMemcacheValue) - } else { - limitBeforeIncrease = uint32(decoded) - } - - } - - limitAfterIncrease := limitBeforeIncrease + hitsAddend - - limitInfo := limiter.NewRateLimitInfo(limits[i], limitBeforeIncrease, limitAfterIncrease, 0, 0) - - responseDescriptorStatuses[i] = this.baseRateLimiter.GetResponseDescriptorStatus(cacheKey.Key, - limitInfo, isOverLimitWithLocalCache[i], hitsAddend) - } - - this.waitGroup.Add(1) - go this.increaseAsync(cacheKeys, isOverLimitWithLocalCache, limits, uint64(hitsAddend)) - if AutoFlushForIntegrationTests { - this.Flush() - } - - return responseDescriptorStatuses -} - -func (this *rateLimitMemcacheImpl) increaseAsync(cacheKeys []limiter.CacheKey, isOverLimitWithLocalCache []bool, - limits []*config.RateLimit, hitsAddend uint64) { - defer this.waitGroup.Done() - for i, cacheKey := range cacheKeys { - if cacheKey.Key == "" || isOverLimitWithLocalCache[i] { - continue - } - - _, err := this.client.Increment(cacheKey.Key, hitsAddend) - if err == memcache.ErrCacheMiss { - expirationSeconds := utils.UnitToDivider(limits[i].Limit.Unit) - if this.expirationJitterMaxSeconds > 0 { - expirationSeconds += this.jitterRand.Int63n(this.expirationJitterMaxSeconds) - } - - // Need to add instead of increment. - err = this.client.Add(&memcache.Item{ - Key: cacheKey.Key, - Value: []byte(strconv.FormatUint(hitsAddend, 10)), - Expiration: int32(expirationSeconds), - }) - if err == memcache.ErrNotStored { - // There was a race condition to do this add. We should be able to increment - // now instead. - _, err := this.client.Increment(cacheKey.Key, hitsAddend) - if err != nil { - logger.Errorf("Failed to increment key %s after failing to add: %s", cacheKey.Key, err) - continue - } - } else if err != nil { - logger.Errorf("Failed to add key %s: %s", cacheKey.Key, err) - continue - } - } else if err != nil { - logger.Errorf("Failed to increment key %s: %s", cacheKey.Key, err) - continue - } +func NewRateLimitCacheImplFromSettings(s settings.Settings, timeSource utils.TimeSource, jitterRand *rand.Rand, + localCache *freecache.Cache, scope stats.Scope) (limiter.RateLimitCache, error) { + if s.RateLimitAlgorithm == settings.FixedRateLimit { + return NewFixedRateLimitCacheImpl( + CollectStats(memcache.New(s.MemcacheHostPort), scope.Scope("memcache")), + timeSource, + jitterRand, + s.ExpirationJitterMaxSeconds, + localCache, + s.NearLimitRatio, + s.CacheKeyPrefix), nil } -} - -func (this *rateLimitMemcacheImpl) Flush() { - this.waitGroup.Wait() -} - -func NewRateLimitCacheImpl(client Client, timeSource utils.TimeSource, jitterRand *rand.Rand, - expirationJitterMaxSeconds int64, localCache *freecache.Cache, scope stats.Scope, nearLimitRatio float32, cacheKeyPrefix string) limiter.RateLimitCache { - return &rateLimitMemcacheImpl{ - client: client, - timeSource: timeSource, - jitterRand: jitterRand, - expirationJitterMaxSeconds: expirationJitterMaxSeconds, - localCache: localCache, - nearLimitRatio: nearLimitRatio, - baseRateLimiter: limiter.NewBaseRateLimit(timeSource, jitterRand, expirationJitterMaxSeconds, localCache, nearLimitRatio, cacheKeyPrefix), + if s.RateLimitAlgorithm == settings.WindowedRateLimit { + + return NewWindowedRateLimitCacheImpl( + CollectStats(memcache.New(s.MemcacheHostPort), scope.Scope("memcache")), + timeSource, + jitterRand, + s.ExpirationJitterMaxSeconds, + localCache, + s.NearLimitRatio, + s.CacheKeyPrefix), nil } -} - -func NewRateLimitCacheImplFromSettings(s settings.Settings, timeSource utils.TimeSource, jitterRand *rand.Rand, - localCache *freecache.Cache, scope stats.Scope) limiter.RateLimitCache { - return NewRateLimitCacheImpl( - CollectStats(memcache.New(s.MemcacheHostPort), scope.Scope("memcache")), - timeSource, - jitterRand, - s.ExpirationJitterMaxSeconds, - localCache, - scope, - s.NearLimitRatio, - s.CacheKeyPrefix, - ) + return nil, fmt.Errorf("Unknown rate limit algorithm. %s\n", s.RateLimitAlgorithm) } diff --git a/src/memcached/client.go b/src/memcached/driver/client.go similarity index 87% rename from src/memcached/client.go rename to src/memcached/driver/client.go index 55c0ec318..7cbcbc1cb 100644 --- a/src/memcached/client.go +++ b/src/memcached/driver/client.go @@ -1,4 +1,4 @@ -package memcached +package driver import ( "github.com/bradfitz/gomemcache/memcache" @@ -11,4 +11,5 @@ type Client interface { GetMulti(keys []string) (map[string]*memcache.Item, error) Increment(key string, delta uint64) (newValue uint64, err error) Add(item *memcache.Item) error + Set(item *memcache.Item) error } diff --git a/src/memcached/fixed_cache_impl.go b/src/memcached/fixed_cache_impl.go new file mode 100644 index 000000000..ff2cb50cb --- /dev/null +++ b/src/memcached/fixed_cache_impl.go @@ -0,0 +1,166 @@ +package memcached + +import ( + "context" + "math/rand" + "strconv" + "sync" + + "github.com/bradfitz/gomemcache/memcache" + "github.com/coocood/freecache" + pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" + "github.com/envoyproxy/ratelimit/src/algorithm" + "github.com/envoyproxy/ratelimit/src/config" + "github.com/envoyproxy/ratelimit/src/limiter" + "github.com/envoyproxy/ratelimit/src/memcached/driver" + "github.com/envoyproxy/ratelimit/src/utils" + logger "github.com/sirupsen/logrus" +) + +type fixedRateLimitCacheImpl struct { + client driver.Client + timeSource utils.TimeSource + jitterRand *rand.Rand + expirationJitterMaxSeconds int64 + localCache *freecache.Cache + waitGroup sync.WaitGroup + nearLimitRatio float32 + algorithm *algorithm.WindowImpl +} + +var _ limiter.RateLimitCache = (*fixedRateLimitCacheImpl)(nil) + +func (this *fixedRateLimitCacheImpl) DoLimit( + ctx context.Context, + request *pb.RateLimitRequest, + limits []*config.RateLimit) []*pb.RateLimitResponse_DescriptorStatus { + + logger.Debugf("starting cache lookup") + + // request.HitsAddend could be 0 (default value) if not specified by the caller in the Ratelimit request. + hitsAddend := utils.MaxInt64(1, int64(request.HitsAddend)) + + // First build a list of all cache keys that we are actually going to hit. + cacheKeys := this.algorithm.GenerateCacheKeys(request, limits, hitsAddend, this.timeSource.UnixNow()) + + isOverLimitWithLocalCache := make([]bool, len(request.Descriptors)) + keysToGet := make([]string, 0, len(request.Descriptors)) + + for i, cacheKey := range cacheKeys { + if cacheKey.Key == "" { + continue + } + + // Check if key is over the limit in local cache. + if this.algorithm.IsOverLimitWithLocalCache(cacheKey.Key) { + isOverLimitWithLocalCache[i] = true + logger.Debugf("cache key is over the limit: %s", cacheKey.Key) + continue + } + + logger.Debugf("looking up cache key: %s", cacheKey.Key) + keysToGet = append(keysToGet, cacheKey.Key) + } + + // Now fetch from memcache. + responseDescriptorStatuses := make([]*pb.RateLimitResponse_DescriptorStatus, + len(request.Descriptors)) + + var memcacheValues map[string]*memcache.Item + var err error + + if len(keysToGet) > 0 { + memcacheValues, err = this.client.GetMulti(keysToGet) + if err != nil { + logger.Errorf("Error multi-getting memcache keys (%s): %s", keysToGet, err) + } + } + + for i, cacheKey := range cacheKeys { + rawMemcacheValue, ok := memcacheValues[cacheKey.Key] + var result int64 + if ok { + decoded, err := strconv.ParseInt(string(rawMemcacheValue.Value), 10, 32) + if err != nil { + logger.Errorf("Unexpected non-numeric value in memcached: %v", rawMemcacheValue) + } else { + result = decoded + } + } + + resultAfterIncrease := result + hitsAddend + responseDescriptorStatuses[i] = this.algorithm.GetResponseDescriptorStatus(cacheKey.Key, limits[i], resultAfterIncrease, isOverLimitWithLocalCache[i], int64(hitsAddend)) + } + + this.waitGroup.Add(1) + go this.increaseAsync(cacheKeys, isOverLimitWithLocalCache, limits, uint64(hitsAddend)) + + if AutoFlushForIntegrationTests { + this.Flush() + } + + return responseDescriptorStatuses +} + +func (this *fixedRateLimitCacheImpl) increaseAsync(cacheKeys []utils.CacheKey, isOverLimitWithLocalCache []bool, + limits []*config.RateLimit, hitsAddend uint64) { + defer this.waitGroup.Done() + for i, cacheKey := range cacheKeys { + if cacheKey.Key == "" || isOverLimitWithLocalCache[i] { + continue + } + + _, err := this.client.Increment(cacheKey.Key, hitsAddend) + if err == memcache.ErrCacheMiss { + expirationSeconds := utils.UnitToDivider(limits[i].Limit.Unit) + if this.expirationJitterMaxSeconds > 0 { + expirationSeconds += this.jitterRand.Int63n(this.expirationJitterMaxSeconds) + } + + // Need to add instead of increment. + err = this.client.Add(&memcache.Item{ + Key: cacheKey.Key, + Value: []byte(strconv.FormatUint(hitsAddend, 10)), + Expiration: int32(expirationSeconds), + }) + + if err == memcache.ErrNotStored { + // There was a race condition to do this add. We should be able to increment + // now instead. + _, err := this.client.Increment(cacheKey.Key, hitsAddend) + if err != nil { + logger.Errorf("Failed to increment key %s after failing to add: %s", cacheKey.Key, err) + continue + } + } else if err != nil { + logger.Errorf("Failed to add key %s: %s", cacheKey.Key, err) + continue + } + } else if err != nil { + logger.Errorf("Failed to increment key %s: %s", cacheKey.Key, err) + continue + } + } +} + +func (this *fixedRateLimitCacheImpl) Flush() { + this.waitGroup.Wait() +} + +func NewFixedRateLimitCacheImpl(client driver.Client, timeSource utils.TimeSource, jitterRand *rand.Rand, + expirationJitterMaxSeconds int64, localCache *freecache.Cache, nearLimitRatio float32, cacheKeyPrefix string) limiter.RateLimitCache { + return &fixedRateLimitCacheImpl{ + client: client, + timeSource: timeSource, + jitterRand: jitterRand, + expirationJitterMaxSeconds: expirationJitterMaxSeconds, + localCache: localCache, + nearLimitRatio: nearLimitRatio, + algorithm: algorithm.NewWindow( + algorithm.NewFixedWindowAlgorithm(timeSource, localCache, nearLimitRatio, cacheKeyPrefix), + cacheKeyPrefix, + localCache, + timeSource, + ), + } +} diff --git a/src/memcached/stats_collecting_client.go b/src/memcached/stats_collecting_client.go index 12b67bad5..b2f1408b2 100644 --- a/src/memcached/stats_collecting_client.go +++ b/src/memcached/stats_collecting_client.go @@ -2,12 +2,12 @@ package memcached import ( "github.com/bradfitz/gomemcache/memcache" + "github.com/envoyproxy/ratelimit/src/memcached/driver" stats "github.com/lyft/gostats" ) type statsCollectingClient struct { - c Client - + c driver.Client multiGetSuccess stats.Counter multiGetError stats.Counter incrementSuccess stats.Counter @@ -16,11 +16,13 @@ type statsCollectingClient struct { addSuccess stats.Counter addError stats.Counter addNotStored stats.Counter + setSuccess stats.Counter + setError stats.Counter keysRequested stats.Counter keysFound stats.Counter } -func CollectStats(c Client, scope stats.Scope) Client { +func CollectStats(c driver.Client, scope stats.Scope) driver.Client { return statsCollectingClient{ c: c, multiGetSuccess: scope.NewCounterWithTags("multiget", map[string]string{"code": "success"}), @@ -31,6 +33,8 @@ func CollectStats(c Client, scope stats.Scope) Client { addSuccess: scope.NewCounterWithTags("add", map[string]string{"code": "success"}), addError: scope.NewCounterWithTags("add", map[string]string{"code": "error"}), addNotStored: scope.NewCounterWithTags("add", map[string]string{"code": "not_stored"}), + setSuccess: scope.NewCounterWithTags("set", map[string]string{"code": "success"}), + setError: scope.NewCounterWithTags("set", map[string]string{"code": "error"}), keysRequested: scope.NewCounter("keys_requested"), keysFound: scope.NewCounter("keys_found"), } @@ -78,3 +82,16 @@ func (scc statsCollectingClient) Add(item *memcache.Item) error { return err } + +func (scc statsCollectingClient) Set(item *memcache.Item) error { + err := scc.c.Set(item) + + switch err { + case nil: + scc.setSuccess.Inc() + default: + scc.setError.Inc() + } + + return err +} diff --git a/src/memcached/windowed_cache_impl.go b/src/memcached/windowed_cache_impl.go new file mode 100644 index 000000000..cbc20088c --- /dev/null +++ b/src/memcached/windowed_cache_impl.go @@ -0,0 +1,160 @@ +package memcached + +import ( + "context" + "math/rand" + "strconv" + "sync" + + "github.com/bradfitz/gomemcache/memcache" + "github.com/coocood/freecache" + pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" + "github.com/envoyproxy/ratelimit/src/algorithm" + "github.com/envoyproxy/ratelimit/src/config" + "github.com/envoyproxy/ratelimit/src/limiter" + "github.com/envoyproxy/ratelimit/src/memcached/driver" + "github.com/envoyproxy/ratelimit/src/utils" + logger "github.com/sirupsen/logrus" +) + +type windowedRateLimitCacheImpl struct { + client driver.Client + timeSource utils.TimeSource + jitterRand *rand.Rand + expirationJitterMaxSeconds int64 + localCache *freecache.Cache + waitGroup sync.WaitGroup + nearLimitRatio float32 + algorithm *algorithm.WindowImpl +} + +var _ limiter.RateLimitCache = (*windowedRateLimitCacheImpl)(nil) + +const DummyCacheKeyTime = 0 + +func (this *windowedRateLimitCacheImpl) DoLimit( + ctx context.Context, + request *pb.RateLimitRequest, + limits []*config.RateLimit) []*pb.RateLimitResponse_DescriptorStatus { + + logger.Debugf("starting cache lookup") + + // request.HitsAddend could be 0 (default value) if not specified by the caller in the Ratelimit request. + hitsAddend := utils.MaxInt64(1, int64(request.HitsAddend)) + + // First build a list of all cache keys that we are actually going to hit. + cacheKeys := this.algorithm.GenerateCacheKeys(request, limits, hitsAddend, DummyCacheKeyTime) + + isOverLimitWithLocalCache := make([]bool, len(request.Descriptors)) + keysToGet := make([]string, 0, len(request.Descriptors)) + + for i, cacheKey := range cacheKeys { + if cacheKey.Key == "" { + continue + } + + // Check if key is over the limit in local cache. + if this.algorithm.IsOverLimitWithLocalCache(cacheKey.Key) { + isOverLimitWithLocalCache[i] = true + logger.Debugf("cache key is over the limit: %s", cacheKey.Key) + continue + } + + logger.Debugf("looking up cache key: %s", cacheKey.Key) + keysToGet = append(keysToGet, cacheKey.Key) + } + + // Now fetch from memcached. + responseDescriptorStatuses := make([]*pb.RateLimitResponse_DescriptorStatus, + len(request.Descriptors)) + + var memcacheValues map[string]*memcache.Item + var err error + + if len(keysToGet) > 0 { + memcacheValues, err = this.client.GetMulti(keysToGet) + if err != nil { + logger.Errorf("Error multi-getting memcache keys (%s): %s", keysToGet, err) + } + } + + newTats := make([]int64, len(cacheKeys)) + isOverLimit := make([]bool, len(cacheKeys)) + expirationSeconds := make([]int64, len(cacheKeys)) + + for i, cacheKey := range cacheKeys { + rawMemcacheValue, ok := memcacheValues[cacheKey.Key] + var tat int64 + if ok { + tat, err = strconv.ParseInt(string(rawMemcacheValue.Value), 10, 64) + if err != nil { + logger.Errorf("Unexpected non-numeric value in memcached: %v", rawMemcacheValue) + } + } + + responseDescriptorStatuses[i] = this.algorithm.GetResponseDescriptorStatus(cacheKey.Key, limits[i], tat, isOverLimitWithLocalCache[i], int64(hitsAddend)) + + if responseDescriptorStatuses[i].Code == pb.RateLimitResponse_OVER_LIMIT { + isOverLimit[i] = true + } else { + isOverLimit[i] = false + } + + newTats[i] = this.algorithm.GetResultsAfterIncrease() + expirationSeconds[i] = this.algorithm.GetExpirationSeconds() + if this.expirationJitterMaxSeconds > 0 { + expirationSeconds[i] += this.jitterRand.Int63n(this.expirationJitterMaxSeconds) + } + } + + this.waitGroup.Add(1) + go this.increaseAsync(isOverLimitWithLocalCache, isOverLimit, cacheKeys, expirationSeconds, newTats) + + if AutoFlushForIntegrationTests { + this.Flush() + } + + return responseDescriptorStatuses +} + +func (this *windowedRateLimitCacheImpl) increaseAsync(isOverLimitWithLocalCache []bool, isOverLimit []bool, cacheKeys []utils.CacheKey, expirationSeconds []int64, newTats []int64) { + defer this.waitGroup.Done() + for i, cacheKey := range cacheKeys { + if cacheKey.Key == "" || isOverLimitWithLocalCache[i] { + continue + } + + err := this.client.Set(&memcache.Item{ + Key: cacheKey.Key, + Value: []byte(strconv.FormatInt(newTats[i], 10)), + Expiration: int32(expirationSeconds[i]), + }) + + if err != nil { + logger.Errorf("Failed to set key %s: %s", cacheKey.Key, err) + continue + } + } +} + +func (this *windowedRateLimitCacheImpl) Flush() { + this.waitGroup.Wait() +} + +func NewWindowedRateLimitCacheImpl(client driver.Client, timeSource utils.TimeSource, jitterRand *rand.Rand, + expirationJitterMaxSeconds int64, localCache *freecache.Cache, nearLimitRatio float32, cacheKeyPrefix string) limiter.RateLimitCache { + return &windowedRateLimitCacheImpl{ + client: client, + timeSource: timeSource, + jitterRand: jitterRand, + expirationJitterMaxSeconds: expirationJitterMaxSeconds, + localCache: localCache, + nearLimitRatio: nearLimitRatio, + algorithm: algorithm.NewWindow( + algorithm.NewRollingWindowAlgorithm(timeSource, localCache, nearLimitRatio, cacheKeyPrefix), + cacheKeyPrefix, + localCache, + timeSource, + ), + } +} diff --git a/src/redis/cache_impl.go b/src/redis/cache_impl.go index 7e619b66a..09aba59a6 100644 --- a/src/redis/cache_impl.go +++ b/src/redis/cache_impl.go @@ -1,33 +1,49 @@ package redis import ( + "fmt" "math/rand" "github.com/coocood/freecache" + "github.com/envoyproxy/ratelimit/src/limiter" + "github.com/envoyproxy/ratelimit/src/redis/driver" "github.com/envoyproxy/ratelimit/src/server" "github.com/envoyproxy/ratelimit/src/settings" "github.com/envoyproxy/ratelimit/src/utils" ) -func NewRateLimiterCacheImplFromSettings(s settings.Settings, localCache *freecache.Cache, srv server.Server, timeSource utils.TimeSource, jitterRand *rand.Rand, expirationJitterMaxSeconds int64) limiter.RateLimitCache { - var perSecondPool Client +func NewRateLimiterCacheImplFromSettings(s settings.Settings, localCache *freecache.Cache, srv server.Server, timeSource utils.TimeSource, jitterRand *rand.Rand, expirationJitterMaxSeconds int64) (limiter.RateLimitCache, error) { + var perSecondPool driver.Client if s.RedisPerSecond { - perSecondPool = NewClientImpl(srv.Scope().Scope("redis_per_second_pool"), s.RedisPerSecondTls, s.RedisPerSecondAuth, + perSecondPool = driver.NewClientImpl(srv.Scope().Scope("redis_per_second_pool"), s.RedisPerSecondTls, s.RedisPerSecondAuth, s.RedisPerSecondType, s.RedisPerSecondUrl, s.RedisPerSecondPoolSize, s.RedisPipelineWindow, s.RedisPipelineLimit) } - var otherPool Client - otherPool = NewClientImpl(srv.Scope().Scope("redis_pool"), s.RedisTls, s.RedisAuth, s.RedisType, s.RedisUrl, s.RedisPoolSize, + var otherPool driver.Client + otherPool = driver.NewClientImpl(srv.Scope().Scope("redis_pool"), s.RedisTls, s.RedisAuth, s.RedisType, s.RedisUrl, s.RedisPoolSize, s.RedisPipelineWindow, s.RedisPipelineLimit) - return NewFixedRateLimitCacheImpl( - otherPool, - perSecondPool, - timeSource, - jitterRand, - expirationJitterMaxSeconds, - localCache, - s.NearLimitRatio, - s.CacheKeyPrefix, - ) + if s.RateLimitAlgorithm == settings.FixedRateLimit { + return NewFixedRateLimitCacheImpl( + otherPool, + perSecondPool, + timeSource, + jitterRand, + expirationJitterMaxSeconds, + localCache, + s.NearLimitRatio, + s.CacheKeyPrefix), nil + } + if s.RateLimitAlgorithm == settings.WindowedRateLimit { + return NewWindowedRateLimitCacheImpl( + otherPool, + perSecondPool, + timeSource, + jitterRand, + expirationJitterMaxSeconds, + localCache, + s.NearLimitRatio, + s.CacheKeyPrefix), nil + } + return nil, fmt.Errorf("Unknown rate limit algorithm. %s\n", s.RateLimitAlgorithm) } diff --git a/src/redis/driver.go b/src/redis/driver/driver.go similarity index 99% rename from src/redis/driver.go rename to src/redis/driver/driver.go index 7ffc0c7b7..70bb8bfdd 100644 --- a/src/redis/driver.go +++ b/src/redis/driver/driver.go @@ -1,4 +1,4 @@ -package redis +package driver import "github.com/mediocregopher/radix/v3" diff --git a/src/redis/driver_impl.go b/src/redis/driver/driver_impl.go similarity index 95% rename from src/redis/driver_impl.go rename to src/redis/driver/driver_impl.go index 18e213f1b..cf563207c 100644 --- a/src/redis/driver_impl.go +++ b/src/redis/driver/driver_impl.go @@ -1,4 +1,4 @@ -package redis +package driver import ( "crypto/tls" @@ -46,7 +46,7 @@ type clientImpl struct { implicitPipelining bool } -func checkError(err error) { +func CheckError(err error) { if err != nil { panic(RedisError(err.Error())) } @@ -114,13 +114,13 @@ func NewClientImpl(scope stats.Scope, useTls bool, auth string, redisType string panic(RedisError("Unrecognized redis type " + redisType)) } - checkError(err) + CheckError(err) // Check if connection is good var pingResponse string - checkError(client.Do(radix.Cmd(&pingResponse, "PING"))) + CheckError(client.Do(radix.Cmd(&pingResponse, "PING"))) if pingResponse != "PONG" { - checkError(fmt.Errorf("connecting redis error: %s", pingResponse)) + CheckError(fmt.Errorf("connecting redis error: %s", pingResponse)) } return &clientImpl{ diff --git a/src/redis/fixed_cache_impl.go b/src/redis/fixed_cache_impl.go index b2b3d3d24..d09e8940f 100644 --- a/src/redis/fixed_cache_impl.go +++ b/src/redis/fixed_cache_impl.go @@ -5,26 +5,28 @@ import ( "github.com/coocood/freecache" pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" + "github.com/envoyproxy/ratelimit/src/algorithm" "github.com/envoyproxy/ratelimit/src/config" "github.com/envoyproxy/ratelimit/src/limiter" + "github.com/envoyproxy/ratelimit/src/redis/driver" "github.com/envoyproxy/ratelimit/src/utils" logger "github.com/sirupsen/logrus" "golang.org/x/net/context" ) type fixedRateLimitCacheImpl struct { - client Client + client driver.Client // Optional Client for a dedicated cache of per second limits. // If this client is nil, then the Cache will use the client for all // limits regardless of unit. If this client is not nil, then it // is used for limits that have a SECOND unit. - perSecondClient Client - baseRateLimiter *limiter.BaseRateLimiter -} - -func pipelineAppend(client Client, pipeline *Pipeline, key string, hitsAddend uint32, result *uint32, expirationSeconds int64) { - *pipeline = client.PipeAppend(*pipeline, result, "INCRBY", key, hitsAddend) - *pipeline = client.PipeAppend(*pipeline, nil, "EXPIRE", key, expirationSeconds) + perSecondClient driver.Client + timeSource utils.TimeSource + jitterRand *rand.Rand + expirationJitterMaxSeconds int64 + localCache *freecache.Cache + nearLimitRatio float32 + algorithm *algorithm.WindowImpl } func (this *fixedRateLimitCacheImpl) DoLimit( @@ -34,15 +36,15 @@ func (this *fixedRateLimitCacheImpl) DoLimit( logger.Debugf("starting cache lookup") - // request.HitsAddend could be 0 (default value) if not specified by the caller in the RateLimit request. - hitsAddend := utils.Max(1, request.HitsAddend) + // request.HitsAddend could be 0 (default value) if not specified by the caller in the Ratelimit request. + hitsAddend := utils.MaxInt64(1, int64(request.HitsAddend)) // First build a list of all cache keys that we are actually going to hit. - cacheKeys := this.baseRateLimiter.GenerateCacheKeys(request, limits, hitsAddend) + cacheKeys := this.algorithm.GenerateCacheKeys(request, limits, hitsAddend, this.timeSource.UnixNow()) isOverLimitWithLocalCache := make([]bool, len(request.Descriptors)) - results := make([]uint32, len(request.Descriptors)) - var pipeline, perSecondPipeline Pipeline + results := make([]int64, len(request.Descriptors)) + var pipeline, perSecondPipeline driver.Pipeline // Now, actually setup the pipeline, skipping empty cache keys. for i, cacheKey := range cacheKeys { @@ -51,66 +53,70 @@ func (this *fixedRateLimitCacheImpl) DoLimit( } // Check if key is over the limit in local cache. - if this.baseRateLimiter.IsOverLimitWithLocalCache(cacheKey.Key) { + if this.algorithm.IsOverLimitWithLocalCache(cacheKey.Key) { isOverLimitWithLocalCache[i] = true logger.Debugf("cache key is over the limit: %s", cacheKey.Key) continue } - logger.Debugf("looking up cache key: %s", cacheKey.Key) - expirationSeconds := utils.UnitToDivider(limits[i].Limit.Unit) - if this.baseRateLimiter.ExpirationJitterMaxSeconds > 0 { - expirationSeconds += this.baseRateLimiter.JitterRand.Int63n(this.baseRateLimiter.ExpirationJitterMaxSeconds) + if this.expirationJitterMaxSeconds > 0 { + expirationSeconds += this.jitterRand.Int63n(this.expirationJitterMaxSeconds) } // Use the perSecondConn if it is not nil and the cacheKey represents a per second Limit. if this.perSecondClient != nil && cacheKey.PerSecond { if perSecondPipeline == nil { - perSecondPipeline = Pipeline{} + perSecondPipeline = driver.Pipeline{} } - pipelineAppend(this.perSecondClient, &perSecondPipeline, cacheKey.Key, hitsAddend, &results[i], expirationSeconds) + fixedPipelineAppend(this.perSecondClient, &perSecondPipeline, cacheKey.Key, hitsAddend, &results[i], expirationSeconds) } else { if pipeline == nil { - pipeline = Pipeline{} + pipeline = driver.Pipeline{} } - pipelineAppend(this.client, &pipeline, cacheKey.Key, hitsAddend, &results[i], expirationSeconds) + fixedPipelineAppend(this.client, &pipeline, cacheKey.Key, hitsAddend, &results[i], expirationSeconds) } } if pipeline != nil { - checkError(this.client.PipeDo(pipeline)) + driver.CheckError(this.client.PipeDo(pipeline)) } if perSecondPipeline != nil { - checkError(this.perSecondClient.PipeDo(perSecondPipeline)) + driver.CheckError(this.perSecondClient.PipeDo(perSecondPipeline)) } // Now fetch the pipeline. responseDescriptorStatuses := make([]*pb.RateLimitResponse_DescriptorStatus, len(request.Descriptors)) - for i, cacheKey := range cacheKeys { - - limitAfterIncrease := results[i] - limitBeforeIncrease := limitAfterIncrease - hitsAddend - - limitInfo := limiter.NewRateLimitInfo(limits[i], limitBeforeIncrease, limitAfterIncrease, 0, 0) - - responseDescriptorStatuses[i] = this.baseRateLimiter.GetResponseDescriptorStatus(cacheKey.Key, - limitInfo, isOverLimitWithLocalCache[i], hitsAddend) + for i, cacheKey := range cacheKeys { + responseDescriptorStatuses[i] = this.algorithm.GetResponseDescriptorStatus(cacheKey.Key, limits[i], results[i], isOverLimitWithLocalCache[i], int64(hitsAddend)) } return responseDescriptorStatuses } -// Flush() is a no-op with redis since quota reads and updates happen synchronously. func (this *fixedRateLimitCacheImpl) Flush() {} -func NewFixedRateLimitCacheImpl(client Client, perSecondClient Client, timeSource utils.TimeSource, - jitterRand *rand.Rand, expirationJitterMaxSeconds int64, localCache *freecache.Cache, nearLimitRatio float32, cacheKeyPrefix string) limiter.RateLimitCache { +func fixedPipelineAppend(client driver.Client, pipeline *driver.Pipeline, key string, hitsAddend int64, result *int64, expirationSeconds int64) { + *pipeline = client.PipeAppend(*pipeline, result, "INCRBY", key, hitsAddend) + *pipeline = client.PipeAppend(*pipeline, nil, "EXPIRE", key, expirationSeconds) +} + +func NewFixedRateLimitCacheImpl(client driver.Client, perSecondClient driver.Client, timeSource utils.TimeSource, jitterRand *rand.Rand, expirationJitterMaxSeconds int64, localCache *freecache.Cache, nearLimitRatio float32, cacheKeyPrefix string) limiter.RateLimitCache { return &fixedRateLimitCacheImpl{ - client: client, - perSecondClient: perSecondClient, - baseRateLimiter: limiter.NewBaseRateLimit(timeSource, jitterRand, expirationJitterMaxSeconds, localCache, nearLimitRatio, cacheKeyPrefix), + client: client, + perSecondClient: perSecondClient, + timeSource: timeSource, + jitterRand: jitterRand, + expirationJitterMaxSeconds: expirationJitterMaxSeconds, + localCache: localCache, + nearLimitRatio: nearLimitRatio, + algorithm: algorithm.NewWindow( + algorithm.NewFixedWindowAlgorithm(timeSource, localCache, nearLimitRatio, cacheKeyPrefix), + cacheKeyPrefix, + localCache, + timeSource, + ), } } diff --git a/src/redis/windowed_cache_impl.go b/src/redis/windowed_cache_impl.go new file mode 100644 index 000000000..1f9556869 --- /dev/null +++ b/src/redis/windowed_cache_impl.go @@ -0,0 +1,168 @@ +package redis + +import ( + "math/rand" + + "github.com/coocood/freecache" + pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" + "github.com/envoyproxy/ratelimit/src/algorithm" + "github.com/envoyproxy/ratelimit/src/config" + "github.com/envoyproxy/ratelimit/src/limiter" + "github.com/envoyproxy/ratelimit/src/redis/driver" + "github.com/envoyproxy/ratelimit/src/utils" + logger "github.com/sirupsen/logrus" + "golang.org/x/net/context" +) + +// This rolling window limit implemented using Generic Cell Rate Algorithm (GCRA) +// GCRA works by tracking remaining limit through a time called the “theoretical arrival time” (TAT). +// Request cost is represented as a multiplier of “emission interval”, which is derived from the duration of equally spread request. +// TAT is seeded by the current request arrival if not set then add the request costs. +// Subtract the window duration from TAT to get the time to allow a request +// Requests are allowed if the time to allow a request is in the past +// Store the TAT for next process +// https://blog.ian.stapletoncordas.co/2018/12/understanding-generic-cell-rate-limiting.html + +type windowedRateLimitCacheImpl struct { + client driver.Client + // Optional Client for a dedicated cache of per second limits. + // If this client is nil, then the Cache will use the client for all + // limits regardless of unit. If this client is not nil, then it + // is used for limits that have a SECOND unit. + perSecondClient driver.Client + timeSource utils.TimeSource + jitterRand *rand.Rand + expirationJitterMaxSeconds int64 + localCache *freecache.Cache + nearLimitRatio float32 + algorithm *algorithm.WindowImpl +} + +const DummyCacheKeyTime = 0 + +func (this *windowedRateLimitCacheImpl) DoLimit( + ctx context.Context, + request *pb.RateLimitRequest, + limits []*config.RateLimit) []*pb.RateLimitResponse_DescriptorStatus { + + logger.Debugf("starting windowed cache lookup") + + // request.HitsAddend could be 0 (default value) if not specified by the caller in the Ratelimit request. + hitsAddend := utils.MaxInt64(1, int64(request.HitsAddend)) + + // First build a list of all cache keys that we are actually going to hit. + cacheKeys := this.algorithm.GenerateCacheKeys(request, limits, hitsAddend, DummyCacheKeyTime) + + isOverLimitWithLocalCache := make([]bool, len(request.Descriptors)) + tats := make([]int64, len(request.Descriptors)) + var pipeline, perSecondPipeline driver.Pipeline + + // Now, actually setup the pipeline, skipping empty cache keys. + for i, cacheKey := range cacheKeys { + if cacheKey.Key == "" { + continue + } + + // Check if key is over the limit in local cache. + if this.algorithm.IsOverLimitWithLocalCache(cacheKey.Key) { + isOverLimitWithLocalCache[i] = true + logger.Debugf("cache key is over the limit: %s", cacheKey.Key) + continue + } + + logger.Debugf("looking up tat for cache key: %s", cacheKey.Key) + + expirationSeconds := utils.UnitToDivider(limits[i].Limit.Unit) + + // Use the perSecondConn if it is not nil and the cacheKey represents a per second Limit. + if this.perSecondClient != nil && cacheKey.PerSecond { + if perSecondPipeline == nil { + perSecondPipeline = driver.Pipeline{} + } + windowedPipelineAppend(this.perSecondClient, &perSecondPipeline, cacheKey.Key, &tats[i], expirationSeconds) + } else { + if pipeline == nil { + pipeline = driver.Pipeline{} + } + windowedPipelineAppend(this.client, &pipeline, cacheKey.Key, &tats[i], expirationSeconds) + } + } + + if pipeline != nil { + driver.CheckError(this.client.PipeDo(pipeline)) + pipeline = nil + } + if perSecondPipeline != nil { + driver.CheckError(this.perSecondClient.PipeDo(perSecondPipeline)) + perSecondPipeline = nil + } + + responseDescriptorStatuses := make([]*pb.RateLimitResponse_DescriptorStatus, len(request.Descriptors)) + + for i, cacheKey := range cacheKeys { + responseDescriptorStatuses[i] = this.algorithm.GetResponseDescriptorStatus(cacheKey.Key, limits[i], int64(tats[i]), isOverLimitWithLocalCache[i], int64(hitsAddend)) + + if cacheKey.Key == "" || isOverLimitWithLocalCache[i] { + continue + } + + // Store new tat for initial tat of next requests + newTat := this.algorithm.GetResultsAfterIncrease() + expirationSeconds := this.algorithm.GetExpirationSeconds() + + if this.expirationJitterMaxSeconds > 0 { + expirationSeconds += this.jitterRand.Int63n(this.expirationJitterMaxSeconds) + } + if this.perSecondClient != nil && cacheKey.PerSecond { + if perSecondPipeline == nil { + perSecondPipeline = driver.Pipeline{} + } + windowedSetNewTatPipelineAppend(this.perSecondClient, &perSecondPipeline, cacheKey.Key, newTat, expirationSeconds) + } else { + if pipeline == nil { + pipeline = driver.Pipeline{} + } + windowedSetNewTatPipelineAppend(this.client, &pipeline, cacheKey.Key, newTat, expirationSeconds) + } + } + + if pipeline != nil { + driver.CheckError(this.client.PipeDo(pipeline)) + } + if perSecondPipeline != nil { + driver.CheckError(this.perSecondClient.PipeDo(perSecondPipeline)) + } + + return responseDescriptorStatuses +} + +func (this *windowedRateLimitCacheImpl) Flush() {} + +func windowedPipelineAppend(client driver.Client, pipeline *driver.Pipeline, key string, result *int64, expirationSeconds int64) { + *pipeline = client.PipeAppend(*pipeline, nil, "SETNX", key, int64(0)) + *pipeline = client.PipeAppend(*pipeline, nil, "EXPIRE", key, expirationSeconds) + *pipeline = client.PipeAppend(*pipeline, result, "GET", key) +} + +func windowedSetNewTatPipelineAppend(client driver.Client, pipeline *driver.Pipeline, key string, newTat int64, expirationSeconds int64) { + *pipeline = client.PipeAppend(*pipeline, nil, "SET", key, newTat) + *pipeline = client.PipeAppend(*pipeline, nil, "EXPIRE", key, expirationSeconds) +} + +func NewWindowedRateLimitCacheImpl(client driver.Client, perSecondClient driver.Client, timeSource utils.TimeSource, jitterRand *rand.Rand, expirationJitterMaxSeconds int64, localCache *freecache.Cache, nearLimitRatio float32, cacheKeyPrefix string) limiter.RateLimitCache { + return &windowedRateLimitCacheImpl{ + client: client, + perSecondClient: perSecondClient, + timeSource: timeSource, + jitterRand: jitterRand, + expirationJitterMaxSeconds: expirationJitterMaxSeconds, + localCache: localCache, + nearLimitRatio: nearLimitRatio, + algorithm: algorithm.NewWindow( + algorithm.NewRollingWindowAlgorithm(timeSource, localCache, nearLimitRatio, cacheKeyPrefix), + cacheKeyPrefix, + localCache, + timeSource, + ), + } +} diff --git a/src/server/server_impl.go b/src/server/server_impl.go index b60d1e329..30f79f3d9 100644 --- a/src/server/server_impl.go +++ b/src/server/server_impl.go @@ -19,8 +19,8 @@ import ( "github.com/coocood/freecache" pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" - "github.com/envoyproxy/ratelimit/src/limiter" "github.com/envoyproxy/ratelimit/src/settings" + "github.com/envoyproxy/ratelimit/src/utils" "github.com/golang/protobuf/jsonpb" "github.com/gorilla/mux" reuseport "github.com/kavu/go_reuseport" @@ -190,7 +190,7 @@ func newServer(s settings.Settings, name string, store stats.Store, localCache * ret.scope = ret.store.ScopeWithTags(name, s.ExtraTags) ret.store.AddStatGenerator(stats.NewRuntimeStats(ret.scope.Scope("go"))) if localCache != nil { - ret.store.AddStatGenerator(limiter.NewLocalCacheStats(localCache, ret.scope.Scope("localcache"))) + ret.store.AddStatGenerator(utils.NewLocalCacheStats(localCache, ret.scope.Scope("localcache"))) } // setup runtime diff --git a/src/service/ratelimit.go b/src/service/ratelimit.go index 126bb776b..5442f8e2d 100644 --- a/src/service/ratelimit.go +++ b/src/service/ratelimit.go @@ -9,7 +9,7 @@ import ( "github.com/envoyproxy/ratelimit/src/assert" "github.com/envoyproxy/ratelimit/src/config" "github.com/envoyproxy/ratelimit/src/limiter" - "github.com/envoyproxy/ratelimit/src/redis" + "github.com/envoyproxy/ratelimit/src/redis/driver" "github.com/lyft/goruntime/loader" stats "github.com/lyft/gostats" logger "github.com/sirupsen/logrus" @@ -168,7 +168,7 @@ func (this *service) ShouldRateLimit( logger.Debugf("caught error during call") finalResponse = nil switch t := err.(type) { - case redis.RedisError: + case driver.RedisError: { this.stats.shouldRateLimit.redisError.Inc() finalError = t diff --git a/src/service/ratelimit_legacy.go b/src/service/ratelimit_legacy.go index 17112675c..98e1ae269 100644 --- a/src/service/ratelimit_legacy.go +++ b/src/service/ratelimit_legacy.go @@ -5,7 +5,7 @@ import ( pb_struct "github.com/envoyproxy/go-control-plane/envoy/extensions/common/ratelimit/v3" pb_legacy "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v2" pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" - "github.com/lyft/gostats" + stats "github.com/lyft/gostats" "golang.org/x/net/context" ) diff --git a/src/service_cmd/runner/runner.go b/src/service_cmd/runner/runner.go index afa1b1442..4f7ee2edf 100644 --- a/src/service_cmd/runner/runner.go +++ b/src/service_cmd/runner/runner.go @@ -47,20 +47,28 @@ func (runner *Runner) GetStatsStore() stats.Store { func createLimiter(srv server.Server, s settings.Settings, localCache *freecache.Cache) limiter.RateLimitCache { switch s.BackendType { case "redis", "": - return redis.NewRateLimiterCacheImplFromSettings( + cacheImpl, err := redis.NewRateLimiterCacheImplFromSettings( s, localCache, srv, utils.NewTimeSourceImpl(), rand.New(utils.NewLockedSource(time.Now().Unix())), s.ExpirationJitterMaxSeconds) + if err != nil { + logger.Fatalf("Could not setup redis ratelimit cache. %v\n", err) + } + return cacheImpl case "memcache": - return memcached.NewRateLimitCacheImplFromSettings( + cacheImpl, err := memcached.NewRateLimitCacheImplFromSettings( s, utils.NewTimeSourceImpl(), rand.New(utils.NewLockedSource(time.Now().Unix())), localCache, srv.Scope()) + if err != nil { + logger.Fatalf("Could not setup redis ratelimit cache. %v\n", err) + } + return cacheImpl default: logger.Fatalf("Invalid setting for BackendType: %s", s.BackendType) panic("This line should not be reachable") diff --git a/src/settings/settings.go b/src/settings/settings.go index 8468076ec..2c2e6b915 100644 --- a/src/settings/settings.go +++ b/src/settings/settings.go @@ -7,6 +7,9 @@ import ( "google.golang.org/grpc" ) +const FixedRateLimit = "FIXED_WINDOW" +const WindowedRateLimit = "ROLLING_WINDOW" + type Settings struct { // runtime options GrpcUnaryInterceptor grpc.ServerOption @@ -59,6 +62,9 @@ type Settings struct { // Memcache settings MemcacheHostPort string `envconfig:"MEMCACHE_HOST_PORT" default:""` + + // Algorithm settings + RateLimitAlgorithm string `envconfig:"RATE_LIMIT_ALGORITHM" default:"FIXED_WINDOW"` } type Option func(*Settings) diff --git a/src/limiter/cache_key.go b/src/utils/cache_key_generator.go similarity index 63% rename from src/limiter/cache_key.go rename to src/utils/cache_key_generator.go index 797cc2e11..4b8fe234f 100644 --- a/src/limiter/cache_key.go +++ b/src/utils/cache_key_generator.go @@ -1,4 +1,4 @@ -package limiter +package utils import ( "bytes" @@ -7,45 +7,46 @@ import ( pb_struct "github.com/envoyproxy/go-control-plane/envoy/extensions/common/ratelimit/v3" pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" + "github.com/envoyproxy/ratelimit/src/assert" "github.com/envoyproxy/ratelimit/src/config" - "github.com/envoyproxy/ratelimit/src/utils" ) +type CacheKey struct { + Key string + // True if the key corresponds to a limit with a SECOND unit. False otherwise. + PerSecond bool +} + type CacheKeyGenerator struct { prefix string // bytes.Buffer pool used to efficiently generate cache keys. bufferPool sync.Pool } -func NewCacheKeyGenerator(prefix string) CacheKeyGenerator { - return CacheKeyGenerator{ - prefix: prefix, - bufferPool: sync.Pool{ - New: func() interface{} { - return new(bytes.Buffer) - }, - }, +func (this *CacheKeyGenerator) GenerateCacheKeys(request *pb.RateLimitRequest, + limits []*config.RateLimit, hitsAddend uint32, time int64) []CacheKey { + assert.Assert(len(request.Descriptors) == len(limits)) + cacheKeys := make([]CacheKey, len(request.Descriptors)) + for i := 0; i < len(request.Descriptors); i++ { + // generateCacheKey() returns an empty string in the key if there is no limit + // so that we can keep the arrays all the same size. + cacheKeys[i] = this.GenerateCacheKey(request.Domain, request.Descriptors[i], limits[i], time) + // Increase statistics for limits hit by their respective requests. + if limits[i] != nil { + limits[i].Stats.TotalHits.Add(uint64(hitsAddend)) + } } -} - -type CacheKey struct { - Key string - // True if the key corresponds to a limit with a SECOND unit. False otherwise. - PerSecond bool -} - -func isPerSecondLimit(unit pb.RateLimitResponse_RateLimit_Unit) bool { - return unit == pb.RateLimitResponse_RateLimit_SECOND + return cacheKeys } // Generate a cache key for a limit lookup. // @param domain supplies the cache key domain. // @param descriptor supplies the descriptor to generate the key for. // @param limit supplies the rate limit to generate the key for (may be nil). -// @param now supplies the current unix time. +// @param time supplies the current unix time. // @return CacheKey struct. func (this *CacheKeyGenerator) GenerateCacheKey( - domain string, descriptor *pb_struct.RateLimitDescriptor, limit *config.RateLimit, now int64) CacheKey { + domain string, descriptor *pb_struct.RateLimitDescriptor, limit *config.RateLimit, time int64) CacheKey { if limit == nil { return CacheKey{ @@ -69,10 +70,25 @@ func (this *CacheKeyGenerator) GenerateCacheKey( b.WriteByte('_') } - divider := utils.UnitToDivider(limit.Limit.Unit) - b.WriteString(strconv.FormatInt((now/divider)*divider, 10)) + divider := UnitToDivider(limit.Limit.Unit) + b.WriteString(strconv.FormatInt((time/divider)*divider, 10)) return CacheKey{ Key: b.String(), PerSecond: isPerSecondLimit(limit.Limit.Unit)} } + +func NewCacheKeyGenerator(prefix string) CacheKeyGenerator { + return CacheKeyGenerator{ + prefix: prefix, + bufferPool: sync.Pool{ + New: func() interface{} { + return new(bytes.Buffer) + }, + }, + } +} + +func isPerSecondLimit(unit pb.RateLimitResponse_RateLimit_Unit) bool { + return unit == pb.RateLimitResponse_RateLimit_SECOND +} diff --git a/src/utils/jitter_rand_source.go b/src/utils/jitter_rand_source.go new file mode 100644 index 000000000..33eebe4d2 --- /dev/null +++ b/src/utils/jitter_rand_source.go @@ -0,0 +1,9 @@ +package utils + +// Interface for a rand Source for expiration jitter. +type JitterRandSource interface { + // @return a non-negative pseudo-random 63-bit integer as an int64. + Int63() int64 + // @param seed initializes pseudo-random generator to a deterministic state. + Seed(seed int64) +} diff --git a/src/utils/jitter_rand_source_impl.go b/src/utils/jitter_rand_source_impl.go new file mode 100644 index 000000000..39283d60e --- /dev/null +++ b/src/utils/jitter_rand_source_impl.go @@ -0,0 +1,29 @@ +package utils + +import ( + "math/rand" + "sync" +) + +// rand for jitter. +type lockedSource struct { + lk sync.Mutex + src rand.Source +} + +func NewLockedSource(seed int64) JitterRandSource { + return &lockedSource{src: rand.NewSource(seed)} +} + +func (r *lockedSource) Int63() (n int64) { + r.lk.Lock() + n = r.src.Int63() + r.lk.Unlock() + return +} + +func (r *lockedSource) Seed(seed int64) { + r.lk.Lock() + r.src.Seed(seed) + r.lk.Unlock() +} diff --git a/src/limiter/local_cache_stats.go b/src/utils/local_cache_stats.go similarity index 98% rename from src/limiter/local_cache_stats.go rename to src/utils/local_cache_stats.go index d0d59dc27..be3f93899 100644 --- a/src/limiter/local_cache_stats.go +++ b/src/utils/local_cache_stats.go @@ -1,4 +1,4 @@ -package limiter +package utils import ( "github.com/coocood/freecache" @@ -17,6 +17,17 @@ type localCacheStats struct { overwriteCount stats.Gauge } +func (stats localCacheStats) GenerateStats() { + stats.evacuateCount.Set(uint64(stats.cache.EvacuateCount())) + stats.expiredCount.Set(uint64(stats.cache.ExpiredCount())) + stats.entryCount.Set(uint64(stats.cache.EntryCount())) + stats.averageAccessTime.Set(uint64(stats.cache.AverageAccessTime())) + stats.hitCount.Set(uint64(stats.cache.HitCount())) + stats.missCount.Set(uint64(stats.cache.MissCount())) + stats.lookupCount.Set(uint64(stats.cache.LookupCount())) + stats.overwriteCount.Set(uint64(stats.cache.OverwriteCount())) +} + func NewLocalCacheStats(localCache *freecache.Cache, scope stats.Scope) stats.StatGenerator { return localCacheStats{ cache: localCache, @@ -30,14 +41,3 @@ func NewLocalCacheStats(localCache *freecache.Cache, scope stats.Scope) stats.St overwriteCount: scope.NewGauge("overwriteCount"), } } - -func (stats localCacheStats) GenerateStats() { - stats.evacuateCount.Set(uint64(stats.cache.EvacuateCount())) - stats.expiredCount.Set(uint64(stats.cache.ExpiredCount())) - stats.entryCount.Set(uint64(stats.cache.EntryCount())) - stats.averageAccessTime.Set(uint64(stats.cache.AverageAccessTime())) - stats.hitCount.Set(uint64(stats.cache.HitCount())) - stats.missCount.Set(uint64(stats.cache.MissCount())) - stats.lookupCount.Set(uint64(stats.cache.LookupCount())) - stats.overwriteCount.Set(uint64(stats.cache.OverwriteCount())) -} diff --git a/src/utils/time.go b/src/utils/time.go index e7978cc6c..d4d56b51b 100644 --- a/src/utils/time.go +++ b/src/utils/time.go @@ -1,48 +1,9 @@ package utils -import ( - "math/rand" - "sync" - "time" -) - -// Interface for a rand Source for expiration jitter. -type JitterRandSource interface { - // @return a non-negative pseudo-random 63-bit integer as an int64. - Int63() int64 - // @param seed initializes pseudo-random generator to a deterministic state. - Seed(seed int64) -} - -type timeSourceImpl struct{} - -func NewTimeSourceImpl() TimeSource { - return &timeSourceImpl{} -} - -func (this *timeSourceImpl) UnixNow() int64 { - return time.Now().Unix() -} - -// rand for jitter. -type lockedSource struct { - lk sync.Mutex - src rand.Source -} - -func NewLockedSource(seed int64) JitterRandSource { - return &lockedSource{src: rand.NewSource(seed)} -} - -func (r *lockedSource) Int63() (n int64) { - r.lk.Lock() - n = r.src.Int63() - r.lk.Unlock() - return -} - -func (r *lockedSource) Seed(seed int64) { - r.lk.Lock() - r.src.Seed(seed) - r.lk.Unlock() +// Interface for a time source. +type TimeSource interface { + // @return the current unix time in seconds. + UnixNow() int64 + // @return the current unix time in nanoseconds. + UnixNanoNow() int64 } diff --git a/src/utils/time_impl.go b/src/utils/time_impl.go new file mode 100644 index 000000000..f20075256 --- /dev/null +++ b/src/utils/time_impl.go @@ -0,0 +1,19 @@ +package utils + +import ( + "time" +) + +type timeSourceImpl struct{} + +func (this *timeSourceImpl) UnixNow() int64 { + return time.Now().Unix() +} + +func (this *timeSourceImpl) UnixNanoNow() int64 { + return time.Now().UnixNano() +} + +func NewTimeSourceImpl() TimeSource { + return &timeSourceImpl{} +} diff --git a/src/utils/utilities.go b/src/utils/utilities.go index e6029f5be..edbcc9645 100644 --- a/src/utils/utilities.go +++ b/src/utils/utilities.go @@ -1,15 +1,13 @@ package utils import ( + "time" + pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" "github.com/golang/protobuf/ptypes/duration" ) -// Interface for a time source. -type TimeSource interface { - // @return the current unix time in seconds. - UnixNow() int64 -} +const secondToNanosecondRate = 1e9 // Convert a rate limit into a time divider. // @param unit supplies the unit to convert. @@ -29,15 +27,52 @@ func UnitToDivider(unit pb.RateLimitResponse_RateLimit_Unit) int64 { panic("should not get here") } -func CalculateReset(currentLimit *pb.RateLimitResponse_RateLimit, timeSource TimeSource) *duration.Duration { - sec := UnitToDivider(currentLimit.Unit) - now := timeSource.UnixNow() - return &duration.Duration{Seconds: sec - now%sec} +func MaxInt(a int, b int) int { + if a > b { + return a + } + return b } -func Max(a uint32, b uint32) uint32 { +func MaxInt64(a int64, b int64) int64 { if a > b { return a } return b } + +func MinInt64(a int64, b int64) int64 { + if a < b { + return a + } + return b +} + +func MaxUint32(a uint32, b uint32) uint32 { + if a > b { + return a + } + return b +} + +func NanosecondsToSeconds(nanoseconds int64) int64 { + return nanoseconds / secondToNanosecondRate +} + +func NanosecondsToDuration(nanoseconds int64) *duration.Duration { + nanos := nanoseconds + secs := nanos / secondToNanosecondRate + nanos -= secs * secondToNanosecondRate + return &duration.Duration{Seconds: secs, Nanos: int32(nanos)} +} + +func SecondsToNanoseconds(second int64) int64 { + time.Now() + return second * secondToNanosecondRate +} + +func CalculateFixedReset(currentLimit *pb.RateLimitResponse_RateLimit, timeSource TimeSource) *duration.Duration { + sec := UnitToDivider(currentLimit.Unit) + now := timeSource.UnixNow() + return &duration.Duration{Seconds: sec - now%sec} +} diff --git a/test/algorithm/base_window_test.go b/test/algorithm/base_window_test.go new file mode 100644 index 000000000..dee8c3b91 --- /dev/null +++ b/test/algorithm/base_window_test.go @@ -0,0 +1,157 @@ +package algorithm + +import ( + "testing" + + "github.com/coocood/freecache" + pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" + "github.com/envoyproxy/ratelimit/src/algorithm" + "github.com/envoyproxy/ratelimit/src/config" + "github.com/envoyproxy/ratelimit/src/utils" + "github.com/envoyproxy/ratelimit/test/common" + mock_utils "github.com/envoyproxy/ratelimit/test/mocks/utils" + "github.com/golang/mock/gomock" + "github.com/golang/protobuf/ptypes/duration" + stats "github.com/lyft/gostats" + "github.com/stretchr/testify/assert" +) + +func TestGetResponseDescriptorStatus(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + timeSource := mock_utils.NewMockTimeSource(controller) + statsStore := stats.NewStore(stats.NewNullSink(), false) + + // Fixed Window algorithm + fixedAlgorithm := algorithm.NewFixedWindowAlgorithm(timeSource, nil, 0.8, "") + baseAlgorithm := algorithm.NewWindow(fixedAlgorithm, "", nil, timeSource) + + key := "key_value" + limit := config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore) + var results int64 = 1 + var hitsAddend int64 = 1 + isOverLimitWithLocalCache := false + + timeSource.EXPECT().UnixNow().Return(int64(1)).MaxTimes(2) + + expectedResult := &pb.RateLimitResponse_DescriptorStatus{ + Code: pb.RateLimitResponse_OK, + CurrentLimit: limit.Limit, + LimitRemaining: 9, + DurationUntilReset: utils.CalculateFixedReset(limit.Limit, timeSource)} + + actualResult := baseAlgorithm.GetResponseDescriptorStatus(key, limit, results, isOverLimitWithLocalCache, hitsAddend) + assert.Equal(expectedResult, actualResult) + + // Rolling Window algorithm + rollingAlgorithm := algorithm.NewRollingWindowAlgorithm(timeSource, nil, 0.8, "") + baseAlgorithm = algorithm.NewWindow(rollingAlgorithm, "", nil, timeSource) + + key = "key_value" + limit = config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore) + results = 0 + hitsAddend = 1 + isOverLimitWithLocalCache = false + + timeSource.EXPECT().UnixNanoNow().Return(int64(1e9)).MaxTimes(1) + + expectedResult = &pb.RateLimitResponse_DescriptorStatus{ + Code: pb.RateLimitResponse_OK, + CurrentLimit: limit.Limit, + LimitRemaining: 9, + DurationUntilReset: &duration.Duration{Nanos: 1e8}} + + actualResult = baseAlgorithm.GetResponseDescriptorStatus(key, limit, results, isOverLimitWithLocalCache, hitsAddend) + assert.Equal(expectedResult, actualResult) +} + +func TestIsOverLimitWithLocalCache(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + key := "key_value" + + timeSource := mock_utils.NewMockTimeSource(controller) + + // Fixed Window algorithm + fixedLocalCache := freecache.NewCache(100) + + fixedAlgorithm := algorithm.NewFixedWindowAlgorithm(timeSource, fixedLocalCache, 0.8, "") + baseAlgorithm := algorithm.NewWindow(fixedAlgorithm, "", fixedLocalCache, timeSource) + + assert.Equal(false, baseAlgorithm.IsOverLimitWithLocalCache(key)) + + fixedLocalCache.Set([]byte(key), []byte{}, 1) + assert.Equal(true, baseAlgorithm.IsOverLimitWithLocalCache(key)) + + // Rolling Window algorithm + rollingLocalCache := freecache.NewCache(100) + + rollingAlgorithm := algorithm.NewRollingWindowAlgorithm(timeSource, rollingLocalCache, 0.8, "") + baseAlgorithm = algorithm.NewWindow(rollingAlgorithm, "", rollingLocalCache, timeSource) + + assert.Equal(false, baseAlgorithm.IsOverLimitWithLocalCache(key)) + + rollingLocalCache.Set([]byte(key), []byte{}, 1) + assert.Equal(true, baseAlgorithm.IsOverLimitWithLocalCache(key)) +} + +func TestGenerateCacheKeys(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + timeSource := mock_utils.NewMockTimeSource(controller) + statsStore := stats.NewStore(stats.NewNullSink(), false) + + var hitsAddend int64 = 1 + request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) + limit := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} + + // Fixed Window algorithm + fixedAlgorithm := algorithm.NewFixedWindowAlgorithm(timeSource, nil, 0.8, "") + baseAlgorithm := algorithm.NewWindow(fixedAlgorithm, "", nil, timeSource) + + timeSource.EXPECT().UnixNow().Return(int64(1)).MaxTimes(1) + + expectedResult := []utils.CacheKey([]utils.CacheKey{{Key: "domain_key_value_1", PerSecond: true}}) + actualResult := baseAlgorithm.GenerateCacheKeys(request, limit, hitsAddend, 1) + assert.Equal(expectedResult, actualResult) + + // Rolling Window algorithm + rollingAlgorithm := algorithm.NewRollingWindowAlgorithm(timeSource, nil, 0.8, "") + baseAlgorithm = algorithm.NewWindow(rollingAlgorithm, "", nil, timeSource) + + timeSource.EXPECT().UnixNanoNow().Return(int64(1e9)).MaxTimes(1) + + expectedResult = []utils.CacheKey([]utils.CacheKey{{Key: "domain_key_value_0", PerSecond: true}}) + actualResult = baseAlgorithm.GenerateCacheKeys(request, limit, hitsAddend, 0) + assert.Equal(expectedResult, actualResult) +} + +func TestPopulateStats(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + statsStore := stats.NewStore(stats.NewNullSink(), false) + limit := config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore) + + algorithm.PopulateStats(limit, 1, 0, 0) + assert.Equal(uint64(1), limit.Stats.NearLimit.Value()) + assert.Equal(uint64(0), limit.Stats.OverLimit.Value()) + assert.Equal(uint64(0), limit.Stats.OverLimitWithLocalCache.Value()) + + algorithm.PopulateStats(limit, 0, 1, 0) + assert.Equal(uint64(1), limit.Stats.NearLimit.Value()) + assert.Equal(uint64(1), limit.Stats.OverLimit.Value()) + assert.Equal(uint64(0), limit.Stats.OverLimitWithLocalCache.Value()) + + algorithm.PopulateStats(limit, 0, 0, 1) + assert.Equal(uint64(1), limit.Stats.NearLimit.Value()) + assert.Equal(uint64(1), limit.Stats.OverLimit.Value()) + assert.Equal(uint64(1), limit.Stats.OverLimitWithLocalCache.Value()) +} diff --git a/test/algorithm/fixed_window_test.go b/test/algorithm/fixed_window_test.go new file mode 100644 index 000000000..d3ce07dd1 --- /dev/null +++ b/test/algorithm/fixed_window_test.go @@ -0,0 +1,109 @@ +package algorithm + +import ( + "testing" + + pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" + "github.com/envoyproxy/ratelimit/src/algorithm" + "github.com/envoyproxy/ratelimit/src/config" + mock_utils "github.com/envoyproxy/ratelimit/test/mocks/utils" + "github.com/golang/mock/gomock" + "github.com/golang/protobuf/ptypes/duration" + stats "github.com/lyft/gostats" + "github.com/stretchr/testify/assert" +) + +func TestFixedIsOverLimit(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + timeSource := mock_utils.NewMockTimeSource(controller) + statsStore := stats.NewStore(stats.NewNullSink(), false) + algorithm := algorithm.NewFixedWindowAlgorithm(timeSource, nil, 0.8, "") + + var result int64 = 1 + var hitsAddend int64 = 1 + limit := config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore) + + actualIsOverLimit, actualLimitRemaining, actualDurationUntilReset := algorithm.IsOverLimit(limit, result, hitsAddend) + + assert.Equal(false, actualIsOverLimit) + assert.Equal(int64(9), actualLimitRemaining) + assert.Equal(1, actualDurationUntilReset) + + result = 10 + hitsAddend = 1 + limit = config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore) + + actualIsOverLimit, actualLimitRemaining, actualDurationUntilReset = algorithm.IsOverLimit(limit, result, hitsAddend) + + assert.Equal(false, actualIsOverLimit) + assert.Equal(int64(0), actualLimitRemaining) + assert.Equal(1, actualDurationUntilReset) + + result = 11 + hitsAddend = 1 + limit = config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore) + + actualIsOverLimit, actualLimitRemaining, actualDurationUntilReset = algorithm.IsOverLimit(limit, result, hitsAddend) + + assert.Equal(true, actualIsOverLimit) + assert.Equal(int64(0), actualLimitRemaining) + assert.Equal(1, actualDurationUntilReset) +} + +func TestFixedCalculateSimpleReset(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + timeSource := mock_utils.NewMockTimeSource(controller) + statsStore := stats.NewStore(stats.NewNullSink(), false) + algorithm := algorithm.NewFixedWindowAlgorithm(timeSource, nil, 0.8, "") + + timeSource.EXPECT().UnixNow().Return(int64(1)).MaxTimes(1) + limit := config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore) + + actualResetDuration := algorithm.CalculateSimpleReset(limit, timeSource) + expectedResetDuration := &duration.Duration{Seconds: 1} + assert.Equal(expectedResetDuration, actualResetDuration) + + timeSource.EXPECT().UnixNow().Return(int64(30)).MaxTimes(1) + limit = config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, "key_value", statsStore) + + actualResetDuration = algorithm.CalculateSimpleReset(limit, timeSource) + expectedResetDuration = &duration.Duration{Seconds: 30} + assert.Equal(expectedResetDuration, actualResetDuration) + + timeSource.EXPECT().UnixNow().Return(int64(60)).MaxTimes(1) + limit = config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_HOUR, "key_value", statsStore) + + actualResetDuration = algorithm.CalculateSimpleReset(limit, timeSource) + expectedResetDuration = &duration.Duration{Seconds: 59 * 60} + assert.Equal(expectedResetDuration, actualResetDuration) +} + +func TestFixedCalculateReset(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + timeSource := mock_utils.NewMockTimeSource(controller) + statsStore := stats.NewStore(stats.NewNullSink(), false) + algorithm := algorithm.NewFixedWindowAlgorithm(timeSource, nil, 0.8, "") + + limit := config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, "key_value", statsStore) + + timeSource.EXPECT().UnixNow().Return(int64(45)).MaxTimes(1) + + actualResetDuration := algorithm.CalculateReset(true, limit, timeSource) + expectedResetDuration := &duration.Duration{Seconds: 15} + assert.Equal(expectedResetDuration, actualResetDuration) + + timeSource.EXPECT().UnixNow().Return(int64(45)).MaxTimes(1) + + actualResetDuration = algorithm.CalculateReset(false, limit, timeSource) + expectedResetDuration = &duration.Duration{Seconds: 15} + assert.Equal(expectedResetDuration, actualResetDuration) +} diff --git a/test/algorithm/rolling_window_test.go b/test/algorithm/rolling_window_test.go new file mode 100644 index 000000000..5259d1ae5 --- /dev/null +++ b/test/algorithm/rolling_window_test.go @@ -0,0 +1,262 @@ +package algorithm + +import ( + "testing" + + pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" + "github.com/envoyproxy/ratelimit/src/algorithm" + "github.com/envoyproxy/ratelimit/src/config" + mock_utils "github.com/envoyproxy/ratelimit/test/mocks/utils" + "github.com/golang/mock/gomock" + "github.com/golang/protobuf/ptypes/duration" + stats "github.com/lyft/gostats" + "github.com/stretchr/testify/assert" +) + +func TestRollingIsOverLimit(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + timeSource := mock_utils.NewMockTimeSource(controller) + statsStore := stats.NewStore(stats.NewNullSink(), false) + algorithm := algorithm.NewRollingWindowAlgorithm(timeSource, nil, 0.8, "") + + var result int64 = 1e9 + var hitsAddend int64 = 1 + limit := config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore) + + timeSource.EXPECT().UnixNanoNow().Return(int64(1e9)).MaxTimes(1) + actualIsOverLimit, actualLimitRemaining, actualDurationUntilReset := algorithm.IsOverLimit(limit, result, hitsAddend) + + assert.Equal(false, actualIsOverLimit) + assert.Equal(int64(9), actualLimitRemaining) + assert.Equal(0, actualDurationUntilReset) + + result = 0 + hitsAddend = 1 + limit = config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore) + + timeSource.EXPECT().UnixNanoNow().Return(int64(1e9)).MaxTimes(1e9) + actualIsOverLimit, actualLimitRemaining, actualDurationUntilReset = algorithm.IsOverLimit(limit, result, hitsAddend) + + assert.Equal(false, actualIsOverLimit) + assert.Equal(int64(9), actualLimitRemaining) + assert.Equal(0, actualDurationUntilReset) + + result = 3600e9 + hitsAddend = 1 + limit = config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_HOUR, "key_value", statsStore) + + timeSource.EXPECT().UnixNanoNow().Return(int64(1e9)).MaxTimes(4e9) + actualIsOverLimit, actualLimitRemaining, actualDurationUntilReset = algorithm.IsOverLimit(limit, result, hitsAddend) + + assert.Equal(true, actualIsOverLimit) + assert.Equal(int64(0), actualLimitRemaining) + assert.Equal(359, actualDurationUntilReset) +} + +func TestRollingCalculateSimpleReset(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + timeSource := mock_utils.NewMockTimeSource(controller) + statsStore := stats.NewStore(stats.NewNullSink(), false) + algorithm := algorithm.NewRollingWindowAlgorithm(timeSource, nil, 0.8, "") + + timeSource.EXPECT().UnixNanoNow().Return(int64(1e9)).MaxTimes(1) + limit := config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore) + + actualResetDuration := algorithm.CalculateSimpleReset(limit, timeSource) + expectedResetDuration := &duration.Duration{Seconds: 1} + assert.Equal(expectedResetDuration, actualResetDuration) + + timeSource.EXPECT().UnixNanoNow().Return(int64(30 * 1e9)).MaxTimes(1) + limit = config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, "key_value", statsStore) + + actualResetDuration = algorithm.CalculateSimpleReset(limit, timeSource) + expectedResetDuration = &duration.Duration{Seconds: 30} + assert.Equal(expectedResetDuration, actualResetDuration) + + timeSource.EXPECT().UnixNanoNow().Return(int64(60 * 1e9)).MaxTimes(1) + limit = config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_HOUR, "key_value", statsStore) + + actualResetDuration = algorithm.CalculateSimpleReset(limit, timeSource) + expectedResetDuration = &duration.Duration{Seconds: 59 * 60} + assert.Equal(expectedResetDuration, actualResetDuration) +} + +func TestRollingCalculateReset(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + timeSource := mock_utils.NewMockTimeSource(controller) + statsStore := stats.NewStore(stats.NewNullSink(), false) + algorithm := algorithm.NewRollingWindowAlgorithm(timeSource, nil, 0.8, "") + + timeSource.EXPECT().UnixNanoNow().Return(int64(1e9)).MaxTimes(1) + + var results int64 = 0 + var hitsAddend int64 = 1 + limit := config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, "key_value", statsStore) + + // populating tat, newTat, arriveAt + // that is required to execute CalculateAt + + // periode = 1 minute = 60 second + // limit = 10 request/minute + // emissionInterval = 6 second + // request = 1 + // increment = emissionInterval*request = 6 second + + // arriveAt = 1 second + // tat = 0 second + + // newTat should be max(arriveAt,tat)+increment = 7 second + // DurationUntilReset should be newtat-arriveat = 6 second + + algorithm.IsOverLimit(limit, results, hitsAddend) + isOverLimit := false + + actualResetDuration := algorithm.CalculateReset(isOverLimit, limit, timeSource) + expectedResetDuration := &duration.Duration{Seconds: 6} + assert.Equal(expectedResetDuration, actualResetDuration) + + timeSource.EXPECT().UnixNanoNow().Return(int64(3 * 60 * 1e9)).MaxTimes(1) + + results = 72 * 60 * 1e9 + hitsAddend = 1 + limit = config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_HOUR, "key_value", statsStore) + + // populating tat, newTat, arriveAt + // that is required to execute CalculateAt + + // periode = 1 hour = 3600 second + // limit = 10 request/hour + // emissionInterval = 6 minute = 360 second + // request = 1 + // increment = emissionInterval*request = 6 minute = 360 second + + // arriveAt = 3 minute + // tat = 72 minute + + // newTat should be max(arriveAt,tat)+increment = 78 minute (not used) + // DurationUntilReset should be tat-arriveat = 6 second + + algorithm.IsOverLimit(limit, results, hitsAddend) + isOverLimit = true + + actualResetDuration = algorithm.CalculateReset(isOverLimit, limit, timeSource) + expectedResetDuration = &duration.Duration{Seconds: 69 * 60} + assert.Equal(expectedResetDuration, actualResetDuration) +} + +func TestRollingGetExpirationSeconds(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + timeSource := mock_utils.NewMockTimeSource(controller) + statsStore := stats.NewStore(stats.NewNullSink(), false) + algorithm := algorithm.NewRollingWindowAlgorithm(timeSource, nil, 0.8, "") + + limit := config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore) + var results int64 = 0 + var hitsAddend int64 = 1 + + timeSource.EXPECT().UnixNanoNow().Return(int64(1e9)).MaxTimes(1) + algorithm.IsOverLimit(limit, results, hitsAddend) + + expectedResult := int64(1) + actualResult := algorithm.GetExpirationSeconds() + assert.Equal(expectedResult, actualResult) + + limit = config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore) + results = 2e9 + hitsAddend = 1 + + timeSource.EXPECT().UnixNanoNow().Return(int64(1e9 + 4e6)).MaxTimes(1) + algorithm.IsOverLimit(limit, results, hitsAddend) + + expectedResult = int64(1) + actualResult = algorithm.GetExpirationSeconds() + assert.Equal(expectedResult, actualResult) + + limit = config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, "key_value", statsStore) + results = 0 + hitsAddend = 1 + + timeSource.EXPECT().UnixNanoNow().Return(int64(1e9)).MaxTimes(1) + algorithm.IsOverLimit(limit, results, hitsAddend) + + expectedResult = int64(7) + actualResult = algorithm.GetExpirationSeconds() + assert.Equal(expectedResult, actualResult) + + limit = config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, "key_value", statsStore) + results = 60e9 + hitsAddend = 1 + + timeSource.EXPECT().UnixNanoNow().Return(int64(4e9)).MaxTimes(1) + algorithm.IsOverLimit(limit, results, hitsAddend) + + expectedResult = int64(57) + actualResult = algorithm.GetExpirationSeconds() + assert.Equal(expectedResult, actualResult) +} + +func TestRollingGetResultsAfterIncrease(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + timeSource := mock_utils.NewMockTimeSource(controller) + statsStore := stats.NewStore(stats.NewNullSink(), false) + algorithm := algorithm.NewRollingWindowAlgorithm(timeSource, nil, 0.8, "") + + limit := config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore) + var results int64 = 0 + var hitsAddend int64 = 1 + + timeSource.EXPECT().UnixNanoNow().Return(int64(1e9)).MaxTimes(1) + algorithm.IsOverLimit(limit, results, hitsAddend) + + expectedResult := int64(1e9 + 1e8) + actualResult := algorithm.GetResultsAfterIncrease() + assert.Equal(expectedResult, actualResult) + + limit = config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore) + results = 2e9 + hitsAddend = 1 + + timeSource.EXPECT().UnixNanoNow().Return(int64(1e9 + 4e6)).MaxTimes(1) + algorithm.IsOverLimit(limit, results, hitsAddend) + + expectedResult = int64(2e9) + actualResult = algorithm.GetResultsAfterIncrease() + assert.Equal(expectedResult, actualResult) + + limit = config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, "key_value", statsStore) + results = 0 + hitsAddend = 1 + + timeSource.EXPECT().UnixNanoNow().Return(int64(1e9)).MaxTimes(1) + algorithm.IsOverLimit(limit, results, hitsAddend) + + expectedResult = int64(7e9) + actualResult = algorithm.GetResultsAfterIncrease() + assert.Equal(expectedResult, actualResult) + + limit = config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, "key_value", statsStore) + results = 60e9 + hitsAddend = 1 + + timeSource.EXPECT().UnixNanoNow().Return(int64(4e9)).MaxTimes(1) + algorithm.IsOverLimit(limit, results, hitsAddend) + + expectedResult = int64(60e9) + actualResult = algorithm.GetResultsAfterIncrease() + assert.Equal(expectedResult, actualResult) +} diff --git a/test/limiter/base_limiter_test.go b/test/limiter/base_limiter_test.go deleted file mode 100644 index 94b63d771..000000000 --- a/test/limiter/base_limiter_test.go +++ /dev/null @@ -1,144 +0,0 @@ -package limiter - -import ( - "math/rand" - "testing" - - "github.com/coocood/freecache" - pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" - "github.com/envoyproxy/ratelimit/src/config" - "github.com/envoyproxy/ratelimit/src/limiter" - "github.com/envoyproxy/ratelimit/test/common" - mock_utils "github.com/envoyproxy/ratelimit/test/mocks/utils" - "github.com/golang/mock/gomock" - stats "github.com/lyft/gostats" - "github.com/stretchr/testify/assert" -) - -func TestGenerateCacheKeys(t *testing.T) { - assert := assert.New(t) - controller := gomock.NewController(t) - defer controller.Finish() - timeSource := mock_utils.NewMockTimeSource(controller) - jitterSource := mock_utils.NewMockJitterRandSource(controller) - statsStore := stats.NewStore(stats.NewNullSink(), false) - timeSource.EXPECT().UnixNow().Return(int64(1234)) - baseRateLimit := limiter.NewBaseRateLimit(timeSource, rand.New(jitterSource), 3600, nil, 0.8, "") - request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) - limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} - assert.Equal(uint64(0), limits[0].Stats.TotalHits.Value()) - cacheKeys := baseRateLimit.GenerateCacheKeys(request, limits, 1) - assert.Equal(1, len(cacheKeys)) - assert.Equal("domain_key_value_1234", cacheKeys[0].Key) - assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) -} - -func TestGenerateCacheKeysPrefix(t *testing.T) { - assert := assert.New(t) - controller := gomock.NewController(t) - defer controller.Finish() - timeSource := mock_utils.NewMockTimeSource(controller) - jitterSource := mock_utils.NewMockJitterRandSource(controller) - statsStore := stats.NewStore(stats.NewNullSink(), false) - timeSource.EXPECT().UnixNow().Return(int64(1234)) - baseRateLimit := limiter.NewBaseRateLimit(timeSource, rand.New(jitterSource), 3600, nil, 0.8, "prefix:") - request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) - limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} - assert.Equal(uint64(0), limits[0].Stats.TotalHits.Value()) - cacheKeys := baseRateLimit.GenerateCacheKeys(request, limits, 1) - assert.Equal(1, len(cacheKeys)) - assert.Equal("prefix:domain_key_value_1234", cacheKeys[0].Key) - assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) -} - -func TestOverLimitWithLocalCache(t *testing.T) { - assert := assert.New(t) - controller := gomock.NewController(t) - defer controller.Finish() - localCache := freecache.NewCache(100) - localCache.Set([]byte("key"), []byte("value"), 100) - baseRateLimit := limiter.NewBaseRateLimit(nil, nil, 3600, localCache, 0.8, "") - // Returns true, as local cache contains over limit value for the key. - assert.Equal(true, baseRateLimit.IsOverLimitWithLocalCache("key")) -} - -func TestNoOverLimitWithLocalCache(t *testing.T) { - assert := assert.New(t) - controller := gomock.NewController(t) - defer controller.Finish() - baseRateLimit := limiter.NewBaseRateLimit(nil, nil, 3600, nil, 0.8, "") - // Returns false, as local cache is nil. - assert.Equal(false, baseRateLimit.IsOverLimitWithLocalCache("domain_key_value_1234")) - localCache := freecache.NewCache(100) - baseRateLimitWithLocalCache := limiter.NewBaseRateLimit(nil, nil, 3600, localCache, 0.8, "") - // Returns false, as local cache does not contain value for cache key. - assert.Equal(false, baseRateLimitWithLocalCache.IsOverLimitWithLocalCache("domain_key_value_1234")) -} - -func TestGetResponseStatusEmptyKey(t *testing.T) { - assert := assert.New(t) - controller := gomock.NewController(t) - defer controller.Finish() - baseRateLimit := limiter.NewBaseRateLimit(nil, nil, 3600, nil, 0.8, "") - responseStatus := baseRateLimit.GetResponseDescriptorStatus("", nil, false, 1) - assert.Equal(pb.RateLimitResponse_OK, responseStatus.GetCode()) - assert.Equal(uint32(0), responseStatus.GetLimitRemaining()) -} - -func TestGetResponseStatusOverLimitWithLocalCache(t *testing.T) { - assert := assert.New(t) - controller := gomock.NewController(t) - defer controller.Finish() - timeSource := mock_utils.NewMockTimeSource(controller) - timeSource.EXPECT().UnixNow().Return(int64(1234)) - statsStore := stats.NewStore(stats.NewNullSink(), false) - baseRateLimit := limiter.NewBaseRateLimit(timeSource, nil, 3600, nil, 0.8, "") - limits := []*config.RateLimit{config.NewRateLimit(5, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} - limitInfo := limiter.NewRateLimitInfo(limits[0], 2, 6, 4, 5) - // As `isOverLimitWithLocalCache` is passed as `true`, immediate response is returned with no checks of the limits. - responseStatus := baseRateLimit.GetResponseDescriptorStatus("key", limitInfo, true, 2) - assert.Equal(pb.RateLimitResponse_OVER_LIMIT, responseStatus.GetCode()) - assert.Equal(uint32(0), responseStatus.GetLimitRemaining()) - assert.Equal(limits[0].Limit, responseStatus.GetCurrentLimit()) - assert.Equal(uint64(2), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(2), limits[0].Stats.OverLimitWithLocalCache.Value()) -} - -func TestGetResponseStatusOverLimit(t *testing.T) { - assert := assert.New(t) - controller := gomock.NewController(t) - defer controller.Finish() - timeSource := mock_utils.NewMockTimeSource(controller) - timeSource.EXPECT().UnixNow().Return(int64(1234)) - statsStore := stats.NewStore(stats.NewNullSink(), false) - localCache := freecache.NewCache(100) - baseRateLimit := limiter.NewBaseRateLimit(timeSource, nil, 3600, localCache, 0.8, "") - limits := []*config.RateLimit{config.NewRateLimit(5, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} - limitInfo := limiter.NewRateLimitInfo(limits[0], 2, 7, 4, 5) - responseStatus := baseRateLimit.GetResponseDescriptorStatus("key", limitInfo, false, 1) - assert.Equal(pb.RateLimitResponse_OVER_LIMIT, responseStatus.GetCode()) - assert.Equal(uint32(0), responseStatus.GetLimitRemaining()) - assert.Equal(limits[0].Limit, responseStatus.GetCurrentLimit()) - result, _ := localCache.Get([]byte("key")) - // Local cache should have been populated with over the limit key. - assert.Equal("", string(result)) - assert.Equal(uint64(2), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(1), limits[0].Stats.NearLimit.Value()) -} - -func TestGetResponseStatusBelowLimit(t *testing.T) { - assert := assert.New(t) - controller := gomock.NewController(t) - defer controller.Finish() - timeSource := mock_utils.NewMockTimeSource(controller) - timeSource.EXPECT().UnixNow().Return(int64(1234)) - statsStore := stats.NewStore(stats.NewNullSink(), false) - baseRateLimit := limiter.NewBaseRateLimit(timeSource, nil, 3600, nil, 0.8, "") - limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} - limitInfo := limiter.NewRateLimitInfo(limits[0], 2, 6, 9, 10) - responseStatus := baseRateLimit.GetResponseDescriptorStatus("key", limitInfo, false, 1) - assert.Equal(pb.RateLimitResponse_OK, responseStatus.GetCode()) - assert.Equal(uint32(4), responseStatus.GetLimitRemaining()) - assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) - assert.Equal(limits[0].Limit, responseStatus.GetCurrentLimit()) -} diff --git a/test/memcached/cache_impl_test.go b/test/memcached/fixed_cache_impl_test.go similarity index 91% rename from test/memcached/cache_impl_test.go rename to test/memcached/fixed_cache_impl_test.go index 847b76f4d..d1f329d3c 100644 --- a/test/memcached/cache_impl_test.go +++ b/test/memcached/fixed_cache_impl_test.go @@ -14,13 +14,12 @@ import ( pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" "github.com/envoyproxy/ratelimit/src/config" - "github.com/envoyproxy/ratelimit/src/limiter" "github.com/envoyproxy/ratelimit/src/memcached" "github.com/envoyproxy/ratelimit/src/utils" stats "github.com/lyft/gostats" "github.com/envoyproxy/ratelimit/test/common" - mock_memcached "github.com/envoyproxy/ratelimit/test/mocks/memcached" + mock_memcached "github.com/envoyproxy/ratelimit/test/mocks/memcached/driver" mock_utils "github.com/envoyproxy/ratelimit/test/mocks/utils" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" @@ -34,19 +33,19 @@ func TestMemcached(t *testing.T) { timeSource := mock_utils.NewMockTimeSource(controller) client := mock_memcached.NewMockClient(controller) statsStore := stats.NewStore(stats.NewNullSink(), false) - cache := memcached.NewRateLimitCacheImpl(client, timeSource, nil, 0, nil, statsStore, 0.8, "") + cache := memcached.NewFixedRateLimitCacheImpl(client, timeSource, nil, 0, nil, 0.8, "") + + request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) + limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) client.EXPECT().GetMulti([]string{"domain_key_value_1234"}).Return( - getMultiResult(map[string]int{"domain_key_value_1234": 4}), nil, + getMultiResult(map[string]int{"domain_key_value_1234": 0}), nil, ) - client.EXPECT().Increment("domain_key_value_1234", uint64(1)).Return(uint64(5), nil) - - request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) - limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} + client.EXPECT().Increment("domain_key_value_1234", uint64(1)) assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 5, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: utils.CalculateFixedReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) @@ -69,7 +68,7 @@ func TestMemcached(t *testing.T) { config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, "key2_value2_subkey2_subvalue2", statsStore)} assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: nil, LimitRemaining: 0}, - {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[1].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[1].Limit, timeSource)}}, + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[1].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateFixedReset(limits[1].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(1), limits[1].Stats.TotalHits.Value()) assert.Equal(uint64(1), limits[1].Stats.OverLimit.Value()) @@ -99,8 +98,8 @@ func TestMemcached(t *testing.T) { config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_DAY, "key3_value3_subkey3_subvalue3", statsStore)} assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}, - {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[1].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[1].Limit, timeSource)}}, + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateFixedReset(limits[0].Limit, timeSource)}, + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[1].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateFixedReset(limits[1].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(1), limits[0].Stats.OverLimit.Value()) @@ -120,7 +119,7 @@ func TestMemcachedGetError(t *testing.T) { timeSource := mock_utils.NewMockTimeSource(controller) client := mock_memcached.NewMockClient(controller) statsStore := stats.NewStore(stats.NewNullSink(), false) - cache := memcached.NewRateLimitCacheImpl(client, timeSource, nil, 0, nil, statsStore, 0.8, "") + cache := memcached.NewFixedRateLimitCacheImpl(client, timeSource, nil, 0, nil, 0.8, "") timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) client.EXPECT().GetMulti([]string{"domain_key_value_1234"}).Return( @@ -132,7 +131,7 @@ func TestMemcachedGetError(t *testing.T) { limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: utils.CalculateFixedReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) @@ -149,7 +148,7 @@ func TestMemcachedGetError(t *testing.T) { limits = []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value1", statsStore)} assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: utils.CalculateFixedReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) @@ -204,8 +203,8 @@ func TestOverLimitWithLocalCache(t *testing.T) { localCache := freecache.NewCache(100) sink := &common.TestStatSink{} statsStore := stats.NewStore(sink, true) - cache := memcached.NewRateLimitCacheImpl(client, timeSource, nil, 0, localCache, statsStore, 0.8, "") - localCacheStats := limiter.NewLocalCacheStats(localCache, statsStore.Scope("localcache")) + cache := memcached.NewFixedRateLimitCacheImpl(client, timeSource, nil, 0, localCache, 0.8, "") + localCacheStats := utils.NewLocalCacheStats(localCache, statsStore.Scope("localcache")) // Test Near Limit Stats. Under Near Limit Ratio timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) @@ -221,7 +220,7 @@ func TestOverLimitWithLocalCache(t *testing.T) { assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 4, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 4, DurationUntilReset: utils.CalculateFixedReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) @@ -240,7 +239,7 @@ func TestOverLimitWithLocalCache(t *testing.T) { assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 2, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 2, DurationUntilReset: utils.CalculateFixedReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(2), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) @@ -259,7 +258,7 @@ func TestOverLimitWithLocalCache(t *testing.T) { assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateFixedReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(1), limits[0].Stats.OverLimit.Value()) @@ -275,7 +274,7 @@ func TestOverLimitWithLocalCache(t *testing.T) { client.EXPECT().Increment("domain_key4_value4_997200", uint64(1)).Times(0) assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateFixedReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(4), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(2), limits[0].Stats.OverLimit.Value()) @@ -296,7 +295,11 @@ func TestNearLimit(t *testing.T) { timeSource := mock_utils.NewMockTimeSource(controller) client := mock_memcached.NewMockClient(controller) statsStore := stats.NewStore(stats.NewNullSink(), false) - cache := memcached.NewRateLimitCacheImpl(client, timeSource, nil, 0, nil, statsStore, 0.8, "") + cache := memcached.NewFixedRateLimitCacheImpl(client, timeSource, nil, 0, nil, 0.8, "") + + request := common.NewRateLimitRequest("domain", [][][2]string{{{"key4", "value4"}}}, 1) + limits := []*config.RateLimit{ + config.NewRateLimit(15, pb.RateLimitResponse_RateLimit_HOUR, "key4_value4", statsStore)} // Test Near Limit Stats. Under Near Limit Ratio timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) @@ -305,14 +308,9 @@ func TestNearLimit(t *testing.T) { ) client.EXPECT().Increment("domain_key4_value4_997200", uint64(1)).Return(uint64(11), nil) - request := common.NewRateLimitRequest("domain", [][][2]string{{{"key4", "value4"}}}, 1) - - limits := []*config.RateLimit{ - config.NewRateLimit(15, pb.RateLimitResponse_RateLimit_HOUR, "key4_value4", statsStore)} - assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 4, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 4, DurationUntilReset: utils.CalculateFixedReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) @@ -327,7 +325,7 @@ func TestNearLimit(t *testing.T) { assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 2, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 2, DurationUntilReset: utils.CalculateFixedReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(2), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) @@ -343,7 +341,7 @@ func TestNearLimit(t *testing.T) { assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateFixedReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(1), limits[0].Stats.OverLimit.Value()) @@ -351,102 +349,102 @@ func TestNearLimit(t *testing.T) { // Now test hitsAddend that is greater than 1 // All of it under limit, under near limit + request = common.NewRateLimitRequest("domain", [][][2]string{{{"key5", "value5"}}}, 3) + limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, "key5_value5", statsStore)} + timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) client.EXPECT().GetMulti([]string{"domain_key5_value5_1234"}).Return( getMultiResult(map[string]int{"domain_key5_value5_1234": 2}), nil, ) client.EXPECT().Increment("domain_key5_value5_1234", uint64(3)).Return(uint64(5), nil) - request = common.NewRateLimitRequest("domain", [][][2]string{{{"key5", "value5"}}}, 3) - limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, "key5_value5", statsStore)} - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 15, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 15, DurationUntilReset: utils.CalculateFixedReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) // All of it under limit, some over near limit + request = common.NewRateLimitRequest("domain", [][][2]string{{{"key6", "value6"}}}, 2) + limits = []*config.RateLimit{config.NewRateLimit(8, pb.RateLimitResponse_RateLimit_SECOND, "key6_value6", statsStore)} + timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) client.EXPECT().GetMulti([]string{"domain_key6_value6_1234"}).Return( getMultiResult(map[string]int{"domain_key6_value6_1234": 5}), nil, ) client.EXPECT().Increment("domain_key6_value6_1234", uint64(2)).Return(uint64(7), nil) - request = common.NewRateLimitRequest("domain", [][][2]string{{{"key6", "value6"}}}, 2) - limits = []*config.RateLimit{config.NewRateLimit(8, pb.RateLimitResponse_RateLimit_SECOND, "key6_value6", statsStore)} - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 1, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 1, DurationUntilReset: utils.CalculateFixedReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(2), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) assert.Equal(uint64(1), limits[0].Stats.NearLimit.Value()) // All of it under limit, all of it over near limit + request = common.NewRateLimitRequest("domain", [][][2]string{{{"key7", "value7"}}}, 3) + limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, "key7_value7", statsStore)} + timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) client.EXPECT().GetMulti([]string{"domain_key7_value7_1234"}).Return( getMultiResult(map[string]int{"domain_key7_value7_1234": 16}), nil, ) client.EXPECT().Increment("domain_key7_value7_1234", uint64(3)).Return(uint64(19), nil) - request = common.NewRateLimitRequest("domain", [][][2]string{{{"key7", "value7"}}}, 3) - limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, "key7_value7", statsStore)} - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 1, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 1, DurationUntilReset: utils.CalculateFixedReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) assert.Equal(uint64(3), limits[0].Stats.NearLimit.Value()) // Some of it over limit, all of it over near limit + request = common.NewRateLimitRequest("domain", [][][2]string{{{"key8", "value8"}}}, 3) + limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, "key8_value8", statsStore)} + timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) client.EXPECT().GetMulti([]string{"domain_key8_value8_1234"}).Return( getMultiResult(map[string]int{"domain_key8_value8_1234": 19}), nil, ) client.EXPECT().Increment("domain_key8_value8_1234", uint64(3)).Return(uint64(22), nil) - request = common.NewRateLimitRequest("domain", [][][2]string{{{"key8", "value8"}}}, 3) - limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, "key8_value8", statsStore)} - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateFixedReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(2), limits[0].Stats.OverLimit.Value()) assert.Equal(uint64(1), limits[0].Stats.NearLimit.Value()) // Some of it in all three places + request = common.NewRateLimitRequest("domain", [][][2]string{{{"key9", "value9"}}}, 7) + limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, "key9_value9", statsStore)} + timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) client.EXPECT().GetMulti([]string{"domain_key9_value9_1234"}).Return( getMultiResult(map[string]int{"domain_key9_value9_1234": 15}), nil, ) client.EXPECT().Increment("domain_key9_value9_1234", uint64(7)).Return(uint64(22), nil) - request = common.NewRateLimitRequest("domain", [][][2]string{{{"key9", "value9"}}}, 7) - limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, "key9_value9", statsStore)} - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateFixedReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(7), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(2), limits[0].Stats.OverLimit.Value()) assert.Equal(uint64(4), limits[0].Stats.NearLimit.Value()) // all of it over limit + request = common.NewRateLimitRequest("domain", [][][2]string{{{"key10", "value10"}}}, 3) + limits = []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key10_value10", statsStore)} + timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) client.EXPECT().GetMulti([]string{"domain_key10_value10_1234"}).Return( getMultiResult(map[string]int{"domain_key10_value10_1234": 27}), nil, ) client.EXPECT().Increment("domain_key10_value10_1234", uint64(3)).Return(uint64(30), nil) - request = common.NewRateLimitRequest("domain", [][][2]string{{{"key10", "value10"}}}, 3) - limits = []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key10_value10", statsStore)} - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateFixedReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(3), limits[0].Stats.OverLimit.Value()) @@ -464,7 +462,10 @@ func TestMemcacheWithJitter(t *testing.T) { client := mock_memcached.NewMockClient(controller) jitterSource := mock_utils.NewMockJitterRandSource(controller) statsStore := stats.NewStore(stats.NewNullSink(), false) - cache := memcached.NewRateLimitCacheImpl(client, timeSource, rand.New(jitterSource), 3600, nil, statsStore, 0.8, "") + cache := memcached.NewFixedRateLimitCacheImpl(client, timeSource, rand.New(jitterSource), 3600, nil, 0.8, "") + + request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) + limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) jitterSource.EXPECT().Int63().Return(int64(100)) @@ -484,11 +485,8 @@ func TestMemcacheWithJitter(t *testing.T) { }, ).Return(nil) - request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) - limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: utils.CalculateFixedReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) @@ -505,7 +503,10 @@ func TestMemcacheAdd(t *testing.T) { timeSource := mock_utils.NewMockTimeSource(controller) client := mock_memcached.NewMockClient(controller) statsStore := stats.NewStore(stats.NewNullSink(), false) - cache := memcached.NewRateLimitCacheImpl(client, timeSource, nil, 0, nil, statsStore, 0.8, "") + cache := memcached.NewFixedRateLimitCacheImpl(client, timeSource, nil, 0, nil, 0.8, "") + + request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) + limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} // Test a race condition with the initial add timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) @@ -525,17 +526,17 @@ func TestMemcacheAdd(t *testing.T) { client.EXPECT().Increment("domain_key_value_1234", uint64(1)).Return( uint64(2), nil) - request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) - limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: utils.CalculateFixedReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) // A rate limit with 1-minute window + request = common.NewRateLimitRequest("domain", [][][2]string{{{"key2", "value2"}}}, 1) + limits = []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, "key2_value2", statsStore)} + timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) client.EXPECT().GetMulti([]string{"domain_key2_value2_1200"}).Return(nil, nil) client.EXPECT().Increment("domain_key2_value2_1200", uint64(1)).Return( @@ -548,11 +549,8 @@ func TestMemcacheAdd(t *testing.T) { }, ).Return(nil) - request = common.NewRateLimitRequest("domain", [][][2]string{{{"key2", "value2"}}}, 1) - limits = []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, "key2_value2", statsStore)} - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: utils.CalculateFixedReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) diff --git a/test/memcached/stats_collecting_client_test.go b/test/memcached/stats_collecting_client_test.go index 548b93041..d24f92f54 100644 --- a/test/memcached/stats_collecting_client_test.go +++ b/test/memcached/stats_collecting_client_test.go @@ -6,7 +6,8 @@ import ( "github.com/bradfitz/gomemcache/memcache" "github.com/envoyproxy/ratelimit/src/memcached" - mock_memcached "github.com/envoyproxy/ratelimit/test/mocks/memcached" + + mock_memcached "github.com/envoyproxy/ratelimit/test/mocks/memcached/driver" "github.com/golang/mock/gomock" stats "github.com/lyft/gostats" "github.com/stretchr/testify/assert" @@ -197,3 +198,38 @@ func TestStats_Add(t *testing.T) { "add.__code=not_stored": 1, }, fakeSink.values) } + +func TestStats_Set(t *testing.T) { + fakeSink := &fakeSink{} + + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + client := mock_memcached.NewMockClient(controller) + statsStore := stats.NewStore(fakeSink, false) + + sc := memcached.CollectStats(client, statsStore) + item := &memcache.Item{} + + fakeSink.Reset() + client.EXPECT().Set(item).Return(nil) + err := sc.Set(item) + statsStore.Flush() + + assert.Nil(err) + assert.Equal(map[string]uint64{ + "set.__code=success": 1, + }, fakeSink.values) + + expectedErr := errors.New("expected err") + + fakeSink.Reset() + client.EXPECT().Set(item).Return(expectedErr) + err = sc.Set(item) + statsStore.Flush() + + assert.Equal(expectedErr, err) + assert.Equal(map[string]uint64{ + "set.__code=error": 1, + }, fakeSink.values) +} diff --git a/test/memcached/windowed_cache_impl_test.go b/test/memcached/windowed_cache_impl_test.go new file mode 100644 index 000000000..aad609780 --- /dev/null +++ b/test/memcached/windowed_cache_impl_test.go @@ -0,0 +1,505 @@ +package memcached_test + +import ( + "math/rand" + "strconv" + "testing" + + "github.com/bradfitz/gomemcache/memcache" + "github.com/coocood/freecache" + pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" + "github.com/envoyproxy/ratelimit/src/config" + "github.com/envoyproxy/ratelimit/src/memcached" + "github.com/envoyproxy/ratelimit/src/utils" + "github.com/envoyproxy/ratelimit/test/common" + mock_memcached "github.com/envoyproxy/ratelimit/test/mocks/memcached/driver" + mock_utils "github.com/envoyproxy/ratelimit/test/mocks/utils" + "github.com/golang/mock/gomock" + "github.com/golang/protobuf/ptypes/duration" + stats "github.com/lyft/gostats" + "github.com/stretchr/testify/assert" +) + +func TestMemcachedWindowed(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + timeSource := mock_utils.NewMockTimeSource(controller) + client := mock_memcached.NewMockClient(controller) + statsStore := stats.NewStore(stats.NewNullSink(), false) + cache := memcached.NewWindowedRateLimitCacheImpl(client, timeSource, nil, 0, nil, 0.8, "") + + // test 1 + // test initial rate limit process + + // periode = 1 second + // limit = 10 request/second + // emissionInterval = 0.1 second + // request = 1 + // increment = emissionInterval*request = 0.1 second + + // arriveAt = 1 second + // tat = 1 second + + // newTat should be max(arriveAt,tat)+increment = 1.1 second + // DurationUntilReset should be newTat-arriveat = 0.1 minute + // expiration should be second(newTat-arriveat)+1 = 0.1 minute + + request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) + limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} + + timeSource.EXPECT().UnixNanoNow().Return(int64(1e9)).MaxTimes(1) + + client.EXPECT().GetMulti([]string{"domain_key_value_0"}).Return( + getMultiResult(map[string]int{"domain_key_value_0": 1e9}), nil, + ) + client.EXPECT().Set(&memcache.Item{ + Key: "domain_key_value_0", + Value: []byte(strconv.FormatInt(int64(1e9+1e8), 10)), + Expiration: int32(1), + }) + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: &duration.Duration{Nanos: 1e8}}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) + + // test 2 + // test rate limit with multiple description + + // periode = 1 minute = 60 second + // limit = 10 request/minute + // emissionInterval = 6 second + // request = 1 + // increment = emissionInterval*request = 6 second + + // arriveAt = 1 second + // tat = 1 second + + // newTat should be max(arriveAt,tat)+increment = 7 second + // DurationUntilReset should be newTat-arriveat = 6 second + // expiration should be second(newTat-arriveat)+1 = 7 second + + request = common.NewRateLimitRequest( + "domain", + [][][2]string{ + {{"key2", "value2"}}, + {{"key2", "value2"}, {"subkey2", "subvalue2"}}, + }, 1) + limits = []*config.RateLimit{ + nil, + config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, "key2_value2_subkey2_subvalue2", statsStore)} + + timeSource.EXPECT().UnixNanoNow().Return(int64(1e9)).MaxTimes(1) + + client.EXPECT().GetMulti([]string{"domain_key2_value2_subkey2_subvalue2_0"}).Return( + getMultiResult(map[string]int{"domain_key2_value2_subkey2_subvalue2_0": 1e9}), nil, + ) + client.EXPECT().Set(&memcache.Item{ + Key: "domain_key2_value2_subkey2_subvalue2_0", + Value: []byte(strconv.FormatInt(int64(1e9+6e9), 10)), + Expiration: int32(7), + }) + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: nil, LimitRemaining: 0}, + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[1].Limit, LimitRemaining: 9, DurationUntilReset: &duration.Duration{Seconds: 6}}}, + cache.DoLimit(nil, request, limits)) + + assert.Equal(uint64(1), limits[1].Stats.TotalHits.Value()) + assert.Equal(uint64(0), limits[1].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[1].Stats.NearLimit.Value()) + + // test 3 + // test rate limit with multiple description and different limit configuration + + // periode = 1 hour = 3600 second + // limit = 10 request/hour + // emissionInterval = 6 minute = 360 second + // request = 1 + // increment = emissionInterval*request = 6 minute = 360 second + + // arriveAt = 1 second + // tat = 1 second + + // newTat should be max(arriveAt,tat)+increment = 361 second + // DurationUntilReset should be newTat-arriveat = 360 second + // expiration should be second(newTat-arriveat)+1 = 361 second + + // periode = 1 day = 86400 second + // limit = 10 request/day + // emissionInterval = 8640 second + // request = 1 + // increment = emissionInterval*request = 8640 second + + // arriveAt = 1 second + // tat = 1 second + + // newTat should be max(arriveAt,tat)+increment = 8641 second + // DurationUntilReset should be newTat-arriveat = 8640 second + // expiration should be second(newTat-arriveat)+1 = 8641 second + + request = common.NewRateLimitRequest( + "domain", + [][][2]string{ + {{"key3", "value3"}}, + {{"key3", "value3"}, {"subkey3", "subvalue3"}}, + }, 1) + limits = []*config.RateLimit{ + config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_HOUR, "key3_value3", statsStore), + config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_DAY, "key3_value3_subkey3_subvalue3", statsStore)} + + timeSource.EXPECT().UnixNanoNow().Return(int64(1e9)).MaxTimes(2) + + client.EXPECT().GetMulti([]string{"domain_key3_value3_0", "domain_key3_value3_subkey3_subvalue3_0"}).Return( + getMultiResult(map[string]int{"domain_key3_value3_0": 1e9, "domain_key3_value3_subkey3_subvalue3_0": 1e9}), nil, + ) + client.EXPECT().Set(&memcache.Item{ + Key: "domain_key3_value3_0", + Value: []byte(strconv.FormatInt(int64(1e9+360e9), 10)), + Expiration: int32(361), + }) + client.EXPECT().Set(&memcache.Item{ + Key: "domain_key3_value3_subkey3_subvalue3_0", + Value: []byte(strconv.FormatInt(int64(1e9+8640e9), 10)), + Expiration: int32(8641), + }) + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{ + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: &duration.Duration{Seconds: 360}}, + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[1].Limit, LimitRemaining: 9, DurationUntilReset: &duration.Duration{Seconds: 8640}}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) + assert.Equal(uint64(1), limits[1].Stats.TotalHits.Value()) + assert.Equal(uint64(0), limits[1].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[1].Stats.NearLimit.Value()) + + cache.Flush() +} + +func TestNearLimitWindowed(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + timeSource := mock_utils.NewMockTimeSource(controller) + client := mock_memcached.NewMockClient(controller) + statsStore := stats.NewStore(stats.NewNullSink(), false) + cache := memcached.NewWindowedRateLimitCacheImpl(client, timeSource, nil, 0, nil, 0.8, "") + domain := "domain" + + request := common.NewRateLimitRequest(domain, [][][2]string{{{"key4", "value4"}}}, 1) + limits := []*config.RateLimit{ + config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, "key4_value4", statsStore)} + + // Test Near Limit Stats. Under Near Limit Ratio + // periode = 1 minute = 60 second + // limit = 10 request/minute + // emissionInterval = 6 second + // request = 1 + // increment = emissionInterval*request = 6 second + + // arriveAt = 01 second + // tat = 01 second + + // newTat should be max(arriveAt,tat)+increment = 7 second + // expire should be (newtat-arriveat)+1 = 7 second + // DurationUntilReset should be newtat-arriveat = 6 second + + timeSource.EXPECT().UnixNanoNow().Return(int64(1e9)).MaxTimes(1) + + client.EXPECT().GetMulti([]string{"domain_key4_value4_0"}).Return( + getMultiResult(map[string]int{"domain_key4_value4_0": 1e9}), nil, + ) + client.EXPECT().Set(&memcache.Item{ + Key: "domain_key4_value4_0", + Value: []byte(strconv.FormatInt(int64(1e9+6e9), 10)), + Expiration: int32(7), + }) + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{ + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: &duration.Duration{Seconds: 6}}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) + + // Test Near Limit Stats. At Near Limit Ratio, still OK + // periode = 1 minute = 60 second + // limit = 10 request/minute + // emissionInterval = 6 second + // request = 1 + // increment = emissionInterval*request = 6 second + + // arriveAt = 07 second + // tat = 54 second + + // newTat should be max(arriveAt,tat)+increment = 60 second + // expire should be (newtat-arriveat)+1 = 54 second + // DurationUntilReset should be newtat-arriveat = 53 second + timeSource.EXPECT().UnixNanoNow().Return(int64(7e9)).MaxTimes(1) + + client.EXPECT().GetMulti([]string{"domain_key4_value4_0"}).Return( + getMultiResult(map[string]int{"domain_key4_value4_0": 54e9}), nil, + ) + client.EXPECT().Set(&memcache.Item{ + Key: "domain_key4_value4_0", + Value: []byte(strconv.FormatInt(int64(60e9), 10)), + Expiration: int32(54), + }) + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{ + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 1, DurationUntilReset: &duration.Duration{Seconds: 53}}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(2), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(1), limits[0].Stats.NearLimit.Value()) + + // Test Near Limit Stats. We went OVER_LIMIT, but the near_limit counter only increases + // when we are near limit, not after we have passed the limit. + // periode = 1 minute = 60 second + // limit = 10 request/minute + // emissionInterval = 6 second + // request = 1 + // increment = emissionInterval*request = 6 second + + // arriveAt = 04 second + // tat = 60 second + + // newTat should be max(arriveAt,tat)+increment = 66 second + // expire should be (tat-arriveat)+1 = 57 second + // DurationUntilReset should be tat-arriveat = 56 second + timeSource.EXPECT().UnixNanoNow().Return(int64(4e9)).MaxTimes(1) + + client.EXPECT().GetMulti([]string{"domain_key4_value4_0"}).Return( + getMultiResult(map[string]int{"domain_key4_value4_0": 60e9}), nil, + ) + client.EXPECT().Set(&memcache.Item{ + Key: "domain_key4_value4_0", + Value: []byte(strconv.FormatInt(int64(60e9), 10)), + Expiration: int32(57), + }) + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{ + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: &duration.Duration{Seconds: 56}}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(1), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(1), limits[0].Stats.NearLimit.Value()) + + cache.Flush() +} + +func TestWindowedOverLimitWithLocalCache(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + timeSource := mock_utils.NewMockTimeSource(controller) + client := mock_memcached.NewMockClient(controller) + sink := &common.TestStatSink{} + statsStore := stats.NewStore(sink, true) + localCache := freecache.NewCache(100) + localCacheStats := utils.NewLocalCacheStats(localCache, statsStore.Scope("localcache")) + cache := memcached.NewWindowedRateLimitCacheImpl(client, timeSource, nil, 0, localCache, 0.8, "") + domain := "domain" + + request := common.NewRateLimitRequest(domain, [][][2]string{{{"key", "value"}}}, 1) + limits := []*config.RateLimit{ + config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_HOUR, "key_value", statsStore)} + // Test Near Limit Stats. Under Near Limit Ratio + // periode = 60 minute = 3600 second + // limit = 10 request/hour + // emissionInterval = 6 minute + // request = 1 + // increment = emissionInterval*request = 6 minute + + // arriveAt = 1 minute + // tat = 12 minite + + // newTat should be max(arriveAt,tat)+increment = 18 minute + // expire should be (newtat-arriveat)+1 second = 17 minute 1 second + // DurationUntilReset should be newtat-arriveat = 17 minute + timeSource.EXPECT().UnixNanoNow().Return(int64(1 * 60 * 1e9)).MaxTimes(1) + + client.EXPECT().GetMulti([]string{"domain_key_value_0"}).Return( + getMultiResult(map[string]int{"domain_key_value_0": 12 * 60 * 1e9}), nil, + ) + client.EXPECT().Set(&memcache.Item{ + Key: "domain_key_value_0", + Value: []byte(strconv.FormatInt(int64(18*60*1e9), 10)), + Expiration: int32((17 * 60) + 1), + }) + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{ + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 7, DurationUntilReset: &duration.Duration{Seconds: 17 * 60}}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimitWithLocalCache.Value()) + assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) + + // Check the local cache stats. + testLocalCacheStats(localCacheStats, statsStore, sink, 0, 1, 1, 0, 0) + + // Test Near Limit Stats. At Near Limit Ratio, still OK + // periode = 60 minute = 3600 second + // limit = 10 request/hour + // emissionInterval = 6 minute + // request = 1 + // increment = emissionInterval*request = 6 minute + + // arriveAt = 12 minute + // tat = 60 minute + + // newTat should be max(arriveAt,tat)+increment = 66 minute + // expire should be (newtat-arriveat)+1 second = 54 minute 1 second + // DurationUntilReset should be newtat-arriveat = 54 minute + timeSource.EXPECT().UnixNanoNow().Return(int64(12 * 60 * 1e9)).MaxTimes(1) + + client.EXPECT().GetMulti([]string{"domain_key_value_0"}).Return( + getMultiResult(map[string]int{"domain_key_value_0": 60 * 60 * 1e9}), nil, + ) + client.EXPECT().Set(&memcache.Item{ + Key: "domain_key_value_0", + Value: []byte(strconv.FormatInt(int64(66*60*1e9), 10)), + Expiration: int32((54 * 60) + 1), + }) + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{ + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 1, DurationUntilReset: &duration.Duration{Seconds: 54 * 60}}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(2), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimitWithLocalCache.Value()) + assert.Equal(uint64(1), limits[0].Stats.NearLimit.Value()) + + // Check the local cache stats. + testLocalCacheStats(localCacheStats, statsStore, sink, 0, 2, 2, 0, 0) + + // Test Over limit stats + // periode = 60 minute = 3600 second + // limit = 10 request/hour + // emissionInterval = 6 minute + // request = 1 + // increment = emissionInterval*request = 6 minute + + // arriveAt = 2 minute + // tat = 72 minute + + // newTat should be max(arriveAt,tat)+increment = 78 minute (not used) + // expire should be (tat-arriveat)+1 second = 70 minute 1 second + // DurationUntilReset should be tat-arriveat = 70 minute + timeSource.EXPECT().UnixNanoNow().Return(int64(2 * 60 * 1e9)).MaxTimes(1) + + client.EXPECT().GetMulti([]string{"domain_key_value_0"}).Return( + getMultiResult(map[string]int{"domain_key_value_0": 72 * 60 * 1e9}), nil, + ) + client.EXPECT().Set(&memcache.Item{ + Key: "domain_key_value_0", + Value: []byte(strconv.FormatInt(int64(72*60*1e9), 10)), + Expiration: int32((70 * 60) + 1), + }) + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{ + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: &duration.Duration{Seconds: 70 * 60}}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(1), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimitWithLocalCache.Value()) + assert.Equal(uint64(1), limits[0].Stats.NearLimit.Value()) + + // Check the local cache stats. + testLocalCacheStats(localCacheStats, statsStore, sink, 0, 2, 3, 0, 1) + + // Test Over limit stats with local cache + // periode = 60 minute = 3600 second + // limit = 10 request/hour + // emissionInterval = 6 minute + // request = 1 + // increment = emissionInterval*request = 6 minute + + // arriveAt = 3 minute + // tat = 72 minute + + // newTat should be max(arriveAt,tat)+increment = 78 minute (not used) + // expire should be (tat-arriveat)+1 second = 69 minute 1 second + // DurationUntilReset should be secondsToReset-(arriveAt%secondsToReset) = 57 minute + timeSource.EXPECT().UnixNanoNow().Return(int64(3 * 60 * 1e9)).MaxTimes(1) + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{ + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: &duration.Duration{Seconds: 57 * 60}}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(4), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(2), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(1), limits[0].Stats.OverLimitWithLocalCache.Value()) + assert.Equal(uint64(1), limits[0].Stats.NearLimit.Value()) + + // Check the local cache stats. + testLocalCacheStats(localCacheStats, statsStore, sink, 1, 3, 4, 0, 1) + + cache.Flush() +} + +func TestRedisWindowedWithJitter(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + timeSource := mock_utils.NewMockTimeSource(controller) + client := mock_memcached.NewMockClient(controller) + statsStore := stats.NewStore(stats.NewNullSink(), false) + jitterSource := mock_utils.NewMockJitterRandSource(controller) + cache := memcached.NewWindowedRateLimitCacheImpl(client, timeSource, rand.New(jitterSource), 3600, nil, 0.8, "") + + request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) + limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} + + // periode = 1 second + // limit = 10 request/second + // emissionInterval = 1/10 second + // request = 1 + // increment = emissionInterval*request = 1/10 second + + // arriveAt = 1 second + // tat = 1 second + + // newTat should be max(arriveAt,tat)+increment = 1,1 second + // expire should be (tat-arriveat)+1 second = 1 second + // DurationUntilReset should be newTat-arriveat = 0.1 second + + timeSource.EXPECT().UnixNanoNow().Return(int64(1e9)).MaxTimes(1) + + client.EXPECT().GetMulti([]string{"domain_key_value_0"}).Return( + getMultiResult(map[string]int{"domain_key_value_0": 1e9}), nil, + ) + client.EXPECT().Set(&memcache.Item{ + Key: "domain_key_value_0", + Value: []byte(strconv.FormatInt(int64(1e9+1e8), 10)), + Expiration: int32(100 + 1), + }) + + jitterSource.EXPECT().Int63().Return(int64(100)) + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: &duration.Duration{Nanos: 1e8}}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) + + cache.Flush() +} diff --git a/test/mocks/algorithm/ratelimit_algorithm.go b/test/mocks/algorithm/ratelimit_algorithm.go new file mode 100644 index 000000000..c2d5dd035 --- /dev/null +++ b/test/mocks/algorithm/ratelimit_algorithm.go @@ -0,0 +1,134 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ./src/algorithm/ratelimit_algorithm.go + +// Package mock_algorithm is a generated GoMock package. +package mock_algorithm + +import ( + envoy_service_ratelimit_v3 "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" + config "github.com/envoyproxy/ratelimit/src/config" + utils "github.com/envoyproxy/ratelimit/src/utils" + gomock "github.com/golang/mock/gomock" + reflect "reflect" +) + +// MockRatelimitAlgorithm is a mock of RatelimitAlgorithm interface +type MockRatelimitAlgorithm struct { + ctrl *gomock.Controller + recorder *MockRatelimitAlgorithmMockRecorder +} + +// MockRatelimitAlgorithmMockRecorder is the mock recorder for MockRatelimitAlgorithm +type MockRatelimitAlgorithmMockRecorder struct { + mock *MockRatelimitAlgorithm +} + +// NewMockRatelimitAlgorithm creates a new mock instance +func NewMockRatelimitAlgorithm(ctrl *gomock.Controller) *MockRatelimitAlgorithm { + mock := &MockRatelimitAlgorithm{ctrl: ctrl} + mock.recorder = &MockRatelimitAlgorithmMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockRatelimitAlgorithm) EXPECT() *MockRatelimitAlgorithmMockRecorder { + return m.recorder +} + +// IsOverLimit mocks base method +func (m *MockRatelimitAlgorithm) IsOverLimit(limit *config.RateLimit, results, hitsAddend int64) (bool, int64, int) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsOverLimit", limit, results, hitsAddend) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(int64) + ret2, _ := ret[2].(int) + return ret0, ret1, ret2 +} + +// IsOverLimit indicates an expected call of IsOverLimit +func (mr *MockRatelimitAlgorithmMockRecorder) IsOverLimit(limit, results, hitsAddend interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsOverLimit", reflect.TypeOf((*MockRatelimitAlgorithm)(nil).IsOverLimit), limit, results, hitsAddend) +} + +// IsOverLimitWithLocalCache mocks base method +func (m *MockRatelimitAlgorithm) IsOverLimitWithLocalCache(key string) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsOverLimitWithLocalCache", key) + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsOverLimitWithLocalCache indicates an expected call of IsOverLimitWithLocalCache +func (mr *MockRatelimitAlgorithmMockRecorder) IsOverLimitWithLocalCache(key interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsOverLimitWithLocalCache", reflect.TypeOf((*MockRatelimitAlgorithm)(nil).IsOverLimitWithLocalCache), key) +} + +// GetResponseDescriptorStatus mocks base method +func (m *MockRatelimitAlgorithm) GetResponseDescriptorStatus(key string, limit *config.RateLimit, results int64, isOverLimitWithLocalCache bool, hitsAddend int64) *envoy_service_ratelimit_v3.RateLimitResponse_DescriptorStatus { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetResponseDescriptorStatus", key, limit, results, isOverLimitWithLocalCache, hitsAddend) + ret0, _ := ret[0].(*envoy_service_ratelimit_v3.RateLimitResponse_DescriptorStatus) + return ret0 +} + +// GetResponseDescriptorStatus indicates an expected call of GetResponseDescriptorStatus +func (mr *MockRatelimitAlgorithmMockRecorder) GetResponseDescriptorStatus(key, limit, results, isOverLimitWithLocalCache, hitsAddend interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetResponseDescriptorStatus", reflect.TypeOf((*MockRatelimitAlgorithm)(nil).GetResponseDescriptorStatus), key, limit, results, isOverLimitWithLocalCache, hitsAddend) +} + +// GetNewTat mocks base method +func (m *MockRatelimitAlgorithm) GetNewTat() int64 { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetNewTat") + ret0, _ := ret[0].(int64) + return ret0 +} + +// GetNewTat indicates an expected call of GetNewTat +func (mr *MockRatelimitAlgorithmMockRecorder) GetNewTat() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNewTat", reflect.TypeOf((*MockRatelimitAlgorithm)(nil).GetNewTat)) +} + +// GetArrivedAt mocks base method +func (m *MockRatelimitAlgorithm) GetArrivedAt() int64 { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetArrivedAt") + ret0, _ := ret[0].(int64) + return ret0 +} + +// GetArrivedAt indicates an expected call of GetArrivedAt +func (mr *MockRatelimitAlgorithmMockRecorder) GetArrivedAt() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetArrivedAt", reflect.TypeOf((*MockRatelimitAlgorithm)(nil).GetArrivedAt)) +} + +// GenerateCacheKeys mocks base method +func (m *MockRatelimitAlgorithm) GenerateCacheKeys(request *envoy_service_ratelimit_v3.RateLimitRequest, limits []*config.RateLimit, hitsAddend int64) []utils.CacheKey { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GenerateCacheKeys", request, limits, hitsAddend) + ret0, _ := ret[0].([]utils.CacheKey) + return ret0 +} + +// GenerateCacheKeys indicates an expected call of GenerateCacheKeys +func (mr *MockRatelimitAlgorithmMockRecorder) GenerateCacheKeys(request, limits, hitsAddend interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GenerateCacheKeys", reflect.TypeOf((*MockRatelimitAlgorithm)(nil).GenerateCacheKeys), request, limits, hitsAddend) +} + +// PopulateStats mocks base method +func (m *MockRatelimitAlgorithm) PopulateStats(limit *config.RateLimit, nearLimit, overLimit, overLimitWithLocalCache uint64) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "PopulateStats", limit, nearLimit, overLimit, overLimitWithLocalCache) +} + +// PopulateStats indicates an expected call of PopulateStats +func (mr *MockRatelimitAlgorithmMockRecorder) PopulateStats(limit, nearLimit, overLimit, overLimitWithLocalCache interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopulateStats", reflect.TypeOf((*MockRatelimitAlgorithm)(nil).PopulateStats), limit, nearLimit, overLimit, overLimitWithLocalCache) +} diff --git a/test/mocks/limiter/limiter.go b/test/mocks/limiter/rate_limit_cache.go similarity index 76% rename from test/mocks/limiter/limiter.go rename to test/mocks/limiter/rate_limit_cache.go index 48f995a1f..5e9520b3b 100644 --- a/test/mocks/limiter/limiter.go +++ b/test/mocks/limiter/rate_limit_cache.go @@ -1,14 +1,14 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/envoyproxy/ratelimit/src/limiter (interfaces: RateLimitCache) +// Source: ./src/limiter/rate_limit_cache.go // Package mock_limiter is a generated GoMock package. package mock_limiter import ( - context "context" envoy_service_ratelimit_v3 "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" config "github.com/envoyproxy/ratelimit/src/config" gomock "github.com/golang/mock/gomock" + context "golang.org/x/net/context" reflect "reflect" ) @@ -36,17 +36,17 @@ func (m *MockRateLimitCache) EXPECT() *MockRateLimitCacheMockRecorder { } // DoLimit mocks base method -func (m *MockRateLimitCache) DoLimit(arg0 context.Context, arg1 *envoy_service_ratelimit_v3.RateLimitRequest, arg2 []*config.RateLimit) []*envoy_service_ratelimit_v3.RateLimitResponse_DescriptorStatus { +func (m *MockRateLimitCache) DoLimit(ctx context.Context, request *envoy_service_ratelimit_v3.RateLimitRequest, limits []*config.RateLimit) []*envoy_service_ratelimit_v3.RateLimitResponse_DescriptorStatus { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DoLimit", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "DoLimit", ctx, request, limits) ret0, _ := ret[0].([]*envoy_service_ratelimit_v3.RateLimitResponse_DescriptorStatus) return ret0 } // DoLimit indicates an expected call of DoLimit -func (mr *MockRateLimitCacheMockRecorder) DoLimit(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockRateLimitCacheMockRecorder) DoLimit(ctx, request, limits interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DoLimit", reflect.TypeOf((*MockRateLimitCache)(nil).DoLimit), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DoLimit", reflect.TypeOf((*MockRateLimitCache)(nil).DoLimit), ctx, request, limits) } // Flush mocks base method diff --git a/test/mocks/memcached/client.go b/test/mocks/memcached/driver/client.go similarity index 60% rename from test/mocks/memcached/client.go rename to test/mocks/memcached/driver/client.go index 433105bd0..3bc29d8d9 100644 --- a/test/mocks/memcached/client.go +++ b/test/mocks/memcached/driver/client.go @@ -1,13 +1,14 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/envoyproxy/ratelimit/src/memcached (interfaces: Client) +// Source: ./src/memcached/driver/client.go // Package mock_memcached is a generated GoMock package. package mock_memcached import ( + reflect "reflect" + memcache "github.com/bradfitz/gomemcache/memcache" gomock "github.com/golang/mock/gomock" - reflect "reflect" ) // MockClient is a mock of Client interface @@ -33,46 +34,60 @@ func (m *MockClient) EXPECT() *MockClientMockRecorder { return m.recorder } -// Add mocks base method -func (m *MockClient) Add(arg0 *memcache.Item) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Add", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// Add indicates an expected call of Add -func (mr *MockClientMockRecorder) Add(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockClient)(nil).Add), arg0) -} - // GetMulti mocks base method -func (m *MockClient) GetMulti(arg0 []string) (map[string]*memcache.Item, error) { +func (m *MockClient) GetMulti(keys []string) (map[string]*memcache.Item, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetMulti", arg0) + ret := m.ctrl.Call(m, "GetMulti", keys) ret0, _ := ret[0].(map[string]*memcache.Item) ret1, _ := ret[1].(error) return ret0, ret1 } // GetMulti indicates an expected call of GetMulti -func (mr *MockClientMockRecorder) GetMulti(arg0 interface{}) *gomock.Call { +func (mr *MockClientMockRecorder) GetMulti(keys interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMulti", reflect.TypeOf((*MockClient)(nil).GetMulti), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMulti", reflect.TypeOf((*MockClient)(nil).GetMulti), keys) } // Increment mocks base method -func (m *MockClient) Increment(arg0 string, arg1 uint64) (uint64, error) { +func (m *MockClient) Increment(key string, delta uint64) (uint64, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Increment", arg0, arg1) + ret := m.ctrl.Call(m, "Increment", key, delta) ret0, _ := ret[0].(uint64) ret1, _ := ret[1].(error) return ret0, ret1 } // Increment indicates an expected call of Increment -func (mr *MockClientMockRecorder) Increment(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockClientMockRecorder) Increment(key, delta interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Increment", reflect.TypeOf((*MockClient)(nil).Increment), key, delta) +} + +// Add mocks base method +func (m *MockClient) Add(item *memcache.Item) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Add", item) + ret0, _ := ret[0].(error) + return ret0 +} + +// Add indicates an expected call of Add +func (mr *MockClientMockRecorder) Add(item interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockClient)(nil).Add), item) +} + +// Set mocks base method +func (m *MockClient) Set(item *memcache.Item) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Set", item) + ret0, _ := ret[0].(error) + return ret0 +} + +// Set indicates an expected call of Set +func (mr *MockClientMockRecorder) Set(item interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Increment", reflect.TypeOf((*MockClient)(nil).Increment), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Set", reflect.TypeOf((*MockClient)(nil).Set), item) } diff --git a/test/mocks/redis/redis.go b/test/mocks/redis/driver/redis.go similarity index 71% rename from test/mocks/redis/redis.go rename to test/mocks/redis/driver/redis.go index 032b500dc..e9de3ffd4 100644 --- a/test/mocks/redis/redis.go +++ b/test/mocks/redis/driver/redis.go @@ -1,11 +1,11 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/envoyproxy/ratelimit/src/redis (interfaces: Client) +// Source: github.com/envoyproxy/ratelimit/src/redis/driver (interfaces: Client) -// Package mock_redis is a generated GoMock package. -package mock_redis +// Package mock_driver is a generated GoMock package. +package mock_driver import ( - redis "github.com/envoyproxy/ratelimit/src/redis" + driver "github.com/envoyproxy/ratelimit/src/redis/driver" gomock "github.com/golang/mock/gomock" reflect "reflect" ) @@ -33,51 +33,70 @@ func (m *MockClient) EXPECT() *MockClientMockRecorder { return m.recorder } -// Close mocks base method -func (m *MockClient) Close() error { +// DoCmd mocks base method +func (m *MockClient) DoCmd(rcv interface{}, cmd, key string, args ...interface{}) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Close") + varargs := []interface{}{rcv, cmd, key} + for _, a := range args { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "DoCmd", varargs...) ret0, _ := ret[0].(error) return ret0 } -// Close indicates an expected call of Close -func (mr *MockClientMockRecorder) Close() *gomock.Call { +// DoCmd indicates an expected call of DoCmd +func (mr *MockClientMockRecorder) DoCmd(rcv, cmd, key interface{}, args ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockClient)(nil).Close)) + varargs := append([]interface{}{rcv, cmd, key}, args...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DoCmd", reflect.TypeOf((*MockClient)(nil).DoCmd), varargs...) } -// DoCmd mocks base method -func (m *MockClient) DoCmd(arg0 interface{}, arg1, arg2 string, arg3 ...interface{}) error { +// PipeAppend mocks base method +func (m *MockClient) PipeAppend(pipeline driver.Pipeline, rcv interface{}, cmd, key string, args ...interface{}) driver.Pipeline { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2} - for _, a := range arg3 { + varargs := []interface{}{pipeline, rcv, cmd, key} + for _, a := range args { varargs = append(varargs, a) } - ret := m.ctrl.Call(m, "DoCmd", varargs...) + ret := m.ctrl.Call(m, "PipeAppend", varargs...) + ret0, _ := ret[0].(driver.Pipeline) + return ret0 +} + +// PipeAppend indicates an expected call of PipeAppend +func (mr *MockClientMockRecorder) PipeAppend(pipeline, rcv, cmd, key interface{}, args ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{pipeline, rcv, cmd, key}, args...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PipeAppend", reflect.TypeOf((*MockClient)(nil).PipeAppend), varargs...) +} + +// PipeDo mocks base method +func (m *MockClient) PipeDo(pipeline driver.Pipeline) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PipeDo", pipeline) ret0, _ := ret[0].(error) return ret0 } -// DoCmd indicates an expected call of DoCmd -func (mr *MockClientMockRecorder) DoCmd(arg0, arg1, arg2 interface{}, arg3 ...interface{}) *gomock.Call { +// PipeDo indicates an expected call of PipeDo +func (mr *MockClientMockRecorder) PipeDo(pipeline interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2}, arg3...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DoCmd", reflect.TypeOf((*MockClient)(nil).DoCmd), varargs...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PipeDo", reflect.TypeOf((*MockClient)(nil).PipeDo), pipeline) } -// ImplicitPipeliningEnabled mocks base method -func (m *MockClient) ImplicitPipeliningEnabled() bool { +// Close mocks base method +func (m *MockClient) Close() error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ImplicitPipeliningEnabled") - ret0, _ := ret[0].(bool) + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) return ret0 } -// ImplicitPipeliningEnabled indicates an expected call of ImplicitPipeliningEnabled -func (mr *MockClientMockRecorder) ImplicitPipeliningEnabled() *gomock.Call { +// Close indicates an expected call of Close +func (mr *MockClientMockRecorder) Close() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ImplicitPipeliningEnabled", reflect.TypeOf((*MockClient)(nil).ImplicitPipeliningEnabled)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockClient)(nil).Close)) } // NumActiveConns mocks base method @@ -94,35 +113,16 @@ func (mr *MockClientMockRecorder) NumActiveConns() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NumActiveConns", reflect.TypeOf((*MockClient)(nil).NumActiveConns)) } -// PipeAppend mocks base method -func (m *MockClient) PipeAppend(arg0 redis.Pipeline, arg1 interface{}, arg2, arg3 string, arg4 ...interface{}) redis.Pipeline { - m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2, arg3} - for _, a := range arg4 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "PipeAppend", varargs...) - ret0, _ := ret[0].(redis.Pipeline) - return ret0 -} - -// PipeAppend indicates an expected call of PipeAppend -func (mr *MockClientMockRecorder) PipeAppend(arg0, arg1, arg2, arg3 interface{}, arg4 ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2, arg3}, arg4...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PipeAppend", reflect.TypeOf((*MockClient)(nil).PipeAppend), varargs...) -} - -// PipeDo mocks base method -func (m *MockClient) PipeDo(arg0 redis.Pipeline) error { +// ImplicitPipeliningEnabled mocks base method +func (m *MockClient) ImplicitPipeliningEnabled() bool { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PipeDo", arg0) - ret0, _ := ret[0].(error) + ret := m.ctrl.Call(m, "ImplicitPipeliningEnabled") + ret0, _ := ret[0].(bool) return ret0 } -// PipeDo indicates an expected call of PipeDo -func (mr *MockClientMockRecorder) PipeDo(arg0 interface{}) *gomock.Call { +// ImplicitPipeliningEnabled indicates an expected call of ImplicitPipeliningEnabled +func (mr *MockClientMockRecorder) ImplicitPipeliningEnabled() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PipeDo", reflect.TypeOf((*MockClient)(nil).PipeDo), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ImplicitPipeliningEnabled", reflect.TypeOf((*MockClient)(nil).ImplicitPipeliningEnabled)) } diff --git a/test/mocks/utils/utils.go b/test/mocks/utils/jitter_rand_source.go similarity index 53% rename from test/mocks/utils/utils.go rename to test/mocks/utils/jitter_rand_source.go index 1812f4f0f..2289cf8b9 100644 --- a/test/mocks/utils/utils.go +++ b/test/mocks/utils/jitter_rand_source.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/envoyproxy/ratelimit/src/utils (interfaces: TimeSource,JitterRandSource) +// Source: ./src/utils/jitter_rand_source.go // Package mock_utils is a generated GoMock package. package mock_utils @@ -9,43 +9,6 @@ import ( reflect "reflect" ) -// MockTimeSource is a mock of TimeSource interface -type MockTimeSource struct { - ctrl *gomock.Controller - recorder *MockTimeSourceMockRecorder -} - -// MockTimeSourceMockRecorder is the mock recorder for MockTimeSource -type MockTimeSourceMockRecorder struct { - mock *MockTimeSource -} - -// NewMockTimeSource creates a new mock instance -func NewMockTimeSource(ctrl *gomock.Controller) *MockTimeSource { - mock := &MockTimeSource{ctrl: ctrl} - mock.recorder = &MockTimeSourceMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use -func (m *MockTimeSource) EXPECT() *MockTimeSourceMockRecorder { - return m.recorder -} - -// UnixNow mocks base method -func (m *MockTimeSource) UnixNow() int64 { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UnixNow") - ret0, _ := ret[0].(int64) - return ret0 -} - -// UnixNow indicates an expected call of UnixNow -func (mr *MockTimeSourceMockRecorder) UnixNow() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnixNow", reflect.TypeOf((*MockTimeSource)(nil).UnixNow)) -} - // MockJitterRandSource is a mock of JitterRandSource interface type MockJitterRandSource struct { ctrl *gomock.Controller @@ -84,13 +47,13 @@ func (mr *MockJitterRandSourceMockRecorder) Int63() *gomock.Call { } // Seed mocks base method -func (m *MockJitterRandSource) Seed(arg0 int64) { +func (m *MockJitterRandSource) Seed(seed int64) { m.ctrl.T.Helper() - m.ctrl.Call(m, "Seed", arg0) + m.ctrl.Call(m, "Seed", seed) } // Seed indicates an expected call of Seed -func (mr *MockJitterRandSourceMockRecorder) Seed(arg0 interface{}) *gomock.Call { +func (mr *MockJitterRandSourceMockRecorder) Seed(seed interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Seed", reflect.TypeOf((*MockJitterRandSource)(nil).Seed), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Seed", reflect.TypeOf((*MockJitterRandSource)(nil).Seed), seed) } diff --git a/test/mocks/utils/time.go b/test/mocks/utils/time.go new file mode 100644 index 000000000..148bdc521 --- /dev/null +++ b/test/mocks/utils/time.go @@ -0,0 +1,61 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ./src/utils/time.go + +// Package mock_utils is a generated GoMock package. +package mock_utils + +import ( + gomock "github.com/golang/mock/gomock" + reflect "reflect" +) + +// MockTimeSource is a mock of TimeSource interface +type MockTimeSource struct { + ctrl *gomock.Controller + recorder *MockTimeSourceMockRecorder +} + +// MockTimeSourceMockRecorder is the mock recorder for MockTimeSource +type MockTimeSourceMockRecorder struct { + mock *MockTimeSource +} + +// NewMockTimeSource creates a new mock instance +func NewMockTimeSource(ctrl *gomock.Controller) *MockTimeSource { + mock := &MockTimeSource{ctrl: ctrl} + mock.recorder = &MockTimeSourceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockTimeSource) EXPECT() *MockTimeSourceMockRecorder { + return m.recorder +} + +// UnixNow mocks base method +func (m *MockTimeSource) UnixNow() int64 { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UnixNow") + ret0, _ := ret[0].(int64) + return ret0 +} + +// UnixNow indicates an expected call of UnixNow +func (mr *MockTimeSourceMockRecorder) UnixNow() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnixNow", reflect.TypeOf((*MockTimeSource)(nil).UnixNow)) +} + +// UnixNanoNow mocks base method +func (m *MockTimeSource) UnixNanoNow() int64 { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UnixNanoNow") + ret0, _ := ret[0].(int64) + return ret0 +} + +// UnixNanoNow indicates an expected call of UnixNanoNow +func (mr *MockTimeSourceMockRecorder) UnixNanoNow() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnixNanoNow", reflect.TypeOf((*MockTimeSource)(nil).UnixNanoNow)) +} diff --git a/test/redis/bench_test.go b/test/redis/bench_test.go index 6c190ea7b..970cbcbea 100644 --- a/test/redis/bench_test.go +++ b/test/redis/bench_test.go @@ -2,18 +2,20 @@ package redis_test import ( "context" + "math/rand" "runtime" "testing" "time" pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" + stats "github.com/lyft/gostats" + "github.com/envoyproxy/ratelimit/src/config" + "github.com/envoyproxy/ratelimit/src/limiter" "github.com/envoyproxy/ratelimit/src/redis" + "github.com/envoyproxy/ratelimit/src/redis/driver" + "github.com/envoyproxy/ratelimit/src/settings" "github.com/envoyproxy/ratelimit/src/utils" - stats "github.com/lyft/gostats" - - "math/rand" - "github.com/envoyproxy/ratelimit/test/common" ) @@ -38,13 +40,21 @@ func BenchmarkParallelDoLimit(b *testing.B) { }) } - mkDoLimitBench := func(pipelineWindow time.Duration, pipelineLimit int) func(*testing.B) { + mkDoLimitBench := func(pipelineWindow time.Duration, pipelineLimit int, rateLimitAlgorithm string) func(*testing.B) { return func(b *testing.B) { statsStore := stats.NewStore(stats.NewNullSink(), false) - client := redis.NewClientImpl(statsStore, false, "", "single", "127.0.0.1:6379", poolSize, pipelineWindow, pipelineLimit) + client := driver.NewClientImpl(statsStore, false, "", "single", "127.0.0.1:6379", poolSize, pipelineWindow, pipelineLimit) defer client.Close() - cache := redis.NewFixedRateLimitCacheImpl(client, nil, utils.NewTimeSourceImpl(), rand.New(utils.NewLockedSource(time.Now().Unix())), 10, nil, 0.8, "") + var cache limiter.RateLimitCache + timeSource := utils.NewTimeSourceImpl() + if rateLimitAlgorithm == settings.FixedRateLimit { + cache = redis.NewFixedRateLimitCacheImpl(client, nil, timeSource, rand.New(utils.NewLockedSource(time.Now().Unix())), 10, nil, 0.8, "") + } else if rateLimitAlgorithm == settings.WindowedRateLimit { + cache = redis.NewWindowedRateLimitCacheImpl(client, nil, timeSource, rand.New(utils.NewLockedSource(time.Now().Unix())), 10, nil, 0.8, "") + } else { + b.Fatalf("unknown rate limit type %s", rateLimitAlgorithm) + } request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) limits := []*config.RateLimit{config.NewRateLimit(1000000000, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} @@ -65,30 +75,59 @@ func BenchmarkParallelDoLimit(b *testing.B) { } } - b.Run("no pipeline", mkDoLimitBench(0, 0)) - - b.Run("pipeline 35us 1", mkDoLimitBench(35*time.Microsecond, 1)) - b.Run("pipeline 75us 1", mkDoLimitBench(75*time.Microsecond, 1)) - b.Run("pipeline 150us 1", mkDoLimitBench(150*time.Microsecond, 1)) - b.Run("pipeline 300us 1", mkDoLimitBench(300*time.Microsecond, 1)) - - b.Run("pipeline 35us 2", mkDoLimitBench(35*time.Microsecond, 2)) - b.Run("pipeline 75us 2", mkDoLimitBench(75*time.Microsecond, 2)) - b.Run("pipeline 150us 2", mkDoLimitBench(150*time.Microsecond, 2)) - b.Run("pipeline 300us 2", mkDoLimitBench(300*time.Microsecond, 2)) - - b.Run("pipeline 35us 4", mkDoLimitBench(35*time.Microsecond, 4)) - b.Run("pipeline 75us 4", mkDoLimitBench(75*time.Microsecond, 4)) - b.Run("pipeline 150us 4", mkDoLimitBench(150*time.Microsecond, 4)) - b.Run("pipeline 300us 4", mkDoLimitBench(300*time.Microsecond, 4)) - - b.Run("pipeline 35us 8", mkDoLimitBench(35*time.Microsecond, 8)) - b.Run("pipeline 75us 8", mkDoLimitBench(75*time.Microsecond, 8)) - b.Run("pipeline 150us 8", mkDoLimitBench(150*time.Microsecond, 8)) - b.Run("pipeline 300us 8", mkDoLimitBench(300*time.Microsecond, 8)) - - b.Run("pipeline 35us 16", mkDoLimitBench(35*time.Microsecond, 16)) - b.Run("pipeline 75us 16", mkDoLimitBench(75*time.Microsecond, 16)) - b.Run("pipeline 150us 16", mkDoLimitBench(150*time.Microsecond, 16)) - b.Run("pipeline 300us 16", mkDoLimitBench(300*time.Microsecond, 16)) + // Fixed ratelimit + b.Run("fixed ratelimit with no pipeline", mkDoLimitBench(0, 0, settings.FixedRateLimit)) + + b.Run("fixed ratelimit with pipeline 35us 1", mkDoLimitBench(35*time.Microsecond, 1, settings.FixedRateLimit)) + b.Run("fixed ratelimit with pipeline 75us 1", mkDoLimitBench(75*time.Microsecond, 1, settings.FixedRateLimit)) + b.Run("fixed ratelimit with pipeline 150us 1", mkDoLimitBench(150*time.Microsecond, 1, settings.FixedRateLimit)) + b.Run("fixed ratelimit with pipeline 300us 1", mkDoLimitBench(300*time.Microsecond, 1, settings.FixedRateLimit)) + + b.Run("fixed ratelimit with pipeline 35us 2", mkDoLimitBench(35*time.Microsecond, 2, settings.FixedRateLimit)) + b.Run("fixed ratelimit with pipeline 75us 2", mkDoLimitBench(75*time.Microsecond, 2, settings.FixedRateLimit)) + b.Run("fixed ratelimit with pipeline 150us 2", mkDoLimitBench(150*time.Microsecond, 2, settings.FixedRateLimit)) + b.Run("fixed ratelimit with pipeline 300us 2", mkDoLimitBench(300*time.Microsecond, 2, settings.FixedRateLimit)) + + b.Run("fixed ratelimit with pipeline 35us 4", mkDoLimitBench(35*time.Microsecond, 4, settings.FixedRateLimit)) + b.Run("fixed ratelimit with pipeline 75us 4", mkDoLimitBench(75*time.Microsecond, 4, settings.FixedRateLimit)) + b.Run("fixed ratelimit with pipeline 150us 4", mkDoLimitBench(150*time.Microsecond, 4, settings.FixedRateLimit)) + b.Run("fixed ratelimit with pipeline 300us 4", mkDoLimitBench(300*time.Microsecond, 4, settings.FixedRateLimit)) + + b.Run("fixed ratelimit with pipeline 35us 8", mkDoLimitBench(35*time.Microsecond, 8, settings.FixedRateLimit)) + b.Run("fixed ratelimit with pipeline 75us 8", mkDoLimitBench(75*time.Microsecond, 8, settings.FixedRateLimit)) + b.Run("fixed ratelimit with pipeline 150us 8", mkDoLimitBench(150*time.Microsecond, 8, settings.FixedRateLimit)) + b.Run("fixed ratelimit with pipeline 300us 8", mkDoLimitBench(300*time.Microsecond, 8, settings.FixedRateLimit)) + + b.Run("fixed ratelimit with pipeline 35us 16", mkDoLimitBench(35*time.Microsecond, 16, settings.FixedRateLimit)) + b.Run("fixed ratelimit with pipeline 75us 16", mkDoLimitBench(75*time.Microsecond, 16, settings.FixedRateLimit)) + b.Run("fixed ratelimit with pipeline 150us 16", mkDoLimitBench(150*time.Microsecond, 16, settings.FixedRateLimit)) + b.Run("fixed ratelimit with pipeline 300us 16", mkDoLimitBench(300*time.Microsecond, 16, settings.FixedRateLimit)) + + // Windowed ratelimit + b.Run("windowed ratelimit with no pipeline", mkDoLimitBench(0, 0, settings.WindowedRateLimit)) + + b.Run("windowed ratelimit with pipeline 35us 1", mkDoLimitBench(35*time.Microsecond, 1, settings.WindowedRateLimit)) + b.Run("windowed ratelimit with pipeline 75us 1", mkDoLimitBench(75*time.Microsecond, 1, settings.WindowedRateLimit)) + b.Run("windowed ratelimit with pipeline 150us 1", mkDoLimitBench(150*time.Microsecond, 1, settings.WindowedRateLimit)) + b.Run("windowed ratelimit with pipeline 300us 1", mkDoLimitBench(300*time.Microsecond, 1, settings.WindowedRateLimit)) + + b.Run("windowed ratelimit with pipeline 35us 2", mkDoLimitBench(35*time.Microsecond, 2, settings.WindowedRateLimit)) + b.Run("windowed ratelimit with pipeline 75us 2", mkDoLimitBench(75*time.Microsecond, 2, settings.WindowedRateLimit)) + b.Run("windowed ratelimit with pipeline 150us 2", mkDoLimitBench(150*time.Microsecond, 2, settings.WindowedRateLimit)) + b.Run("windowed ratelimit with pipeline 300us 2", mkDoLimitBench(300*time.Microsecond, 2, settings.WindowedRateLimit)) + + b.Run("windowed ratelimit with pipeline 35us 4", mkDoLimitBench(35*time.Microsecond, 4, settings.WindowedRateLimit)) + b.Run("windowed ratelimit with pipeline 75us 4", mkDoLimitBench(75*time.Microsecond, 4, settings.WindowedRateLimit)) + b.Run("windowed ratelimit with pipeline 150us 4", mkDoLimitBench(150*time.Microsecond, 4, settings.WindowedRateLimit)) + b.Run("windowed ratelimit with pipeline 300us 4", mkDoLimitBench(300*time.Microsecond, 4, settings.WindowedRateLimit)) + + b.Run("windowed ratelimit with pipeline 35us 8", mkDoLimitBench(35*time.Microsecond, 8, settings.WindowedRateLimit)) + b.Run("windowed ratelimit with pipeline 75us 8", mkDoLimitBench(75*time.Microsecond, 8, settings.WindowedRateLimit)) + b.Run("windowed ratelimit with pipeline 150us 8", mkDoLimitBench(150*time.Microsecond, 8, settings.WindowedRateLimit)) + b.Run("windowed ratelimit with pipeline 300us 8", mkDoLimitBench(300*time.Microsecond, 8, settings.WindowedRateLimit)) + + b.Run("windowed ratelimit with pipeline 35us 16", mkDoLimitBench(35*time.Microsecond, 16, settings.WindowedRateLimit)) + b.Run("windowed ratelimit with pipeline 75us 16", mkDoLimitBench(75*time.Microsecond, 16, settings.WindowedRateLimit)) + b.Run("windowed ratelimit with pipeline 150us 16", mkDoLimitBench(150*time.Microsecond, 16, settings.WindowedRateLimit)) + b.Run("windowed ratelimit with pipeline 300us 16", mkDoLimitBench(300*time.Microsecond, 16, settings.WindowedRateLimit)) } diff --git a/test/redis/driver_impl_test.go b/test/redis/driver_impl_test.go index ab488e239..c1c5376ef 100644 --- a/test/redis/driver_impl_test.go +++ b/test/redis/driver_impl_test.go @@ -4,8 +4,9 @@ import ( "testing" "time" + "github.com/envoyproxy/ratelimit/src/redis/driver" + "github.com/alicebob/miniredis/v2" - "github.com/envoyproxy/ratelimit/src/redis" "github.com/lyft/gostats" "github.com/stretchr/testify/assert" ) @@ -35,8 +36,8 @@ func testNewClientImpl(t *testing.T, pipelineWindow time.Duration, pipelineLimit redisAuth := "123" statsStore := stats.NewStore(stats.NewNullSink(), false) - mkRedisClient := func(auth, addr string) redis.Client { - return redis.NewClientImpl(statsStore, false, auth, "single", addr, 1, pipelineWindow, pipelineLimit) + mkRedisClient := func(auth, addr string) driver.Client { + return driver.NewClientImpl(statsStore, false, auth, "single", addr, 1, pipelineWindow, pipelineLimit) } t.Run("connection refused", func(t *testing.T) { @@ -50,7 +51,7 @@ func testNewClientImpl(t *testing.T, pipelineWindow time.Duration, pipelineLimit redisSrv := mustNewRedisServer() defer redisSrv.Close() - var client redis.Client + var client driver.Client assert.NotPanics(t, func() { client = mkRedisClient("", redisSrv.Addr()) }) @@ -102,8 +103,8 @@ func TestNewClientImpl(t *testing.T) { func TestDoCmd(t *testing.T) { statsStore := stats.NewStore(stats.NewNullSink(), false) - mkRedisClient := func(addr string) redis.Client { - return redis.NewClientImpl(statsStore, false, "", "single", addr, 1, 0, 0) + mkRedisClient := func(addr string) driver.Client { + return driver.NewClientImpl(statsStore, false, "", "single", addr, 1, 0, 0) } t.Run("SETGET ok", func(t *testing.T) { @@ -147,8 +148,8 @@ func testPipeDo(t *testing.T, pipelineWindow time.Duration, pipelineLimit int) f return func(t *testing.T) { statsStore := stats.NewStore(stats.NewNullSink(), false) - mkRedisClient := func(addr string) redis.Client { - return redis.NewClientImpl(statsStore, false, "", "single", addr, 1, pipelineWindow, pipelineLimit) + mkRedisClient := func(addr string) driver.Client { + return driver.NewClientImpl(statsStore, false, "", "single", addr, 1, pipelineWindow, pipelineLimit) } t.Run("SETGET ok", func(t *testing.T) { @@ -158,7 +159,7 @@ func testPipeDo(t *testing.T, pipelineWindow time.Duration, pipelineLimit int) f client := mkRedisClient(redisSrv.Addr()) var res string - pipeline := redis.Pipeline{} + pipeline := driver.Pipeline{} pipeline = client.PipeAppend(pipeline, nil, "SET", "foo", "bar") pipeline = client.PipeAppend(pipeline, &res, "GET", "foo") @@ -174,10 +175,10 @@ func testPipeDo(t *testing.T, pipelineWindow time.Duration, pipelineLimit int) f var res uint32 hits := uint32(1) - assert.Nil(t, client.PipeDo(client.PipeAppend(redis.Pipeline{}, &res, "INCRBY", "a", hits))) + assert.Nil(t, client.PipeDo(client.PipeAppend(driver.Pipeline{}, &res, "INCRBY", "a", hits))) assert.Equal(t, hits, res) - assert.Nil(t, client.PipeDo(client.PipeAppend(redis.Pipeline{}, &res, "INCRBY", "a", hits))) + assert.Nil(t, client.PipeDo(client.PipeAppend(driver.Pipeline{}, &res, "INCRBY", "a", hits))) assert.Equal(t, uint32(2), res) }) @@ -185,7 +186,7 @@ func testPipeDo(t *testing.T, pipelineWindow time.Duration, pipelineLimit int) f redisSrv := mustNewRedisServer() client := mkRedisClient(redisSrv.Addr()) - assert.Nil(t, nil, client.PipeDo(client.PipeAppend(redis.Pipeline{}, nil, "SET", "foo", "bar"))) + assert.Nil(t, nil, client.PipeDo(client.PipeAppend(driver.Pipeline{}, nil, "SET", "foo", "bar"))) redisSrv.Close() @@ -194,7 +195,7 @@ func testPipeDo(t *testing.T, pipelineWindow time.Duration, pipelineLimit int) f assert.Contains(t, err.Error(), "EOF") } - expectErrContainEOF(t, client.PipeDo(client.PipeAppend(redis.Pipeline{}, nil, "GET", "foo"))) + expectErrContainEOF(t, client.PipeDo(client.PipeAppend(driver.Pipeline{}, nil, "GET", "foo"))) }) } } diff --git a/test/redis/fixed_cache_impl_test.go b/test/redis/fixed_cache_impl_test.go index 080d617d5..031c311c3 100644 --- a/test/redis/fixed_cache_impl_test.go +++ b/test/redis/fixed_cache_impl_test.go @@ -1,25 +1,24 @@ package redis_test import ( - "testing" - "github.com/coocood/freecache" + "github.com/golang/mock/gomock" + stats "github.com/lyft/gostats" "github.com/mediocregopher/radix/v3" + "github.com/stretchr/testify/assert" pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" "github.com/envoyproxy/ratelimit/src/config" "github.com/envoyproxy/ratelimit/src/limiter" "github.com/envoyproxy/ratelimit/src/redis" + redis_driver "github.com/envoyproxy/ratelimit/src/redis/driver" "github.com/envoyproxy/ratelimit/src/utils" - stats "github.com/lyft/gostats" - - "math/rand" - "github.com/envoyproxy/ratelimit/test/common" - mock_redis "github.com/envoyproxy/ratelimit/test/mocks/redis" + mock_driver "github.com/envoyproxy/ratelimit/test/mocks/redis/driver" mock_utils "github.com/envoyproxy/ratelimit/test/mocks/utils" - "github.com/golang/mock/gomock" - "github.com/stretchr/testify/assert" + + "math/rand" + "testing" ) func TestRedis(t *testing.T) { @@ -27,7 +26,7 @@ func TestRedis(t *testing.T) { t.Run("WithPerSecondRedis", testRedis(true)) } -func pipeAppend(pipeline redis.Pipeline, rcv interface{}, cmd, key string, args ...interface{}) redis.Pipeline { +func pipeAppend(pipeline redis_driver.Pipeline, rcv interface{}, cmd, key string, args ...interface{}) redis_driver.Pipeline { return append(pipeline, radix.FlatCmd(rcv, cmd, key, args...)) } @@ -37,48 +36,47 @@ func testRedis(usePerSecondRedis bool) func(*testing.T) { controller := gomock.NewController(t) defer controller.Finish() - client := mock_redis.NewMockClient(controller) - perSecondClient := mock_redis.NewMockClient(controller) + client := mock_driver.NewMockClient(controller) + perSecondClient := mock_driver.NewMockClient(controller) timeSource := mock_utils.NewMockTimeSource(controller) + + var clientUsed *mock_driver.MockClient + if usePerSecondRedis { + clientUsed = perSecondClient + } else { + clientUsed = client + } + var cache limiter.RateLimitCache if usePerSecondRedis { cache = redis.NewFixedRateLimitCacheImpl(client, perSecondClient, timeSource, rand.New(rand.NewSource(1)), 0, nil, 0.8, "") } else { cache = redis.NewFixedRateLimitCacheImpl(client, nil, timeSource, rand.New(rand.NewSource(1)), 0, nil, 0.8, "") } + statsStore := stats.NewStore(stats.NewNullSink(), false) + domain := "domain" + + // Test 1 + request := common.NewRateLimitRequest(domain, [][][2]string{{{"key", "value"}}}, 1) + limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - var clientUsed *mock_redis.MockClient - if usePerSecondRedis { - clientUsed = perSecondClient - } else { - clientUsed = client - } - clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key_value_1234", uint32(1)).SetArg(1, uint32(5)).DoAndReturn(pipeAppend) + clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key_value_1234", int64(1)).SetArg(1, int64(1)).DoAndReturn(pipeAppend) clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key_value_1234", int64(1)).DoAndReturn(pipeAppend) clientUsed.EXPECT().PipeDo(gomock.Any()).Return(nil) - request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) - limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 5, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: utils.CalculateFixedReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) - clientUsed = client - timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key2_value2_subkey2_subvalue2_1200", uint32(1)).SetArg(1, uint32(11)).DoAndReturn(pipeAppend) - clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), - "EXPIRE", "domain_key2_value2_subkey2_subvalue2_1200", int64(60)).DoAndReturn(pipeAppend) - clientUsed.EXPECT().PipeDo(gomock.Any()).Return(nil) - + // Test 2 request = common.NewRateLimitRequest( - "domain", + domain, [][][2]string{ {{"key2", "value2"}}, {{"key2", "value2"}, {"subkey2", "subvalue2"}}, @@ -86,26 +84,25 @@ func testRedis(usePerSecondRedis bool) func(*testing.T) { limits = []*config.RateLimit{ nil, config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, "key2_value2_subkey2_subvalue2", statsStore)} + + clientUsed = client + timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) + + clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key2_value2_subkey2_subvalue2_1200", int64(1)).SetArg(1, int64(11)).DoAndReturn(pipeAppend) + clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key2_value2_subkey2_subvalue2_1200", int64(60)).DoAndReturn(pipeAppend) + clientUsed.EXPECT().PipeDo(gomock.Any()).Return(nil) + assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: nil, LimitRemaining: 0}, - {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[1].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[1].Limit, timeSource)}}, + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[1].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateFixedReset(limits[1].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(1), limits[1].Stats.TotalHits.Value()) assert.Equal(uint64(1), limits[1].Stats.OverLimit.Value()) assert.Equal(uint64(0), limits[1].Stats.NearLimit.Value()) - clientUsed = client - timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(5) - clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key3_value3_997200", uint32(1)).SetArg(1, uint32(11)).DoAndReturn(pipeAppend) - clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), - "EXPIRE", "domain_key3_value3_997200", int64(3600)).DoAndReturn(pipeAppend) - clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key3_value3_subkey3_subvalue3_950400", uint32(1)).SetArg(1, uint32(13)).DoAndReturn(pipeAppend) - clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), - "EXPIRE", "domain_key3_value3_subkey3_subvalue3_950400", int64(86400)).DoAndReturn(pipeAppend) - clientUsed.EXPECT().PipeDo(gomock.Any()).Return(nil) - + // Test 3 request = common.NewRateLimitRequest( - "domain", + domain, [][][2]string{ {{"key3", "value3"}}, {{"key3", "value3"}, {"subkey3", "subvalue3"}}, @@ -113,17 +110,27 @@ func testRedis(usePerSecondRedis bool) func(*testing.T) { limits = []*config.RateLimit{ config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_HOUR, "key3_value3", statsStore), config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_DAY, "key3_value3_subkey3_subvalue3", statsStore)} + + clientUsed = client + timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(5) + + clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key3_value3_997200", int64(1)).SetArg(1, int64(11)).DoAndReturn(pipeAppend) + clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key3_value3_997200", int64(3600)).DoAndReturn(pipeAppend) + clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key3_value3_subkey3_subvalue3_950400", int64(1)).SetArg(1, int64(13)).DoAndReturn(pipeAppend) + clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key3_value3_subkey3_subvalue3_950400", int64(86400)).DoAndReturn(pipeAppend) + clientUsed.EXPECT().PipeDo(gomock.Any()).Return(nil) + assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}, - {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[1].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[1].Limit, timeSource)}}, + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateFixedReset(limits[0].Limit, timeSource)}, + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[1].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateFixedReset(limits[1].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(1), limits[0].Stats.OverLimit.Value()) assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) - assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) - assert.Equal(uint64(1), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) + assert.Equal(uint64(1), limits[1].Stats.TotalHits.Value()) + assert.Equal(uint64(1), limits[1].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[1].Stats.NearLimit.Value()) } } @@ -168,29 +175,29 @@ func TestOverLimitWithLocalCache(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() - client := mock_redis.NewMockClient(controller) + client := mock_driver.NewMockClient(controller) timeSource := mock_utils.NewMockTimeSource(controller) localCache := freecache.NewCache(100) cache := redis.NewFixedRateLimitCacheImpl(client, nil, timeSource, rand.New(rand.NewSource(1)), 0, localCache, 0.8, "") sink := &common.TestStatSink{} statsStore := stats.NewStore(sink, true) - localCacheStats := limiter.NewLocalCacheStats(localCache, statsStore.Scope("localcache")) + localCacheStats := utils.NewLocalCacheStats(localCache, statsStore.Scope("localcache")) + domain := "domain" // Test Near Limit Stats. Under Near Limit Ratio - timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key4_value4_997200", uint32(1)).SetArg(1, uint32(11)).DoAndReturn(pipeAppend) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), - "EXPIRE", "domain_key4_value4_997200", int64(3600)).DoAndReturn(pipeAppend) - client.EXPECT().PipeDo(gomock.Any()).Return(nil) - - request := common.NewRateLimitRequest("domain", [][][2]string{{{"key4", "value4"}}}, 1) - + request := common.NewRateLimitRequest(domain, [][][2]string{{{"key4", "value4"}}}, 1) limits := []*config.RateLimit{ config.NewRateLimit(15, pb.RateLimitResponse_RateLimit_HOUR, "key4_value4", statsStore)} + timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) + + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key4_value4_997200", int64(1)).SetArg(1, int64(11)).DoAndReturn(pipeAppend) + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key4_value4_997200", int64(3600)).DoAndReturn(pipeAppend) + client.EXPECT().PipeDo(gomock.Any()).Return(nil) + assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 4, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 4, DurationUntilReset: utils.CalculateFixedReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) @@ -202,14 +209,14 @@ func TestOverLimitWithLocalCache(t *testing.T) { // Test Near Limit Stats. At Near Limit Ratio, still OK timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key4_value4_997200", uint32(1)).SetArg(1, uint32(13)).DoAndReturn(pipeAppend) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), - "EXPIRE", "domain_key4_value4_997200", int64(3600)).DoAndReturn(pipeAppend) + + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key4_value4_997200", int64(1)).SetArg(1, int64(13)).DoAndReturn(pipeAppend) + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key4_value4_997200", int64(3600)).DoAndReturn(pipeAppend) client.EXPECT().PipeDo(gomock.Any()).Return(nil) assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 2, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 2, DurationUntilReset: utils.CalculateFixedReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(2), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) @@ -221,14 +228,14 @@ func TestOverLimitWithLocalCache(t *testing.T) { // Test Over limit stats timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key4_value4_997200", uint32(1)).SetArg(1, uint32(16)).DoAndReturn(pipeAppend) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), - "EXPIRE", "domain_key4_value4_997200", int64(3600)).DoAndReturn(pipeAppend) + + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key4_value4_997200", int64(1)).SetArg(1, int64(16)).DoAndReturn(pipeAppend) + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key4_value4_997200", int64(3600)).DoAndReturn(pipeAppend) client.EXPECT().PipeDo(gomock.Any()).Return(nil) assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateFixedReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(1), limits[0].Stats.OverLimit.Value()) @@ -240,12 +247,10 @@ func TestOverLimitWithLocalCache(t *testing.T) { // Test Over limit stats with local cache timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key4_value4_997200", uint32(1)).Times(0) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), - "EXPIRE", "domain_key4_value4_997200", int64(3600)).Times(0) + assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateFixedReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(4), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(2), limits[0].Stats.OverLimit.Value()) @@ -261,26 +266,26 @@ func TestNearLimit(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() - client := mock_redis.NewMockClient(controller) + client := mock_driver.NewMockClient(controller) timeSource := mock_utils.NewMockTimeSource(controller) cache := redis.NewFixedRateLimitCacheImpl(client, nil, timeSource, rand.New(rand.NewSource(1)), 0, nil, 0.8, "") statsStore := stats.NewStore(stats.NewNullSink(), false) + domain := "domain" // Test Near Limit Stats. Under Near Limit Ratio - timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key4_value4_997200", uint32(1)).SetArg(1, uint32(11)).DoAndReturn(pipeAppend) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), - "EXPIRE", "domain_key4_value4_997200", int64(3600)).DoAndReturn(pipeAppend) - client.EXPECT().PipeDo(gomock.Any()).Return(nil) - - request := common.NewRateLimitRequest("domain", [][][2]string{{{"key4", "value4"}}}, 1) - + request := common.NewRateLimitRequest(domain, [][][2]string{{{"key4", "value4"}}}, 1) limits := []*config.RateLimit{ config.NewRateLimit(15, pb.RateLimitResponse_RateLimit_HOUR, "key4_value4", statsStore)} + timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) + + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key4_value4_997200", int64(1)).SetArg(1, int64(11)).DoAndReturn(pipeAppend) + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key4_value4_997200", int64(3600)).DoAndReturn(pipeAppend) + client.EXPECT().PipeDo(gomock.Any()).Return(nil) + assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 4, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 4, DurationUntilReset: utils.CalculateFixedReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) @@ -288,14 +293,14 @@ func TestNearLimit(t *testing.T) { // Test Near Limit Stats. At Near Limit Ratio, still OK timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key4_value4_997200", uint32(1)).SetArg(1, uint32(13)).DoAndReturn(pipeAppend) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), - "EXPIRE", "domain_key4_value4_997200", int64(3600)).DoAndReturn(pipeAppend) + + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key4_value4_997200", int64(1)).SetArg(1, int64(13)).DoAndReturn(pipeAppend) + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key4_value4_997200", int64(3600)).DoAndReturn(pipeAppend) client.EXPECT().PipeDo(gomock.Any()).Return(nil) assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 2, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 2, DurationUntilReset: utils.CalculateFixedReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(2), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) @@ -304,14 +309,14 @@ func TestNearLimit(t *testing.T) { // Test Near Limit Stats. We went OVER_LIMIT, but the near_limit counter only increases // when we are near limit, not after we have passed the limit. timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key4_value4_997200", uint32(1)).SetArg(1, uint32(16)).DoAndReturn(pipeAppend) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), - "EXPIRE", "domain_key4_value4_997200", int64(3600)).DoAndReturn(pipeAppend) + + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key4_value4_997200", int64(1)).SetArg(1, int64(16)).DoAndReturn(pipeAppend) + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key4_value4_997200", int64(3600)).DoAndReturn(pipeAppend) client.EXPECT().PipeDo(gomock.Any()).Return(nil) assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateFixedReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(1), limits[0].Stats.OverLimit.Value()) @@ -319,96 +324,102 @@ func TestNearLimit(t *testing.T) { // Now test hitsAddend that is greater than 1 // All of it under limit, under near limit + request = common.NewRateLimitRequest("domain", [][][2]string{{{"key5", "value5"}}}, 3) + limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, "key5_value5", statsStore)} + timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key5_value5_1234", uint32(3)).SetArg(1, uint32(5)).DoAndReturn(pipeAppend) + + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key5_value5_1234", int64(3)).SetArg(1, int64(5)).DoAndReturn(pipeAppend) client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key5_value5_1234", int64(1)).DoAndReturn(pipeAppend) client.EXPECT().PipeDo(gomock.Any()).Return(nil) - request = common.NewRateLimitRequest("domain", [][][2]string{{{"key5", "value5"}}}, 3) - limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, "key5_value5", statsStore)} - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 15, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 15, DurationUntilReset: utils.CalculateFixedReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) // All of it under limit, some over near limit + request = common.NewRateLimitRequest("domain", [][][2]string{{{"key6", "value6"}}}, 2) + limits = []*config.RateLimit{config.NewRateLimit(8, pb.RateLimitResponse_RateLimit_SECOND, "key6_value6", statsStore)} + timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key6_value6_1234", uint32(2)).SetArg(1, uint32(7)).DoAndReturn(pipeAppend) + + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key6_value6_1234", int64(2)).SetArg(1, int64(7)).DoAndReturn(pipeAppend) client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key6_value6_1234", int64(1)).DoAndReturn(pipeAppend) client.EXPECT().PipeDo(gomock.Any()).Return(nil) - request = common.NewRateLimitRequest("domain", [][][2]string{{{"key6", "value6"}}}, 2) - limits = []*config.RateLimit{config.NewRateLimit(8, pb.RateLimitResponse_RateLimit_SECOND, "key6_value6", statsStore)} - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 1, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 1, DurationUntilReset: utils.CalculateFixedReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(2), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) assert.Equal(uint64(1), limits[0].Stats.NearLimit.Value()) // All of it under limit, all of it over near limit + request = common.NewRateLimitRequest("domain", [][][2]string{{{"key7", "value7"}}}, 3) + limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, "key7_value7", statsStore)} + timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key7_value7_1234", uint32(3)).SetArg(1, uint32(19)).DoAndReturn(pipeAppend) + + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key7_value7_1234", int64(3)).SetArg(1, int64(19)).DoAndReturn(pipeAppend) client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key7_value7_1234", int64(1)).DoAndReturn(pipeAppend) client.EXPECT().PipeDo(gomock.Any()).Return(nil) - request = common.NewRateLimitRequest("domain", [][][2]string{{{"key7", "value7"}}}, 3) - limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, "key7_value7", statsStore)} - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 1, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 1, DurationUntilReset: utils.CalculateFixedReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) assert.Equal(uint64(3), limits[0].Stats.NearLimit.Value()) // Some of it over limit, all of it over near limit + request = common.NewRateLimitRequest("domain", [][][2]string{{{"key8", "value8"}}}, 3) + limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, "key8_value8", statsStore)} + timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key8_value8_1234", uint32(3)).SetArg(1, uint32(22)).DoAndReturn(pipeAppend) + + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key8_value8_1234", int64(3)).SetArg(1, int64(22)).DoAndReturn(pipeAppend) client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key8_value8_1234", int64(1)).DoAndReturn(pipeAppend) client.EXPECT().PipeDo(gomock.Any()).Return(nil) - request = common.NewRateLimitRequest("domain", [][][2]string{{{"key8", "value8"}}}, 3) - limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, "key8_value8", statsStore)} - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateFixedReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(2), limits[0].Stats.OverLimit.Value()) assert.Equal(uint64(1), limits[0].Stats.NearLimit.Value()) // Some of it in all three places + request = common.NewRateLimitRequest("domain", [][][2]string{{{"key9", "value9"}}}, 7) + limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, "key9_value9", statsStore)} + timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key9_value9_1234", uint32(7)).SetArg(1, uint32(22)).DoAndReturn(pipeAppend) + + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key9_value9_1234", int64(7)).SetArg(1, int64(22)).DoAndReturn(pipeAppend) client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key9_value9_1234", int64(1)).DoAndReturn(pipeAppend) client.EXPECT().PipeDo(gomock.Any()).Return(nil) - request = common.NewRateLimitRequest("domain", [][][2]string{{{"key9", "value9"}}}, 7) - limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, "key9_value9", statsStore)} - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateFixedReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(7), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(2), limits[0].Stats.OverLimit.Value()) assert.Equal(uint64(4), limits[0].Stats.NearLimit.Value()) // all of it over limit + request = common.NewRateLimitRequest("domain", [][][2]string{{{"key10", "value10"}}}, 3) + limits = []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key10_value10", statsStore)} + timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key10_value10_1234", uint32(3)).SetArg(1, uint32(30)).DoAndReturn(pipeAppend) + + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key10_value10_1234", int64(3)).SetArg(1, int64(30)).DoAndReturn(pipeAppend) client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key10_value10_1234", int64(1)).DoAndReturn(pipeAppend) client.EXPECT().PipeDo(gomock.Any()).Return(nil) - request = common.NewRateLimitRequest("domain", [][][2]string{{{"key10", "value10"}}}, 3) - limits = []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key10_value10", statsStore)} - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateFixedReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(3), limits[0].Stats.OverLimit.Value()) @@ -420,23 +431,26 @@ func TestRedisWithJitter(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() - client := mock_redis.NewMockClient(controller) + client := mock_driver.NewMockClient(controller) timeSource := mock_utils.NewMockTimeSource(controller) jitterSource := mock_utils.NewMockJitterRandSource(controller) cache := redis.NewFixedRateLimitCacheImpl(client, nil, timeSource, rand.New(jitterSource), 3600, nil, 0.8, "") statsStore := stats.NewStore(stats.NewNullSink(), false) + domain := "domain" + + request := common.NewRateLimitRequest(domain, [][][2]string{{{"key", "value"}}}, 1) + limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - jitterSource.EXPECT().Int63().Return(int64(100)) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key_value_1234", uint32(1)).SetArg(1, uint32(5)).DoAndReturn(pipeAppend) + + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key_value_1234", int64(1)).SetArg(1, int64(5)).DoAndReturn(pipeAppend) client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key_value_1234", int64(101)).DoAndReturn(pipeAppend) client.EXPECT().PipeDo(gomock.Any()).Return(nil) - request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) - limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} + jitterSource.EXPECT().Int63().Return(int64(100)) assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 5, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 5, DurationUntilReset: utils.CalculateFixedReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) diff --git a/test/redis/windowed_cache_impl_test.go b/test/redis/windowed_cache_impl_test.go new file mode 100644 index 000000000..90caf9b0c --- /dev/null +++ b/test/redis/windowed_cache_impl_test.go @@ -0,0 +1,464 @@ +package redis_test + +import ( + "math/rand" + "testing" + + "github.com/coocood/freecache" + pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" + "github.com/envoyproxy/ratelimit/src/config" + "github.com/envoyproxy/ratelimit/src/limiter" + "github.com/envoyproxy/ratelimit/src/redis" + "github.com/envoyproxy/ratelimit/src/utils" + "github.com/envoyproxy/ratelimit/test/common" + redis_driver_mock "github.com/envoyproxy/ratelimit/test/mocks/redis/driver" + mock_utils "github.com/envoyproxy/ratelimit/test/mocks/utils" + + "github.com/golang/mock/gomock" + "github.com/golang/protobuf/ptypes/duration" + stats "github.com/lyft/gostats" + "github.com/stretchr/testify/assert" +) + +func TestRedisWindowed(t *testing.T) { + t.Run("WithoutPerSecondRedis", testRedisWindowed(false)) + t.Run("WithPerSecondRedis", testRedisWindowed(true)) +} + +func testRedisWindowed(usePerSecondRedis bool) func(*testing.T) { + return func(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + client := redis_driver_mock.NewMockClient(controller) + perSecondClient := redis_driver_mock.NewMockClient(controller) + timeSource := mock_utils.NewMockTimeSource(controller) + + var clientUsed *redis_driver_mock.MockClient + if usePerSecondRedis { + clientUsed = perSecondClient + + } else { + clientUsed = client + } + + var cache limiter.RateLimitCache + if usePerSecondRedis { + cache = redis.NewWindowedRateLimitCacheImpl(client, perSecondClient, timeSource, rand.New(rand.NewSource(1)), 0, nil, 0.8, "") + } else { + cache = redis.NewWindowedRateLimitCacheImpl(client, nil, timeSource, rand.New(rand.NewSource(1)), 0, nil, 0.8, "") + } + statsStore := stats.NewStore(stats.NewNullSink(), false) + + // Test 1 + request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) + limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} + + timeSource.EXPECT().UnixNanoNow().Return(int64(1e9)).MaxTimes(1) + + clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "SETNX", "domain_key_value_0", int64(0)).DoAndReturn(pipeAppend) + clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key_value_0", int64(1)).DoAndReturn(pipeAppend) + clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "GET", "domain_key_value_0").DoAndReturn(pipeAppend) + clientUsed.EXPECT().PipeDo(gomock.Any()).Return(nil) + + clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "SET", "domain_key_value_0", int64(1e9+1e8)).DoAndReturn(pipeAppend) + clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key_value_0", int64(1)).DoAndReturn(pipeAppend) + clientUsed.EXPECT().PipeDo(gomock.Any()).Return(nil) + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: &duration.Duration{Nanos: 1e8}}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) + + // Test 2 + request = common.NewRateLimitRequest( + "domain", + [][][2]string{ + {{"key2", "value2"}}, + {{"key2", "value2"}, {"subkey2", "subvalue2"}}, + }, 1) + limits = []*config.RateLimit{ + nil, + config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, "key2_value2_subkey2_subvalue2", statsStore)} + + timeSource.EXPECT().UnixNanoNow().Return(int64(1e9)).MaxTimes(1) + + clientUsed = client + + clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "SETNX", "domain_key2_value2_subkey2_subvalue2_0", int64(0)).DoAndReturn(pipeAppend) + clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key2_value2_subkey2_subvalue2_0", int64(60)).DoAndReturn(pipeAppend) + clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "GET", "domain_key2_value2_subkey2_subvalue2_0").DoAndReturn(pipeAppend) + clientUsed.EXPECT().PipeDo(gomock.Any()).Return(nil) + + clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "SET", "domain_key2_value2_subkey2_subvalue2_0", int64(1e9+6e9)).DoAndReturn(pipeAppend) + clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key2_value2_subkey2_subvalue2_0", int64(7)).DoAndReturn(pipeAppend) + clientUsed.EXPECT().PipeDo(gomock.Any()).Return(nil) + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: nil, LimitRemaining: 0}, + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[1].Limit, LimitRemaining: 9, DurationUntilReset: &duration.Duration{Seconds: 6}}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(1), limits[1].Stats.TotalHits.Value()) + assert.Equal(uint64(0), limits[1].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[1].Stats.NearLimit.Value()) + + // Test 3 + request = common.NewRateLimitRequest( + "domain", + [][][2]string{ + {{"key3", "value3"}}, + {{"key3", "value3"}, {"subkey3", "subvalue3"}}, + }, 1) + limits = []*config.RateLimit{ + config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_HOUR, "key3_value3", statsStore), + config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_DAY, "key3_value3_subkey3_subvalue3", statsStore)} + + timeSource.EXPECT().UnixNanoNow().Return(int64(1e9)).MaxTimes(2) + + clientUsed = client + clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "SETNX", "domain_key3_value3_0", int64(0)).DoAndReturn(pipeAppend) + clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key3_value3_0", int64(3600)).DoAndReturn(pipeAppend) + clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "GET", "domain_key3_value3_0").DoAndReturn(pipeAppend) + + clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "SETNX", "domain_key3_value3_subkey3_subvalue3_0", int64(0)).DoAndReturn(pipeAppend) + clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key3_value3_subkey3_subvalue3_0", int64(24*3600)).DoAndReturn(pipeAppend) + clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "GET", "domain_key3_value3_subkey3_subvalue3_0").DoAndReturn(pipeAppend) + clientUsed.EXPECT().PipeDo(gomock.Any()).Return(nil) + + clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "SET", "domain_key3_value3_0", int64(361*1e9)).DoAndReturn(pipeAppend) + clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key3_value3_0", int64(361)).DoAndReturn(pipeAppend) + + clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "SET", "domain_key3_value3_subkey3_subvalue3_0", int64((24*360*1e9)+1e9)).DoAndReturn(pipeAppend) + clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key3_value3_subkey3_subvalue3_0", int64((24*360)+1)).DoAndReturn(pipeAppend) + clientUsed.EXPECT().PipeDo(gomock.Any()).Return(nil) + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{ + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: &duration.Duration{Seconds: 360}}, + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[1].Limit, LimitRemaining: 9, DurationUntilReset: &duration.Duration{Seconds: 24 * 360}}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) + assert.Equal(uint64(1), limits[1].Stats.TotalHits.Value()) + assert.Equal(uint64(0), limits[1].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[1].Stats.NearLimit.Value()) + } +} + +func TestNearLimitWindowed(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + client := redis_driver_mock.NewMockClient(controller) + timeSource := mock_utils.NewMockTimeSource(controller) + cache := redis.NewWindowedRateLimitCacheImpl(client, nil, timeSource, rand.New(rand.NewSource(1)), 0, nil, 0.8, "") + statsStore := stats.NewStore(stats.NewNullSink(), false) + domain := "domain" + request := common.NewRateLimitRequest(domain, [][][2]string{{{"key4", "value4"}}}, 1) + limits := []*config.RateLimit{ + config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, "key4_value4", statsStore)} + + // Test Near Limit Stats. Under Near Limit Ratio + // periode = 1 minute = 60 second + // limit = 10 request/minute + // emissionInterval = 6 second + // request = 1 + // increment = emissionInterval*request = 6 second + + // arriveAt = 01 second + // tat = 01 second + + // newTat should be max(arriveAt,tat)+increment = 7 second + // expire should be (newtat-arriveat)+1 = 7 second + // DurationUntilReset should be newtat-arriveat = 6 second + timeSource.EXPECT().UnixNanoNow().Return(int64(1e9)).MaxTimes(1) + + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "SETNX", "domain_key4_value4_0", int64(0)).DoAndReturn(pipeAppend) + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key4_value4_0", int64(60)).DoAndReturn(pipeAppend) + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "GET", "domain_key4_value4_0").SetArg(1, int64(1e9)).DoAndReturn(pipeAppend) + client.EXPECT().PipeDo(gomock.Any()).Return(nil) + + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "SET", "domain_key4_value4_0", int64(1e9+6e9)).DoAndReturn(pipeAppend) + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key4_value4_0", int64(6+1)).DoAndReturn(pipeAppend) + client.EXPECT().PipeDo(gomock.Any()).Return(nil) + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{ + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: &duration.Duration{Seconds: 6}}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) + + // Test Near Limit Stats. At Near Limit Ratio, still OK + // periode = 1 minute = 60 second + // limit = 10 request/minute + // emissionInterval = 6 second + // request = 1 + // increment = emissionInterval*request = 6 second + + // arriveAt = 07 second + // tat = 54 second + + // newTat should be max(arriveAt,tat)+increment = 60 second + // expire should be (newtat-arriveat)+1 = 54 second + // DurationUntilReset should be newtat-arriveat = 53 second + timeSource.EXPECT().UnixNanoNow().Return(int64(7e9)).MaxTimes(1) + + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "SETNX", "domain_key4_value4_0", int64(0)).DoAndReturn(pipeAppend) + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key4_value4_0", int64(60)).DoAndReturn(pipeAppend) + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "GET", "domain_key4_value4_0").SetArg(1, int64(54e9)).DoAndReturn(pipeAppend) + client.EXPECT().PipeDo(gomock.Any()).Return(nil) + + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "SET", "domain_key4_value4_0", int64(54e9+6e9)).DoAndReturn(pipeAppend) + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key4_value4_0", int64(60-7+1)).DoAndReturn(pipeAppend) + client.EXPECT().PipeDo(gomock.Any()).Return(nil) + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{ + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 1, DurationUntilReset: &duration.Duration{Seconds: 53}}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(2), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(1), limits[0].Stats.NearLimit.Value()) + + // Test Near Limit Stats. We went OVER_LIMIT, but the near_limit counter only increases + // when we are near limit, not after we have passed the limit. + // periode = 1 minute = 60 second + // limit = 10 request/minute + // emissionInterval = 6 second + // request = 1 + // increment = emissionInterval*request = 6 second + + // arriveAt = 04 second + // tat = 60 second + + // newTat should be max(arriveAt,tat)+increment = 66 second + // expire should be (tat-arriveat)+1 = 57 second + // DurationUntilReset should be tat-arriveat = 56 second + timeSource.EXPECT().UnixNanoNow().Return(int64(4e9)).MaxTimes(1) + + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "SETNX", "domain_key4_value4_0", int64(0)).DoAndReturn(pipeAppend) + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key4_value4_0", int64(60)).DoAndReturn(pipeAppend) + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "GET", "domain_key4_value4_0").SetArg(1, int64(60e9)).DoAndReturn(pipeAppend) + client.EXPECT().PipeDo(gomock.Any()).Return(nil) + + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "SET", "domain_key4_value4_0", int64(60e9)).DoAndReturn(pipeAppend) + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key4_value4_0", int64(60-4+1)).DoAndReturn(pipeAppend) + client.EXPECT().PipeDo(gomock.Any()).Return(nil) + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{ + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: &duration.Duration{Seconds: 56}}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(1), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(1), limits[0].Stats.NearLimit.Value()) +} + +func TestWindowedOverLimitWithLocalCache(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + client := redis_driver_mock.NewMockClient(controller) + timeSource := mock_utils.NewMockTimeSource(controller) + localCache := freecache.NewCache(100) + cache := redis.NewWindowedRateLimitCacheImpl(client, nil, timeSource, rand.New(rand.NewSource(1)), 0, localCache, 0.8, "") + sink := &common.TestStatSink{} + statsStore := stats.NewStore(sink, true) + domain := "domain" + localCacheStats := utils.NewLocalCacheStats(localCache, statsStore.Scope("localcache")) + + request := common.NewRateLimitRequest(domain, [][][2]string{{{"key", "value"}}}, 1) + limits := []*config.RateLimit{ + config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_HOUR, "key_value", statsStore)} + + // Test Near Limit Stats. Under Near Limit Ratio + // periode = 60 minute = 3600 second + // limit = 10 request/hour + // emissionInterval = 6 minute + // request = 1 + // increment = emissionInterval*request = 6 minute + + // arriveAt = 1 minute + // tat = 12 minite + + // newTat should be max(arriveAt,tat)+increment = 18 minute + // expire should be (newtat-arriveat)+1 second = 17 minute 1 second + // DurationUntilReset should be newtat-arriveat = 17 minute + timeSource.EXPECT().UnixNanoNow().Return(int64(1 * 60 * 1e9)).MaxTimes(1) + + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "SETNX", "domain_key_value_0", int64(0)).DoAndReturn(pipeAppend) + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key_value_0", int64(3600)).DoAndReturn(pipeAppend) + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "GET", "domain_key_value_0").SetArg(1, int64(12*60*1e9)).DoAndReturn(pipeAppend) + client.EXPECT().PipeDo(gomock.Any()).Return(nil) + + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "SET", "domain_key_value_0", int64(18*60*1e9)).DoAndReturn(pipeAppend) + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key_value_0", int64((17*60)+1)).DoAndReturn(pipeAppend) + client.EXPECT().PipeDo(gomock.Any()).Return(nil) + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{ + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 7, DurationUntilReset: &duration.Duration{Seconds: 17 * 60}}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimitWithLocalCache.Value()) + assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) + + // Check the local cache stats. + testLocalCacheStats(localCacheStats, statsStore, sink, 0, 1, 1, 0, 0) + + // Test Near Limit Stats. At Near Limit Ratio, still OK + // periode = 60 minute = 3600 second + // limit = 10 request/hour + // emissionInterval = 6 minute + // request = 1 + // increment = emissionInterval*request = 6 minute + + // arriveAt = 12 minute + // tat = 60 minute + + // newTat should be max(arriveAt,tat)+increment = 66 minute + // expire should be (newtat-arriveat)+1 second = 54 minute 1 second + // DurationUntilReset should be newtat-arriveat = 54 minute + timeSource.EXPECT().UnixNanoNow().Return(int64(12 * 60 * 1e9)).MaxTimes(1) + + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "SETNX", "domain_key_value_0", int64(0)).DoAndReturn(pipeAppend) + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key_value_0", int64(3600)).DoAndReturn(pipeAppend) + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "GET", "domain_key_value_0").SetArg(1, int64(60*60*1e9)).DoAndReturn(pipeAppend) + client.EXPECT().PipeDo(gomock.Any()).Return(nil) + + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "SET", "domain_key_value_0", int64(66*60*1e9)).DoAndReturn(pipeAppend) + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key_value_0", int64((54*60)+1)).DoAndReturn(pipeAppend) + client.EXPECT().PipeDo(gomock.Any()).Return(nil) + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{ + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 1, DurationUntilReset: &duration.Duration{Seconds: 54 * 60}}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(2), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimitWithLocalCache.Value()) + assert.Equal(uint64(1), limits[0].Stats.NearLimit.Value()) + + // Check the local cache stats. + testLocalCacheStats(localCacheStats, statsStore, sink, 0, 2, 2, 0, 0) + + // Test Over limit stats + // periode = 60 minute = 3600 second + // limit = 10 request/hour + // emissionInterval = 6 minute + // request = 1 + // increment = emissionInterval*request = 6 minute + + // arriveAt = 2 minute + // tat = 72 minute + + // newTat should be max(arriveAt,tat)+increment = 78 minute (not used) + // expire should be (tat-arriveat)+1 second = 70 minute 1 second + // DurationUntilReset should be tat-arriveat = 70 minute + timeSource.EXPECT().UnixNanoNow().Return(int64(2 * 60 * 1e9)).MaxTimes(1) + + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "SETNX", "domain_key_value_0", int64(0)).DoAndReturn(pipeAppend) + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key_value_0", int64(3600)).DoAndReturn(pipeAppend) + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "GET", "domain_key_value_0").SetArg(1, int64(72*60*1e9)).DoAndReturn(pipeAppend) + client.EXPECT().PipeDo(gomock.Any()).Return(nil) + + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "SET", "domain_key_value_0", int64(72*60*1e9)).DoAndReturn(pipeAppend) + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key_value_0", int64((70*60)+1)).DoAndReturn(pipeAppend) + client.EXPECT().PipeDo(gomock.Any()).Return(nil) + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{ + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: &duration.Duration{Seconds: 70 * 60}}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(1), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimitWithLocalCache.Value()) + assert.Equal(uint64(1), limits[0].Stats.NearLimit.Value()) + + // Check the local cache stats. + testLocalCacheStats(localCacheStats, statsStore, sink, 0, 2, 3, 0, 1) + + // Test Over limit stats with local cache + // periode = 60 minute = 3600 second + // limit = 10 request/hour + // emissionInterval = 6 minute + // request = 1 + // increment = emissionInterval*request = 6 minute + + // arriveAt = 3 minute + // tat = 72 minute + + // newTat should be max(arriveAt,tat)+increment = 78 minute (not used) + // expire should be (tat-arriveat)+1 second = 69 minute 1 second + // DurationUntilReset should be secondsToReset-(arriveAt%secondsToReset) = 57 minute + timeSource.EXPECT().UnixNanoNow().Return(int64(3 * 60 * 1e9)).MaxTimes(1) + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{ + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: &duration.Duration{Seconds: 57 * 60}}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(4), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(2), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(1), limits[0].Stats.OverLimitWithLocalCache.Value()) + assert.Equal(uint64(1), limits[0].Stats.NearLimit.Value()) + + // Check the local cache stats. + testLocalCacheStats(localCacheStats, statsStore, sink, 1, 3, 4, 0, 1) +} + +func TestRedisWindowedWithJitter(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + client := redis_driver_mock.NewMockClient(controller) + timeSource := mock_utils.NewMockTimeSource(controller) + jitterSource := mock_utils.NewMockJitterRandSource(controller) + cache := redis.NewWindowedRateLimitCacheImpl(client, nil, timeSource, rand.New(jitterSource), 3600, nil, 0.8, "") + statsStore := stats.NewStore(stats.NewNullSink(), false) + + request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) + limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} + + // periode = 1 second + // limit = 10 request/second + // emissionInterval = 1/10 second + // request = 1 + // increment = emissionInterval*request = 1/10 second + + // arriveAt = 1 second + // tat = 1 second + + // newTat should be max(arriveAt,tat)+increment = 1,1 second + // expire should be (tat-arriveat)+1 second = 1 second + // DurationUntilReset should be newTat-arriveat = 0.1 second + + timeSource.EXPECT().UnixNanoNow().Return(int64(1e9)).MaxTimes(1) + + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "SETNX", "domain_key_value_0", int64(0)).DoAndReturn(pipeAppend) + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key_value_0", int64(1)).DoAndReturn(pipeAppend) + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "GET", "domain_key_value_0").SetArg(1, int64(1e9)).DoAndReturn(pipeAppend) + client.EXPECT().PipeDo(gomock.Any()).Return(nil) + + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "SET", "domain_key_value_0", int64(1e9+1e8)).DoAndReturn(pipeAppend) + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key_value_0", int64(101)).DoAndReturn(pipeAppend) + client.EXPECT().PipeDo(gomock.Any()).Return(nil) + + jitterSource.EXPECT().Int63().Return(int64(100)) + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: &duration.Duration{Nanos: 1e8}}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) +} diff --git a/test/service/ratelimit_legacy_test.go b/test/service/ratelimit_legacy_test.go index a51ddbe90..510e80232 100644 --- a/test/service/ratelimit_legacy_test.go +++ b/test/service/ratelimit_legacy_test.go @@ -10,7 +10,7 @@ import ( pb_legacy "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v2" pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" "github.com/envoyproxy/ratelimit/src/config" - "github.com/envoyproxy/ratelimit/src/redis" + "github.com/envoyproxy/ratelimit/src/redis/driver" "github.com/envoyproxy/ratelimit/src/service" "github.com/envoyproxy/ratelimit/test/common" "github.com/golang/mock/gomock" @@ -197,7 +197,7 @@ func TestCacheErrorLegacy(test *testing.T) { t.config.EXPECT().GetLimit(nil, "different-domain", req.Descriptors[0]).Return(limits[0]) t.cache.EXPECT().DoLimit(nil, req, limits).Do( func(context.Context, *pb.RateLimitRequest, []*config.RateLimit) { - panic(redis.RedisError("cache error")) + panic(driver.RedisError("cache error")) }) response, err := service.GetLegacyService().ShouldRateLimit(nil, legacyRequest) diff --git a/test/service/ratelimit_test.go b/test/service/ratelimit_test.go index 12c77926a..58c7b25d1 100644 --- a/test/service/ratelimit_test.go +++ b/test/service/ratelimit_test.go @@ -6,7 +6,7 @@ import ( pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" "github.com/envoyproxy/ratelimit/src/config" - "github.com/envoyproxy/ratelimit/src/redis" + "github.com/envoyproxy/ratelimit/src/redis/driver" ratelimit "github.com/envoyproxy/ratelimit/src/service" "github.com/envoyproxy/ratelimit/test/common" mock_config "github.com/envoyproxy/ratelimit/test/mocks/config" @@ -205,7 +205,7 @@ func TestCacheError(test *testing.T) { t.config.EXPECT().GetLimit(nil, "different-domain", request.Descriptors[0]).Return(limits[0]) t.cache.EXPECT().DoLimit(nil, request, limits).Do( func(context.Context, *pb.RateLimitRequest, []*config.RateLimit) { - panic(redis.RedisError("cache error")) + panic(driver.RedisError("cache error")) }) response, err := service.ShouldRateLimit(nil, request)