diff --git a/Dockerfile.integration b/Dockerfile.integration index efff81438..55eb04b4d 100644 --- a/Dockerfile.integration +++ b/Dockerfile.integration @@ -1,7 +1,7 @@ # Running this docker image runs the integration tests. FROM golang:1.14 -RUN apt-get update -y && apt-get install sudo stunnel4 redis -y && rm -rf /var/lib/apt/lists/* +RUN apt-get update -y && apt-get install sudo stunnel4 redis memcached -y && rm -rf /var/lib/apt/lists/* WORKDIR /workdir diff --git a/Makefile b/Makefile index 60cffa2d3..168aa1a65 100644 --- a/Makefile +++ b/Makefile @@ -97,6 +97,7 @@ tests_with_redis: bootstrap_redis_tls tests_unit mkdir 6389 && cd 6389 && redis-server --port 6389 --cluster-enabled yes --requirepass password123 & mkdir 6390 && cd 6390 && redis-server --port 6390 --cluster-enabled yes --requirepass password123 & mkdir 6391 && cd 6391 && redis-server --port 6391 --cluster-enabled yes --requirepass password123 & + memcached -u root --port 6394 -m 64 & sleep 2 echo "yes" | redis-cli --cluster create -a password123 127.0.0.1:6386 127.0.0.1:6387 127.0.0.1:6388 --cluster-replicas 0 echo "yes" | redis-cli --cluster create -a password123 127.0.0.1:6389 127.0.0.1:6390 127.0.0.1:6391 --cluster-replicas 0 diff --git a/README.md b/README.md index 807810a73..b545b1ff2 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,7 @@ - [Pipelining](#pipelining) - [One Redis Instance](#one-redis-instance) - [Two Redis Instances](#two-redis-instances) +- [Memcache](#memcache) - [Contact](#contact) @@ -554,6 +555,21 @@ To configure two Redis instances use the following environment variables: This setup will use the Redis server configured with the `_PERSECOND_` vars for per second limits, and the other Redis server for all other limits. +# Memcache + +Experimental Memcache support has been added as an alternative to Redis in v1.5. + +To configure a Memcache instance use the following environment variables instead of the Redis variables: + +1. `MEMCACHE_HOST_PORT=` +1. `BACKEND_TYPE=memcache` + +With memcache mode increments will happen asynchronously, so it's technically possible for +a client to exceed quota briefly if multiple requests happen at exactly the same time. + +Note that Memcache has a max key length of 250 characters, so operations referencing very long +descriptors will fail. + # Contact * [envoy-announce](https://groups.google.com/forum/#!forum/envoy-announce): Low frequency mailing diff --git a/docker-compose.yml b/docker-compose.yml index ac1ab9063..88d3a86e1 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -10,6 +10,15 @@ services: networks: - ratelimit-network + memcached: + image: memcached:alpine + expose: + - 11211 + ports: + - 11211:11211 + networks: + - ratelimit-network + # minimal container that builds the ratelimit service binary and exits. ratelimit-build: image: golang:1.14-alpine @@ -51,6 +60,7 @@ services: - REDIS_URL=redis:6379 - RUNTIME_ROOT=/data - RUNTIME_SUBDIRECTORY=ratelimit + - MEMCACHE_HOST_PORT=memcached:11211 networks: ratelimit-network: diff --git a/go.mod b/go.mod index 5a82cf7e0..ed885178a 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.14 require ( github.com/alicebob/miniredis/v2 v2.11.4 + github.com/bradfitz/gomemcache v0.0.0-20190913173617-a41fca850d0b github.com/cespare/xxhash v1.1.0 // indirect github.com/coocood/freecache v1.1.0 github.com/envoyproxy/go-control-plane v0.9.7 diff --git a/go.sum b/go.sum index b39e91b79..f8b59961e 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,8 @@ github.com/alicebob/gopher-json v0.0.0-20180125190556-5a6b3ba71ee6 h1:45bxf7AZMw github.com/alicebob/gopher-json v0.0.0-20180125190556-5a6b3ba71ee6/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc= github.com/alicebob/miniredis/v2 v2.11.4 h1:GsuyeunTx7EllZBU3/6Ji3dhMQZDpC9rLf1luJ+6M5M= github.com/alicebob/miniredis/v2 v2.11.4/go.mod h1:VL3UDEfAH59bSa7MuHMuFToxkqyHh69s/WUbYlOAuyg= +github.com/bradfitz/gomemcache v0.0.0-20190913173617-a41fca850d0b h1:L/QXpzIa3pOvUGt1D1lA5KjYhPBAN/3iWdP7xeFS9F0= +github.com/bradfitz/gomemcache v0.0.0-20190913173617-a41fca850d0b/go.mod h1:H0wQNHz2YrLsuXOZozoeDmnHXkNCRmMW0gwFWDfEZDA= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko= github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= diff --git a/src/limiter/cache.go b/src/limiter/cache.go index 9408126ca..5ca07edea 100644 --- a/src/limiter/cache.go +++ b/src/limiter/cache.go @@ -6,20 +6,6 @@ import ( "golang.org/x/net/context" ) -// Interface for a time source. -type TimeSource interface { - // @return the current unix time in seconds. - UnixNow() int64 -} - -// Interface for a rand Source for expiration jitter. -type JitterRandSource interface { - // @return a non-negative pseudo-random 63-bit integer as an int64. - Int63() int64 - // @param seed initializes pseudo-random generator to a deterministic state. - Seed(seed int64) -} - // Interface for interacting with a cache backend for rate limiting. type RateLimitCache interface { // Contact the cache and perform rate limiting for a set of descriptors and limits. @@ -35,4 +21,8 @@ type RateLimitCache interface { ctx context.Context, request *pb.RateLimitRequest, limits []*config.RateLimit) []*pb.RateLimitResponse_DescriptorStatus + + // Waits for any unfinished asynchronous work. This may be used by unit tests, + // since the memcache cache does increments in a background gorountine. + Flush() } diff --git a/src/memcached/cache_impl.go b/src/memcached/cache_impl.go new file mode 100644 index 000000000..52f8fae47 --- /dev/null +++ b/src/memcached/cache_impl.go @@ -0,0 +1,297 @@ +// The memcached limiter uses GetMulti() to check keys in parallel and then does +// increments asynchronously in the backend, since the memcache interface doesn't +// support multi-increment and it seems worthwhile to minimize the number of +// concurrent or sequential RPCs in the critical path. +// +// Another difference from redis is that memcache doesn't create a key implicitly by +// incrementing a missing entry. Instead, when increment fails an explicit "add" needs +// to be called. The process of increment becomes a bit of a dance since we try to +// limit the number of RPCs. First we call increment, then add if the increment +// failed, then increment again if the add failed (which could happen if there was +// a race to call "add"). +// +// Note that max memcache key length is 250 characters. Attempting to get or increment +// a longer key will return memcache.ErrMalformedKey + +package memcached + +import ( + "context" + "math" + "math/rand" + "strconv" + "sync" + + "github.com/coocood/freecache" + stats "github.com/lyft/gostats" + + "github.com/bradfitz/gomemcache/memcache" + + logger "github.com/sirupsen/logrus" + + pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" + + "github.com/envoyproxy/ratelimit/src/assert" + "github.com/envoyproxy/ratelimit/src/config" + "github.com/envoyproxy/ratelimit/src/limiter" + "github.com/envoyproxy/ratelimit/src/settings" + "github.com/envoyproxy/ratelimit/src/utils" +) + +type rateLimitMemcacheImpl struct { + client Client + timeSource utils.TimeSource + jitterRand *rand.Rand + expirationJitterMaxSeconds int64 + cacheKeyGenerator limiter.CacheKeyGenerator + localCache *freecache.Cache + waitGroup sync.WaitGroup + nearLimitRatio float32 +} + +var _ limiter.RateLimitCache = (*rateLimitMemcacheImpl)(nil) + +func max(a uint32, b uint32) uint32 { + if a > b { + return a + } + return b +} + +func (this *rateLimitMemcacheImpl) DoLimit( + ctx context.Context, + request *pb.RateLimitRequest, + limits []*config.RateLimit) []*pb.RateLimitResponse_DescriptorStatus { + + logger.Debugf("starting cache lookup") + + // request.HitsAddend could be 0 (default value) if not specified by the caller in the Ratelimit request. + hitsAddend := max(1, request.HitsAddend) + + // First build a list of all cache keys that we are actually going to hit. generateCacheKey() + // returns an empty string in the key if there is no limit so that we can keep the arrays + // all the same size. + assert.Assert(len(request.Descriptors) == len(limits)) + cacheKeys := make([]limiter.CacheKey, len(request.Descriptors)) + now := this.timeSource.UnixNow() + for i := 0; i < len(request.Descriptors); i++ { + cacheKeys[i] = this.cacheKeyGenerator.GenerateCacheKey(request.Domain, request.Descriptors[i], limits[i], now) + + // Increase statistics for limits hit by their respective requests. + if limits[i] != nil { + limits[i].Stats.TotalHits.Add(uint64(hitsAddend)) + } + } + + isOverLimitWithLocalCache := make([]bool, len(request.Descriptors)) + + keysToGet := make([]string, 0, len(request.Descriptors)) + + for i, cacheKey := range cacheKeys { + if cacheKey.Key == "" { + continue + } + + if this.localCache != nil { + // Get returns the value or not found error. + _, err := this.localCache.Get([]byte(cacheKey.Key)) + if err == nil { + isOverLimitWithLocalCache[i] = true + logger.Debugf("cache key is over the limit: %s", cacheKey.Key) + continue + } + } + + logger.Debugf("looking up cache key: %s", cacheKey.Key) + keysToGet = append(keysToGet, cacheKey.Key) + } + + // Now fetch from memcache. + responseDescriptorStatuses := make([]*pb.RateLimitResponse_DescriptorStatus, + len(request.Descriptors)) + + var memcacheValues map[string]*memcache.Item + var err error + + if len(keysToGet) > 0 { + memcacheValues, err = this.client.GetMulti(keysToGet) + if err != nil { + logger.Errorf("Error multi-getting memcache keys (%s): %s", keysToGet, err) + } + } + + for i, cacheKey := range cacheKeys { + if cacheKey.Key == "" { + responseDescriptorStatuses[i] = + &pb.RateLimitResponse_DescriptorStatus{ + Code: pb.RateLimitResponse_OK, + CurrentLimit: nil, + LimitRemaining: 0, + } + continue + } + + if isOverLimitWithLocalCache[i] { + responseDescriptorStatuses[i] = + &pb.RateLimitResponse_DescriptorStatus{ + Code: pb.RateLimitResponse_OVER_LIMIT, + CurrentLimit: limits[i].Limit, + LimitRemaining: 0, + DurationUntilReset: utils.CalculateReset(limits[i].Limit, this.timeSource), + } + limits[i].Stats.OverLimit.Add(uint64(hitsAddend)) + limits[i].Stats.OverLimitWithLocalCache.Add(uint64(hitsAddend)) + continue + } + + rawMemcacheValue, ok := memcacheValues[cacheKey.Key] + var limitBeforeIncrease uint32 + if ok { + decoded, err := strconv.ParseInt(string(rawMemcacheValue.Value), 10, 32) + if err != nil { + logger.Errorf("Unexpected non-numeric value in memcached: %v", rawMemcacheValue) + } else { + limitBeforeIncrease = uint32(decoded) + } + + } + + limitAfterIncrease := limitBeforeIncrease + hitsAddend + overLimitThreshold := limits[i].Limit.RequestsPerUnit + // The nearLimitThreshold is the number of requests that can be made before hitting the NearLimitRatio. + // We need to know it in both the OK and OVER_LIMIT scenarios. + nearLimitThreshold := uint32(math.Floor(float64(float32(overLimitThreshold) * this.nearLimitRatio))) + + logger.Debugf("cache key: %s current: %d", cacheKey.Key, limitAfterIncrease) + if limitAfterIncrease > overLimitThreshold { + responseDescriptorStatuses[i] = + &pb.RateLimitResponse_DescriptorStatus{ + Code: pb.RateLimitResponse_OVER_LIMIT, + CurrentLimit: limits[i].Limit, + LimitRemaining: 0, + DurationUntilReset: utils.CalculateReset(limits[i].Limit, this.timeSource), + } + + // Increase over limit statistics. Because we support += behavior for increasing the limit, we need to + // assess if the entire hitsAddend were over the limit. That is, if the limit's value before adding the + // N hits was over the limit, then all the N hits were over limit. + // Otherwise, only the difference between the current limit value and the over limit threshold + // were over limit hits. + if limitBeforeIncrease >= overLimitThreshold { + limits[i].Stats.OverLimit.Add(uint64(hitsAddend)) + } else { + limits[i].Stats.OverLimit.Add(uint64(limitAfterIncrease - overLimitThreshold)) + + // If the limit before increase was below the over limit value, then some of the hits were + // in the near limit range. + limits[i].Stats.NearLimit.Add(uint64(overLimitThreshold - max(nearLimitThreshold, limitBeforeIncrease))) + } + if this.localCache != nil { + // Set the TTL of the local_cache to be the entire duration. + // Since the cache_key gets changed once the time crosses over current time slot, the over-the-limit + // cache keys in local_cache lose effectiveness. + // For example, if we have an hour limit on all mongo connections, the cache key would be + // similar to mongo_1h, mongo_2h, etc. In the hour 1 (0h0m - 0h59m), the cache key is mongo_1h, we start + // to get ratelimited in the 50th minute, the ttl of local_cache will be set as 1 hour(0h50m-1h49m). + // In the time of 1h1m, since the cache key becomes different (mongo_2h), it won't get ratelimited. + err := this.localCache.Set([]byte(cacheKey.Key), []byte{}, int(utils.UnitToDivider(limits[i].Limit.Unit))) + if err != nil { + logger.Errorf("Failing to set local cache key: %s", cacheKey.Key) + } + } + } else { + responseDescriptorStatuses[i] = + &pb.RateLimitResponse_DescriptorStatus{ + Code: pb.RateLimitResponse_OK, + CurrentLimit: limits[i].Limit, + LimitRemaining: overLimitThreshold - limitAfterIncrease, + DurationUntilReset: utils.CalculateReset(limits[i].Limit, this.timeSource), + } + + // The limit is OK but we additionally want to know if we are near the limit. + if limitAfterIncrease > nearLimitThreshold { + // Here we also need to assess which portion of the hitsAddend were in the near limit range. + // If all the hits were over the nearLimitThreshold, then all hits are near limit. Otherwise, + // only the difference between the current limit value and the near limit threshold were near + // limit hits. + if limitBeforeIncrease >= nearLimitThreshold { + limits[i].Stats.NearLimit.Add(uint64(hitsAddend)) + } else { + limits[i].Stats.NearLimit.Add(uint64(limitAfterIncrease - nearLimitThreshold)) + } + } + } + } + + this.waitGroup.Add(1) + go this.increaseAsync(cacheKeys, isOverLimitWithLocalCache, limits, uint64(hitsAddend)) + + return responseDescriptorStatuses +} + +func (this *rateLimitMemcacheImpl) increaseAsync(cacheKeys []limiter.CacheKey, isOverLimitWithLocalCache []bool, limits []*config.RateLimit, hitsAddend uint64) { + defer this.waitGroup.Done() + for i, cacheKey := range cacheKeys { + if cacheKey.Key == "" || isOverLimitWithLocalCache[i] { + continue + } + + _, err := this.client.Increment(cacheKey.Key, hitsAddend) + if err == memcache.ErrCacheMiss { + expirationSeconds := utils.UnitToDivider(limits[i].Limit.Unit) + if this.expirationJitterMaxSeconds > 0 { + expirationSeconds += this.jitterRand.Int63n(this.expirationJitterMaxSeconds) + } + + // Need to add instead of increment + err = this.client.Add(&memcache.Item{ + Key: cacheKey.Key, + Value: []byte(strconv.FormatUint(hitsAddend, 10)), + Expiration: int32(expirationSeconds), + }) + if err == memcache.ErrNotStored { + // There was a race condition to do this add. We should be able to increment + // now instead. + _, err := this.client.Increment(cacheKey.Key, hitsAddend) + if err != nil { + logger.Errorf("Failed to increment key %s after failing to add: %s", cacheKey.Key, err) + continue + } + } else if err != nil { + logger.Errorf("Failed to add key %s: %s", cacheKey.Key, err) + continue + } + } else if err != nil { + logger.Errorf("Failed to increment key %s: %s", cacheKey.Key, err) + continue + } + } +} + +func (this *rateLimitMemcacheImpl) Flush() { + this.waitGroup.Wait() +} + +func NewRateLimitCacheImpl(client Client, timeSource utils.TimeSource, jitterRand *rand.Rand, expirationJitterMaxSeconds int64, localCache *freecache.Cache, scope stats.Scope, nearLimitRatio float32) limiter.RateLimitCache { + return &rateLimitMemcacheImpl{ + client: client, + timeSource: timeSource, + cacheKeyGenerator: limiter.NewCacheKeyGenerator(), + jitterRand: jitterRand, + expirationJitterMaxSeconds: expirationJitterMaxSeconds, + localCache: localCache, + nearLimitRatio: nearLimitRatio, + } +} + +func NewRateLimitCacheImplFromSettings(s settings.Settings, timeSource utils.TimeSource, jitterRand *rand.Rand, localCache *freecache.Cache, scope stats.Scope) limiter.RateLimitCache { + return NewRateLimitCacheImpl( + memcache.New(s.MemcacheHostPort), + timeSource, + jitterRand, + s.ExpirationJitterMaxSeconds, + localCache, + scope, + s.NearLimitRatio, + ) +} diff --git a/src/memcached/client.go b/src/memcached/client.go new file mode 100644 index 000000000..55c0ec318 --- /dev/null +++ b/src/memcached/client.go @@ -0,0 +1,14 @@ +package memcached + +import ( + "github.com/bradfitz/gomemcache/memcache" +) + +var _ Client = (*memcache.Client)(nil) + +// Interface for memcached, used for mocking. +type Client interface { + GetMulti(keys []string) (map[string]*memcache.Item, error) + Increment(key string, delta uint64) (newValue uint64, err error) + Add(item *memcache.Item) error +} diff --git a/src/redis/cache_impl.go b/src/redis/cache_impl.go index 28607c660..c65447da8 100644 --- a/src/redis/cache_impl.go +++ b/src/redis/cache_impl.go @@ -7,9 +7,10 @@ import ( "github.com/envoyproxy/ratelimit/src/limiter" "github.com/envoyproxy/ratelimit/src/server" "github.com/envoyproxy/ratelimit/src/settings" + "github.com/envoyproxy/ratelimit/src/utils" ) -func NewRateLimiterCacheImplFromSettings(s settings.Settings, localCache *freecache.Cache, srv server.Server, timeSource limiter.TimeSource, jitterRand *rand.Rand, expirationJitterMaxSeconds int64) limiter.RateLimitCache { +func NewRateLimiterCacheImplFromSettings(s settings.Settings, localCache *freecache.Cache, srv server.Server, timeSource utils.TimeSource, jitterRand *rand.Rand, expirationJitterMaxSeconds int64) limiter.RateLimitCache { var perSecondPool Client if s.RedisPerSecond { perSecondPool = NewClientImpl(srv.Scope().Scope("redis_per_second_pool"), s.RedisPerSecondTls, s.RedisPerSecondAuth, diff --git a/src/redis/fixed_cache_impl.go b/src/redis/fixed_cache_impl.go index a6fd067a0..6ecb53092 100644 --- a/src/redis/fixed_cache_impl.go +++ b/src/redis/fixed_cache_impl.go @@ -10,7 +10,6 @@ import ( "github.com/envoyproxy/ratelimit/src/config" "github.com/envoyproxy/ratelimit/src/limiter" "github.com/envoyproxy/ratelimit/src/utils" - "github.com/golang/protobuf/ptypes/duration" logger "github.com/sirupsen/logrus" "golang.org/x/net/context" ) @@ -22,7 +21,7 @@ type fixedRateLimitCacheImpl struct { // limits regardless of unit. If this client is not nil, then it // is used for limits that have a SECOND unit. perSecondClient Client - timeSource limiter.TimeSource + timeSource utils.TimeSource jitterRand *rand.Rand expirationJitterMaxSeconds int64 cacheKeyGenerator limiter.CacheKeyGenerator @@ -136,7 +135,7 @@ func (this *fixedRateLimitCacheImpl) DoLimit( Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[i].Limit, LimitRemaining: 0, - DurationUntilReset: CalculateReset(limits[i].Limit, this.timeSource), + DurationUntilReset: utils.CalculateReset(limits[i].Limit, this.timeSource), } limits[i].Stats.OverLimit.Add(uint64(hitsAddend)) limits[i].Stats.OverLimitWithLocalCache.Add(uint64(hitsAddend)) @@ -157,7 +156,7 @@ func (this *fixedRateLimitCacheImpl) DoLimit( Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[i].Limit, LimitRemaining: 0, - DurationUntilReset: CalculateReset(limits[i].Limit, this.timeSource), + DurationUntilReset: utils.CalculateReset(limits[i].Limit, this.timeSource), } // Increase over limit statistics. Because we support += behavior for increasing the limit, we need to @@ -193,7 +192,7 @@ func (this *fixedRateLimitCacheImpl) DoLimit( Code: pb.RateLimitResponse_OK, CurrentLimit: limits[i].Limit, LimitRemaining: overLimitThreshold - limitAfterIncrease, - DurationUntilReset: CalculateReset(limits[i].Limit, this.timeSource), + DurationUntilReset: utils.CalculateReset(limits[i].Limit, this.timeSource), } // The limit is OK but we additionally want to know if we are near the limit. @@ -214,13 +213,10 @@ func (this *fixedRateLimitCacheImpl) DoLimit( return responseDescriptorStatuses } -func CalculateReset(currentLimit *pb.RateLimitResponse_RateLimit, timeSource limiter.TimeSource) *duration.Duration { - sec := utils.UnitToDivider(currentLimit.Unit) - now := timeSource.UnixNow() - return &duration.Duration{Seconds: sec - now%sec} -} +// Flush() is a no-op with redis since quota reads and updates happen synchronously. +func (this *fixedRateLimitCacheImpl) Flush() {} -func NewFixedRateLimitCacheImpl(client Client, perSecondClient Client, timeSource limiter.TimeSource, jitterRand *rand.Rand, expirationJitterMaxSeconds int64, localCache *freecache.Cache, nearLimitRatio float32) limiter.RateLimitCache { +func NewFixedRateLimitCacheImpl(client Client, perSecondClient Client, timeSource utils.TimeSource, jitterRand *rand.Rand, expirationJitterMaxSeconds int64, localCache *freecache.Cache, nearLimitRatio float32) limiter.RateLimitCache { return &fixedRateLimitCacheImpl{ client: client, perSecondClient: perSecondClient, diff --git a/src/service_cmd/runner/runner.go b/src/service_cmd/runner/runner.go index 80e8e7814..7bde5e9a6 100644 --- a/src/service_cmd/runner/runner.go +++ b/src/service_cmd/runner/runner.go @@ -16,10 +16,12 @@ import ( "github.com/envoyproxy/ratelimit/src/config" "github.com/envoyproxy/ratelimit/src/limiter" + "github.com/envoyproxy/ratelimit/src/memcached" "github.com/envoyproxy/ratelimit/src/redis" "github.com/envoyproxy/ratelimit/src/server" ratelimit "github.com/envoyproxy/ratelimit/src/service" "github.com/envoyproxy/ratelimit/src/settings" + "github.com/envoyproxy/ratelimit/src/utils" logger "github.com/sirupsen/logrus" ) @@ -35,6 +37,29 @@ func (runner *Runner) GetStatsStore() stats.Store { return runner.statsStore } +func createLimiter(srv server.Server, s settings.Settings, localCache *freecache.Cache) limiter.RateLimitCache { + switch s.BackendType { + case "redis", "": + return redis.NewRateLimiterCacheImplFromSettings( + s, + localCache, + srv, + utils.NewTimeSourceImpl(), + rand.New(utils.NewLockedSource(time.Now().Unix())), + s.ExpirationJitterMaxSeconds) + case "memcache": + return memcached.NewRateLimitCacheImplFromSettings( + s, + utils.NewTimeSourceImpl(), + rand.New(utils.NewLockedSource(time.Now().Unix())), + localCache, + srv.Scope()) + default: + logger.Fatalf("Invalid setting for BackendType: %s", s.BackendType) + panic("This line should not be reachable") + } +} + func (runner *Runner) Run() { s := settings.NewSettings() @@ -63,13 +88,7 @@ func (runner *Runner) Run() { service := ratelimit.NewService( srv.Runtime(), - redis.NewRateLimiterCacheImplFromSettings( - s, - localCache, - srv, - limiter.NewTimeSourceImpl(), - rand.New(limiter.NewLockedSource(time.Now().Unix())), - s.ExpirationJitterMaxSeconds), + createLimiter(srv, s, localCache), config.NewRateLimitConfigLoaderImpl(), srv.Scope().Scope("service"), s.RuntimeWatchRoot, diff --git a/src/settings/settings.go b/src/settings/settings.go index 0d630858e..38a2474c2 100644 --- a/src/settings/settings.go +++ b/src/settings/settings.go @@ -43,6 +43,8 @@ type Settings struct { ExpirationJitterMaxSeconds int64 `envconfig:"EXPIRATION_JITTER_MAX_SECONDS" default:"300"` LocalCacheSizeInBytes int `envconfig:"LOCAL_CACHE_SIZE_IN_BYTES" default:"0"` NearLimitRatio float32 `envconfig:"NEAR_LIMIT_RATIO" default:"0.8"` + MemcacheHostPort string `envconfig:"MEMCACHE_HOST_PORT" default:""` + BackendType string `envconfig:"BACKEND_TYPE" default:"redis"` } type Option func(*Settings) diff --git a/src/limiter/time.go b/src/utils/time.go similarity index 67% rename from src/limiter/time.go rename to src/utils/time.go index e6a779e70..e7978cc6c 100644 --- a/src/limiter/time.go +++ b/src/utils/time.go @@ -1,4 +1,4 @@ -package limiter +package utils import ( "math/rand" @@ -6,6 +6,14 @@ import ( "time" ) +// Interface for a rand Source for expiration jitter. +type JitterRandSource interface { + // @return a non-negative pseudo-random 63-bit integer as an int64. + Int63() int64 + // @param seed initializes pseudo-random generator to a deterministic state. + Seed(seed int64) +} + type timeSourceImpl struct{} func NewTimeSourceImpl() TimeSource { diff --git a/src/utils/utilities.go b/src/utils/utilities.go index b28619dd5..8bfbb641f 100644 --- a/src/utils/utilities.go +++ b/src/utils/utilities.go @@ -2,8 +2,15 @@ package utils import ( pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" + "github.com/golang/protobuf/ptypes/duration" ) +// Interface for a time source. +type TimeSource interface { + // @return the current unix time in seconds. + UnixNow() int64 +} + // Convert a rate limit into a time divider. // @param unit supplies the unit to convert. // @return the divider to use in time computations. @@ -21,3 +28,9 @@ func UnitToDivider(unit pb.RateLimitResponse_RateLimit_Unit) int64 { panic("should not get here") } + +func CalculateReset(currentLimit *pb.RateLimitResponse_RateLimit, timeSource TimeSource) *duration.Duration { + sec := UnitToDivider(currentLimit.Unit) + now := timeSource.UnixNow() + return &duration.Duration{Seconds: sec - now%sec} +} diff --git a/test/integration/integration_test.go b/test/integration/integration_test.go index b76c8ada6..49c85bef5 100644 --- a/test/integration/integration_test.go +++ b/test/integration/integration_test.go @@ -47,10 +47,10 @@ func newDescriptorStatusLegacy( // TODO: Once adding the ability of stopping the server in the runner (https://github.com/envoyproxy/ratelimit/issues/119), // stop the server at the end of each test, thus we can reuse the grpc port among these integration tests. func TestBasicConfig(t *testing.T) { - t.Run("WithoutPerSecondRedis", testBasicConfig("8083", "false", "0")) - t.Run("WithPerSecondRedis", testBasicConfig("8085", "true", "0")) - t.Run("WithoutPerSecondRedisWithLocalCache", testBasicConfig("18083", "false", "1000")) - t.Run("WithPerSecondRedisWithLocalCache", testBasicConfig("18085", "true", "1000")) + t.Run("WithoutPerSecondRedis", testBasicConfig("8083", "false", "0", "")) + t.Run("WithPerSecondRedis", testBasicConfig("8085", "true", "0", "redis")) + t.Run("WithoutPerSecondRedisWithLocalCache", testBasicConfig("18083", "false", "1000", "")) + t.Run("WithPerSecondRedisWithLocalCache", testBasicConfig("18085", "true", "1000", "redis")) } func TestBasicTLSConfig(t *testing.T) { @@ -86,6 +86,11 @@ func TestBasicReloadConfig(t *testing.T) { t.Run("ReloadWithoutWatchRoot", testBasicConfigReload("8097", "false", "0", "false")) } +func TestBasicConfigMemcache(t *testing.T) { + t.Run("Memcache", testBasicConfig("8098", "false", "0", "memcache")) + t.Run("MemcacheWithLocalCache", testBasicConfig("18099", "false", "1000", "memcache")) +} + func testBasicConfigAuthTLS(grpcPort, perSecond string, local_cache_size string) func(*testing.T) { os.Setenv("REDIS_PERSECOND_URL", "localhost:16382") os.Setenv("REDIS_URL", "localhost:16381") @@ -96,12 +101,13 @@ func testBasicConfigAuthTLS(grpcPort, perSecond string, local_cache_size string) os.Setenv("REDIS_TYPE", "single") os.Setenv("REDIS_PERSECOND_TYPE", "single") - return testBasicBaseConfig(grpcPort, perSecond, local_cache_size) + return testBasicBaseConfig(grpcPort, perSecond, local_cache_size, "") } -func testBasicConfig(grpcPort, perSecond string, local_cache_size string) func(*testing.T) { +func testBasicConfig(grpcPort, perSecond string, local_cache_size string, backend_type string) func(*testing.T) { os.Setenv("REDIS_PERSECOND_URL", "localhost:6380") os.Setenv("REDIS_URL", "localhost:6379") + os.Setenv("MEMCACHE_HOST_PORT", "localhost:6394") os.Setenv("REDIS_AUTH", "") os.Setenv("REDIS_TLS", "false") os.Setenv("REDIS_PERSECOND_AUTH", "") @@ -109,7 +115,7 @@ func testBasicConfig(grpcPort, perSecond string, local_cache_size string) func(* os.Setenv("REDIS_TYPE", "single") os.Setenv("REDIS_PERSECOND_TYPE", "single") - return testBasicBaseConfig(grpcPort, perSecond, local_cache_size) + return testBasicBaseConfig(grpcPort, perSecond, local_cache_size, backend_type) } func testBasicConfigAuth(grpcPort, perSecond string, local_cache_size string) func(*testing.T) { @@ -122,7 +128,7 @@ func testBasicConfigAuth(grpcPort, perSecond string, local_cache_size string) fu os.Setenv("REDIS_TYPE", "single") os.Setenv("REDIS_PERSECOND_TYPE", "single") - return testBasicBaseConfig(grpcPort, perSecond, local_cache_size) + return testBasicBaseConfig(grpcPort, perSecond, local_cache_size, "") } func testBasicConfigAuthWithRedisCluster(grpcPort, perSecond string, local_cache_size string) func(*testing.T) { @@ -137,7 +143,7 @@ func testBasicConfigAuthWithRedisCluster(grpcPort, perSecond string, local_cache os.Setenv("REDIS_PERSECOND_PIPELINE_LIMIT", "8") os.Setenv("REDIS_PIPELINE_LIMIT", "8") - return testBasicBaseConfig(grpcPort, perSecond, local_cache_size) + return testBasicBaseConfig(grpcPort, perSecond, local_cache_size, "") } func testBasicAuthConfigWithRedisSentinel(grpcPort, perSecond string, local_cache_size string) func(*testing.T) { @@ -148,7 +154,7 @@ func testBasicAuthConfigWithRedisSentinel(grpcPort, perSecond string, local_cach os.Setenv("REDIS_URL", "mymaster,localhost:26394,localhost:26395,localhost:26396") os.Setenv("REDIS_TLS", "false") - return testBasicBaseConfig(grpcPort, perSecond, local_cache_size) + return testBasicBaseConfig(grpcPort, perSecond, local_cache_size, "") } func testBasicConfigWithoutWatchRoot(grpcPort, perSecond string, local_cache_size string) func(*testing.T) { @@ -161,7 +167,7 @@ func testBasicConfigWithoutWatchRoot(grpcPort, perSecond string, local_cache_siz os.Setenv("RUNTIME_WATCH_ROOT", "false") os.Setenv("REDIS_TYPE", "single") os.Setenv("REDIS_PERSECOND_TYPE", "single") - return testBasicBaseConfig(grpcPort, perSecond, local_cache_size) + return testBasicBaseConfig(grpcPort, perSecond, local_cache_size, "") } func testBasicConfigWithoutWatchRootWithRedisCluster(grpcPort, perSecond string, local_cache_size string) func(*testing.T) { @@ -177,7 +183,7 @@ func testBasicConfigWithoutWatchRootWithRedisCluster(grpcPort, perSecond string, os.Setenv("REDIS_PERSECOND_PIPELINE_LIMIT", "8") os.Setenv("REDIS_PIPELINE_LIMIT", "8") - return testBasicBaseConfig(grpcPort, perSecond, local_cache_size) + return testBasicBaseConfig(grpcPort, perSecond, local_cache_size, "") } func testBasicConfigWithoutWatchRootWithRedisSentinel(grpcPort, perSecond string, local_cache_size string) func(*testing.T) { @@ -189,7 +195,7 @@ func testBasicConfigWithoutWatchRootWithRedisSentinel(grpcPort, perSecond string os.Setenv("REDIS_PERSECOND_TLS", "false") os.Setenv("RUNTIME_WATCH_ROOT", "false") - return testBasicBaseConfig(grpcPort, perSecond, local_cache_size) + return testBasicBaseConfig(grpcPort, perSecond, local_cache_size, "") } func testBasicConfigReload(grpcPort, perSecond string, local_cache_size, runtimeWatchRoot string) func(*testing.T) { @@ -242,7 +248,7 @@ func getCacheKey(cacheKey string, enableLocalCache bool) string { return cacheKey } -func testBasicBaseConfig(grpcPort, perSecond string, local_cache_size string) func(*testing.T) { +func testBasicBaseConfig(grpcPort, perSecond string, local_cache_size string, backend_type string) func(*testing.T) { return func(t *testing.T) { os.Setenv("REDIS_PERSECOND", perSecond) os.Setenv("PORT", "8082") @@ -254,6 +260,7 @@ func testBasicBaseConfig(grpcPort, perSecond string, local_cache_size string) fu os.Setenv("REDIS_SOCKET_TYPE", "tcp") os.Setenv("LOCAL_CACHE_SIZE_IN_BYTES", local_cache_size) os.Setenv("USE_STATSD", "false") + os.Setenv("BACKEND_TYPE", backend_type) local_cache_size_val, _ := strconv.Atoi(local_cache_size) enable_local_cache := local_cache_size_val > 0 @@ -494,6 +501,7 @@ func TestBasicConfigLegacy(t *testing.T) { os.Setenv("REDIS_AUTH", "") os.Setenv("REDIS_PERSECOND_TLS", "false") os.Setenv("REDIS_PERSECOND_AUTH", "") + os.Setenv("BACKEND_TYPE", "") os.Setenv("REDIS_TYPE", "single") os.Setenv("REDIS_PERSECOND_TYPE", "single") @@ -604,6 +612,7 @@ func testConfigReload(grpcPort, perSecond string, local_cache_size string) func( os.Setenv("REDIS_SOCKET_TYPE", "tcp") os.Setenv("LOCAL_CACHE_SIZE_IN_BYTES", local_cache_size) os.Setenv("USE_STATSD", "false") + os.Setenv("BACKEND_TYPE", "") local_cache_size_val, _ := strconv.Atoi(local_cache_size) enable_local_cache := local_cache_size_val > 0 diff --git a/test/memcached/cache_impl_test.go b/test/memcached/cache_impl_test.go new file mode 100644 index 000000000..fad218407 --- /dev/null +++ b/test/memcached/cache_impl_test.go @@ -0,0 +1,572 @@ +// Adapted from test/redis/cache_impl_test.go, with most test cases being the same +// basic idea. TestMemcacheAdd() is unique to the memcache tests, since redis can create a new key +// simply by incrementing it but memcached cannot. In memcache new keys need to be explicitly +// added. +package memcached_test + +import ( + "math/rand" + "strconv" + "testing" + + "github.com/bradfitz/gomemcache/memcache" + "github.com/coocood/freecache" + + pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" + "github.com/envoyproxy/ratelimit/src/config" + "github.com/envoyproxy/ratelimit/src/limiter" + "github.com/envoyproxy/ratelimit/src/memcached" + "github.com/envoyproxy/ratelimit/src/utils" + stats "github.com/lyft/gostats" + + "github.com/envoyproxy/ratelimit/test/common" + mock_memcached "github.com/envoyproxy/ratelimit/test/mocks/memcached" + mock_utils "github.com/envoyproxy/ratelimit/test/mocks/utils" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" +) + +func TestMemcached(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + timeSource := mock_utils.NewMockTimeSource(controller) + client := mock_memcached.NewMockClient(controller) + statsStore := stats.NewStore(stats.NewNullSink(), false) + cache := memcached.NewRateLimitCacheImpl(client, timeSource, nil, 0, nil, statsStore, 0.8) + + timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) + client.EXPECT().GetMulti([]string{"domain_key_value_1234"}).Return( + getMultiResult(map[string]int{"domain_key_value_1234": 4}), nil, + ) + client.EXPECT().Increment("domain_key_value_1234", uint64(1)).Return(uint64(5), nil) + + request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) + limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 5, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) + + timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) + client.EXPECT().GetMulti([]string{"domain_key2_value2_subkey2_subvalue2_1200"}).Return( + getMultiResult(map[string]int{"domain_key2_value2_subkey2_subvalue2_1200": 10}), nil, + ) + client.EXPECT().Increment("domain_key2_value2_subkey2_subvalue2_1200", uint64(1)).Return(uint64(11), nil) + + request = common.NewRateLimitRequest( + "domain", + [][][2]string{ + {{"key2", "value2"}}, + {{"key2", "value2"}, {"subkey2", "subvalue2"}}, + }, 1) + limits = []*config.RateLimit{ + nil, + config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, "key2_value2_subkey2_subvalue2", statsStore)} + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: nil, LimitRemaining: 0}, + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[1].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[1].Limit, timeSource)}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(1), limits[1].Stats.TotalHits.Value()) + assert.Equal(uint64(1), limits[1].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[1].Stats.NearLimit.Value()) + + timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(5) + client.EXPECT().GetMulti([]string{ + "domain_key3_value3_997200", + "domain_key3_value3_subkey3_subvalue3_950400", + }).Return( + getMultiResult(map[string]int{ + "domain_key3_value3_997200": 10, + "domain_key3_value3_subkey3_subvalue3_950400": 12}), + nil, + ) + client.EXPECT().Increment("domain_key3_value3_997200", uint64(1)).Return(uint64(11), nil) + client.EXPECT().Increment("domain_key3_value3_subkey3_subvalue3_950400", uint64(1)).Return(uint64(13), nil) + + request = common.NewRateLimitRequest( + "domain", + [][][2]string{ + {{"key3", "value3"}}, + {{"key3", "value3"}, {"subkey3", "subvalue3"}}, + }, 1) + limits = []*config.RateLimit{ + config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_HOUR, "key3_value3", statsStore), + config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_DAY, "key3_value3_subkey3_subvalue3", statsStore)} + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{ + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}, + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[1].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[1].Limit, timeSource)}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(1), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) + assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(1), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) + + cache.Flush() +} + +func TestMemcachedGetError(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + timeSource := mock_utils.NewMockTimeSource(controller) + client := mock_memcached.NewMockClient(controller) + statsStore := stats.NewStore(stats.NewNullSink(), false) + cache := memcached.NewRateLimitCacheImpl(client, timeSource, nil, 0, nil, statsStore, 0.8) + + timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) + client.EXPECT().GetMulti([]string{"domain_key_value_1234"}).Return( + nil, memcache.ErrNoServers, + ) + client.EXPECT().Increment("domain_key_value_1234", uint64(1)).Return(uint64(5), nil) + + request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) + limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) + + // No error, but the key is missing + timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) + client.EXPECT().GetMulti([]string{"domain_key_value1_1234"}).Return( + nil, nil, + ) + client.EXPECT().Increment("domain_key_value1_1234", uint64(1)).Return(uint64(5), nil) + + request = common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value1"}}}, 1) + limits = []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value1", statsStore)} + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) + + cache.Flush() +} + +func testLocalCacheStats(localCacheStats stats.StatGenerator, statsStore stats.Store, sink *common.TestStatSink, + expectedHitCount int, expectedMissCount int, expectedLookUpCount int, expectedExpiredCount int, + expectedEntryCount int) func(*testing.T) { + return func(t *testing.T) { + localCacheStats.GenerateStats() + statsStore.Flush() + + // Check whether all local_cache related stats are available. + _, ok := sink.Record["averageAccessTime"] + assert.Equal(t, true, ok) + hitCount, ok := sink.Record["hitCount"] + assert.Equal(t, true, ok) + missCount, ok := sink.Record["missCount"] + assert.Equal(t, true, ok) + lookupCount, ok := sink.Record["lookupCount"] + assert.Equal(t, true, ok) + _, ok = sink.Record["overwriteCount"] + assert.Equal(t, true, ok) + _, ok = sink.Record["evacuateCount"] + assert.Equal(t, true, ok) + expiredCount, ok := sink.Record["expiredCount"] + assert.Equal(t, true, ok) + entryCount, ok := sink.Record["entryCount"] + assert.Equal(t, true, ok) + + // Check the correctness of hitCount, missCount, lookupCount, expiredCount and entryCount + assert.Equal(t, expectedHitCount, hitCount.(int)) + assert.Equal(t, expectedMissCount, missCount.(int)) + assert.Equal(t, expectedLookUpCount, lookupCount.(int)) + assert.Equal(t, expectedExpiredCount, expiredCount.(int)) + assert.Equal(t, expectedEntryCount, entryCount.(int)) + + sink.Clear() + } +} + +func TestOverLimitWithLocalCache(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + timeSource := mock_utils.NewMockTimeSource(controller) + client := mock_memcached.NewMockClient(controller) + localCache := freecache.NewCache(100) + sink := &common.TestStatSink{} + statsStore := stats.NewStore(sink, true) + cache := memcached.NewRateLimitCacheImpl(client, timeSource, nil, 0, localCache, statsStore, 0.8) + localCacheStats := limiter.NewLocalCacheStats(localCache, statsStore.Scope("localcache")) + + // Test Near Limit Stats. Under Near Limit Ratio + timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) + client.EXPECT().GetMulti([]string{"domain_key4_value4_997200"}).Return( + getMultiResult(map[string]int{"domain_key4_value4_997200": 10}), nil, + ) + client.EXPECT().Increment("domain_key4_value4_997200", uint64(1)).Return(uint64(5), nil) + + request := common.NewRateLimitRequest("domain", [][][2]string{{{"key4", "value4"}}}, 1) + + limits := []*config.RateLimit{ + config.NewRateLimit(15, pb.RateLimitResponse_RateLimit_HOUR, "key4_value4", statsStore)} + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{ + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 4, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimitWithLocalCache.Value()) + assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) + + // Check the local cache stats. + testLocalCacheStats(localCacheStats, statsStore, sink, 0, 1, 1, 0, 0) + + // Test Near Limit Stats. At Near Limit Ratio, still OK + timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) + client.EXPECT().GetMulti([]string{"domain_key4_value4_997200"}).Return( + getMultiResult(map[string]int{"domain_key4_value4_997200": 12}), nil, + ) + client.EXPECT().Increment("domain_key4_value4_997200", uint64(1)).Return(uint64(13), nil) + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{ + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 2, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(2), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimitWithLocalCache.Value()) + assert.Equal(uint64(1), limits[0].Stats.NearLimit.Value()) + + // Check the local cache stats. + testLocalCacheStats(localCacheStats, statsStore, sink, 0, 2, 2, 0, 0) + + // Test Over limit stats + timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) + client.EXPECT().GetMulti([]string{"domain_key4_value4_997200"}).Return( + getMultiResult(map[string]int{"domain_key4_value4_997200": 15}), nil, + ) + client.EXPECT().Increment("domain_key4_value4_997200", uint64(1)).Return(uint64(16), nil) + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{ + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(1), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimitWithLocalCache.Value()) + assert.Equal(uint64(1), limits[0].Stats.NearLimit.Value()) + + // Check the local cache stats. + testLocalCacheStats(localCacheStats, statsStore, sink, 0, 2, 3, 0, 1) + + // Test Over limit stats with local cache + timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) + client.EXPECT().GetMulti([]string{"domain_key4_value4_997200"}).Times(0) + client.EXPECT().Increment("domain_key4_value4_997200", uint64(1)).Times(0) + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{ + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(4), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(2), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(1), limits[0].Stats.OverLimitWithLocalCache.Value()) + assert.Equal(uint64(1), limits[0].Stats.NearLimit.Value()) + + // Check the local cache stats. + testLocalCacheStats(localCacheStats, statsStore, sink, 1, 3, 4, 0, 1) + + cache.Flush() +} + +func TestNearLimit(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + timeSource := mock_utils.NewMockTimeSource(controller) + client := mock_memcached.NewMockClient(controller) + statsStore := stats.NewStore(stats.NewNullSink(), false) + cache := memcached.NewRateLimitCacheImpl(client, timeSource, nil, 0, nil, statsStore, 0.8) + + // Test Near Limit Stats. Under Near Limit Ratio + timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) + client.EXPECT().GetMulti([]string{"domain_key4_value4_997200"}).Return( + getMultiResult(map[string]int{"domain_key4_value4_997200": 10}), nil, + ) + client.EXPECT().Increment("domain_key4_value4_997200", uint64(1)).Return(uint64(11), nil) + + request := common.NewRateLimitRequest("domain", [][][2]string{{{"key4", "value4"}}}, 1) + + limits := []*config.RateLimit{ + config.NewRateLimit(15, pb.RateLimitResponse_RateLimit_HOUR, "key4_value4", statsStore)} + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{ + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 4, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) + + // Test Near Limit Stats. At Near Limit Ratio, still OK + timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) + client.EXPECT().GetMulti([]string{"domain_key4_value4_997200"}).Return( + getMultiResult(map[string]int{"domain_key4_value4_997200": 12}), nil, + ) + client.EXPECT().Increment("domain_key4_value4_997200", uint64(1)).Return(uint64(13), nil) + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{ + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 2, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(2), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(1), limits[0].Stats.NearLimit.Value()) + + // Test Near Limit Stats. We went OVER_LIMIT, but the near_limit counter only increases + // when we are near limit, not after we have passed the limit. + timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) + client.EXPECT().GetMulti([]string{"domain_key4_value4_997200"}).Return( + getMultiResult(map[string]int{"domain_key4_value4_997200": 15}), nil, + ) + client.EXPECT().Increment("domain_key4_value4_997200", uint64(1)).Return(uint64(16), nil) + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{ + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(1), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(1), limits[0].Stats.NearLimit.Value()) + + // Now test hitsAddend that is greater than 1 + // All of it under limit, under near limit + timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) + client.EXPECT().GetMulti([]string{"domain_key5_value5_1234"}).Return( + getMultiResult(map[string]int{"domain_key5_value5_1234": 2}), nil, + ) + client.EXPECT().Increment("domain_key5_value5_1234", uint64(3)).Return(uint64(5), nil) + + request = common.NewRateLimitRequest("domain", [][][2]string{{{"key5", "value5"}}}, 3) + limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, "key5_value5", statsStore)} + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 15, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) + + // All of it under limit, some over near limit + timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) + client.EXPECT().GetMulti([]string{"domain_key6_value6_1234"}).Return( + getMultiResult(map[string]int{"domain_key6_value6_1234": 5}), nil, + ) + client.EXPECT().Increment("domain_key6_value6_1234", uint64(2)).Return(uint64(7), nil) + + request = common.NewRateLimitRequest("domain", [][][2]string{{{"key6", "value6"}}}, 2) + limits = []*config.RateLimit{config.NewRateLimit(8, pb.RateLimitResponse_RateLimit_SECOND, "key6_value6", statsStore)} + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 1, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(2), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(1), limits[0].Stats.NearLimit.Value()) + + // All of it under limit, all of it over near limit + timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) + client.EXPECT().GetMulti([]string{"domain_key7_value7_1234"}).Return( + getMultiResult(map[string]int{"domain_key7_value7_1234": 16}), nil, + ) + client.EXPECT().Increment("domain_key7_value7_1234", uint64(3)).Return(uint64(19), nil) + + request = common.NewRateLimitRequest("domain", [][][2]string{{{"key7", "value7"}}}, 3) + limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, "key7_value7", statsStore)} + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 1, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(3), limits[0].Stats.NearLimit.Value()) + + // Some of it over limit, all of it over near limit + timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) + client.EXPECT().GetMulti([]string{"domain_key8_value8_1234"}).Return( + getMultiResult(map[string]int{"domain_key8_value8_1234": 19}), nil, + ) + client.EXPECT().Increment("domain_key8_value8_1234", uint64(3)).Return(uint64(22), nil) + + request = common.NewRateLimitRequest("domain", [][][2]string{{{"key8", "value8"}}}, 3) + limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, "key8_value8", statsStore)} + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(2), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(1), limits[0].Stats.NearLimit.Value()) + + // Some of it in all three places + timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) + client.EXPECT().GetMulti([]string{"domain_key9_value9_1234"}).Return( + getMultiResult(map[string]int{"domain_key9_value9_1234": 15}), nil, + ) + client.EXPECT().Increment("domain_key9_value9_1234", uint64(7)).Return(uint64(22), nil) + + request = common.NewRateLimitRequest("domain", [][][2]string{{{"key9", "value9"}}}, 7) + limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, "key9_value9", statsStore)} + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(7), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(2), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(4), limits[0].Stats.NearLimit.Value()) + + // all of it over limit + timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) + client.EXPECT().GetMulti([]string{"domain_key10_value10_1234"}).Return( + getMultiResult(map[string]int{"domain_key10_value10_1234": 27}), nil, + ) + client.EXPECT().Increment("domain_key10_value10_1234", uint64(3)).Return(uint64(30), nil) + + request = common.NewRateLimitRequest("domain", [][][2]string{{{"key10", "value10"}}}, 3) + limits = []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key10_value10", statsStore)} + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(3), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) + + cache.Flush() +} + +func TestMemcacheWithJitter(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + timeSource := mock_utils.NewMockTimeSource(controller) + client := mock_memcached.NewMockClient(controller) + jitterSource := mock_utils.NewMockJitterRandSource(controller) + statsStore := stats.NewStore(stats.NewNullSink(), false) + cache := memcached.NewRateLimitCacheImpl(client, timeSource, rand.New(jitterSource), 3600, nil, statsStore, 0.8) + + timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) + jitterSource.EXPECT().Int63().Return(int64(100)) + + // Key is not found in memcache + client.EXPECT().GetMulti([]string{"domain_key_value_1234"}).Return(nil, nil) + // First increment attempt will fail + client.EXPECT().Increment("domain_key_value_1234", uint64(1)).Return( + uint64(0), memcache.ErrCacheMiss) + // Add succeeds + client.EXPECT().Add( + &memcache.Item{ + Key: "domain_key_value_1234", + Value: []byte(strconv.FormatUint(1, 10)), + // 1 second + 100 seconds of jitter + Expiration: int32(101), + }, + ).Return(nil) + + request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) + limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) + + cache.Flush() +} + +func TestMemcacheAdd(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + timeSource := mock_utils.NewMockTimeSource(controller) + client := mock_memcached.NewMockClient(controller) + statsStore := stats.NewStore(stats.NewNullSink(), false) + cache := memcached.NewRateLimitCacheImpl(client, timeSource, nil, 0, nil, statsStore, 0.8) + + // Test a race condition with the initial add + timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) + + client.EXPECT().GetMulti([]string{"domain_key_value_1234"}).Return(nil, nil) + client.EXPECT().Increment("domain_key_value_1234", uint64(1)).Return( + uint64(0), memcache.ErrCacheMiss) + // Add fails, must have been a race condition + client.EXPECT().Add( + &memcache.Item{ + Key: "domain_key_value_1234", + Value: []byte(strconv.FormatUint(1, 10)), + Expiration: int32(1), + }, + ).Return(memcache.ErrNotStored) + // Should work the second time, since some other client added the key. + client.EXPECT().Increment("domain_key_value_1234", uint64(1)).Return( + uint64(2), nil) + + request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) + limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) + + // A rate limit with 1-minute window + timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) + client.EXPECT().GetMulti([]string{"domain_key2_value2_1200"}).Return(nil, nil) + client.EXPECT().Increment("domain_key2_value2_1200", uint64(1)).Return( + uint64(0), memcache.ErrCacheMiss) + client.EXPECT().Add( + &memcache.Item{ + Key: "domain_key2_value2_1200", + Value: []byte(strconv.FormatUint(1, 10)), + Expiration: int32(60), + }, + ).Return(nil) + + request = common.NewRateLimitRequest("domain", [][][2]string{{{"key2", "value2"}}}, 1) + limits = []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, "key2_value2", statsStore)} + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) + + cache.Flush() +} + +func getMultiResult(vals map[string]int) map[string]*memcache.Item { + result := make(map[string]*memcache.Item, len(vals)) + for k, v := range vals { + result[k] = &memcache.Item{ + Value: []byte(strconv.Itoa(v)), + } + } + return result +} diff --git a/test/mocks/limiter/limiter.go b/test/mocks/limiter/limiter.go index 7e9f3e5b3..48f995a1f 100644 --- a/test/mocks/limiter/limiter.go +++ b/test/mocks/limiter/limiter.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/envoyproxy/ratelimit/src/limiter (interfaces: RateLimitCache,TimeSource,JitterRandSource) +// Source: github.com/envoyproxy/ratelimit/src/limiter (interfaces: RateLimitCache) // Package mock_limiter is a generated GoMock package. package mock_limiter @@ -49,88 +49,14 @@ func (mr *MockRateLimitCacheMockRecorder) DoLimit(arg0, arg1, arg2 interface{}) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DoLimit", reflect.TypeOf((*MockRateLimitCache)(nil).DoLimit), arg0, arg1, arg2) } -// MockTimeSource is a mock of TimeSource interface -type MockTimeSource struct { - ctrl *gomock.Controller - recorder *MockTimeSourceMockRecorder -} - -// MockTimeSourceMockRecorder is the mock recorder for MockTimeSource -type MockTimeSourceMockRecorder struct { - mock *MockTimeSource -} - -// NewMockTimeSource creates a new mock instance -func NewMockTimeSource(ctrl *gomock.Controller) *MockTimeSource { - mock := &MockTimeSource{ctrl: ctrl} - mock.recorder = &MockTimeSourceMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use -func (m *MockTimeSource) EXPECT() *MockTimeSourceMockRecorder { - return m.recorder -} - -// UnixNow mocks base method -func (m *MockTimeSource) UnixNow() int64 { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UnixNow") - ret0, _ := ret[0].(int64) - return ret0 -} - -// UnixNow indicates an expected call of UnixNow -func (mr *MockTimeSourceMockRecorder) UnixNow() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnixNow", reflect.TypeOf((*MockTimeSource)(nil).UnixNow)) -} - -// MockJitterRandSource is a mock of JitterRandSource interface -type MockJitterRandSource struct { - ctrl *gomock.Controller - recorder *MockJitterRandSourceMockRecorder -} - -// MockJitterRandSourceMockRecorder is the mock recorder for MockJitterRandSource -type MockJitterRandSourceMockRecorder struct { - mock *MockJitterRandSource -} - -// NewMockJitterRandSource creates a new mock instance -func NewMockJitterRandSource(ctrl *gomock.Controller) *MockJitterRandSource { - mock := &MockJitterRandSource{ctrl: ctrl} - mock.recorder = &MockJitterRandSourceMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use -func (m *MockJitterRandSource) EXPECT() *MockJitterRandSourceMockRecorder { - return m.recorder -} - -// Int63 mocks base method -func (m *MockJitterRandSource) Int63() int64 { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Int63") - ret0, _ := ret[0].(int64) - return ret0 -} - -// Int63 indicates an expected call of Int63 -func (mr *MockJitterRandSourceMockRecorder) Int63() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Int63", reflect.TypeOf((*MockJitterRandSource)(nil).Int63)) -} - -// Seed mocks base method -func (m *MockJitterRandSource) Seed(arg0 int64) { +// Flush mocks base method +func (m *MockRateLimitCache) Flush() { m.ctrl.T.Helper() - m.ctrl.Call(m, "Seed", arg0) + m.ctrl.Call(m, "Flush") } -// Seed indicates an expected call of Seed -func (mr *MockJitterRandSourceMockRecorder) Seed(arg0 interface{}) *gomock.Call { +// Flush indicates an expected call of Flush +func (mr *MockRateLimitCacheMockRecorder) Flush() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Seed", reflect.TypeOf((*MockJitterRandSource)(nil).Seed), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Flush", reflect.TypeOf((*MockRateLimitCache)(nil).Flush)) } diff --git a/test/mocks/memcached/client.go b/test/mocks/memcached/client.go new file mode 100644 index 000000000..433105bd0 --- /dev/null +++ b/test/mocks/memcached/client.go @@ -0,0 +1,78 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/envoyproxy/ratelimit/src/memcached (interfaces: Client) + +// Package mock_memcached is a generated GoMock package. +package mock_memcached + +import ( + memcache "github.com/bradfitz/gomemcache/memcache" + gomock "github.com/golang/mock/gomock" + reflect "reflect" +) + +// MockClient is a mock of Client interface +type MockClient struct { + ctrl *gomock.Controller + recorder *MockClientMockRecorder +} + +// MockClientMockRecorder is the mock recorder for MockClient +type MockClientMockRecorder struct { + mock *MockClient +} + +// NewMockClient creates a new mock instance +func NewMockClient(ctrl *gomock.Controller) *MockClient { + mock := &MockClient{ctrl: ctrl} + mock.recorder = &MockClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockClient) EXPECT() *MockClientMockRecorder { + return m.recorder +} + +// Add mocks base method +func (m *MockClient) Add(arg0 *memcache.Item) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Add", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Add indicates an expected call of Add +func (mr *MockClientMockRecorder) Add(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockClient)(nil).Add), arg0) +} + +// GetMulti mocks base method +func (m *MockClient) GetMulti(arg0 []string) (map[string]*memcache.Item, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetMulti", arg0) + ret0, _ := ret[0].(map[string]*memcache.Item) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetMulti indicates an expected call of GetMulti +func (mr *MockClientMockRecorder) GetMulti(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMulti", reflect.TypeOf((*MockClient)(nil).GetMulti), arg0) +} + +// Increment mocks base method +func (m *MockClient) Increment(arg0 string, arg1 uint64) (uint64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Increment", arg0, arg1) + ret0, _ := ret[0].(uint64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Increment indicates an expected call of Increment +func (mr *MockClientMockRecorder) Increment(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Increment", reflect.TypeOf((*MockClient)(nil).Increment), arg0, arg1) +} diff --git a/test/mocks/mocks.go b/test/mocks/mocks.go index 9f8b18cec..2aafcb30c 100644 --- a/test/mocks/mocks.go +++ b/test/mocks/mocks.go @@ -4,5 +4,7 @@ package mocks //go:generate go run github.com/golang/mock/mockgen -destination ./runtime/loader/loader.go github.com/lyft/goruntime/loader IFace //go:generate go run github.com/golang/mock/mockgen -destination ./config/config.go github.com/envoyproxy/ratelimit/src/config RateLimitConfig,RateLimitConfigLoader //go:generate go run github.com/golang/mock/mockgen -destination ./redis/redis.go github.com/envoyproxy/ratelimit/src/redis Client -//go:generate go run github.com/golang/mock/mockgen -destination ./limiter/limiter.go github.com/envoyproxy/ratelimit/src/limiter RateLimitCache,TimeSource,JitterRandSource +//go:generate go run github.com/golang/mock/mockgen -destination ./limiter/limiter.go github.com/envoyproxy/ratelimit/src/limiter RateLimitCache +//go:generate go run github.com/golang/mock/mockgen -destination ./utils/utils.go github.com/envoyproxy/ratelimit/src/utils TimeSource,JitterRandSource +//go:generate go run github.com/golang/mock/mockgen -destination ./memcached/client.go github.com/envoyproxy/ratelimit/src/memcached Client //go:generate go run github.com/golang/mock/mockgen -destination ./rls/rls.go github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3 RateLimitServiceServer diff --git a/test/mocks/utils/utils.go b/test/mocks/utils/utils.go new file mode 100644 index 000000000..1812f4f0f --- /dev/null +++ b/test/mocks/utils/utils.go @@ -0,0 +1,96 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/envoyproxy/ratelimit/src/utils (interfaces: TimeSource,JitterRandSource) + +// Package mock_utils is a generated GoMock package. +package mock_utils + +import ( + gomock "github.com/golang/mock/gomock" + reflect "reflect" +) + +// MockTimeSource is a mock of TimeSource interface +type MockTimeSource struct { + ctrl *gomock.Controller + recorder *MockTimeSourceMockRecorder +} + +// MockTimeSourceMockRecorder is the mock recorder for MockTimeSource +type MockTimeSourceMockRecorder struct { + mock *MockTimeSource +} + +// NewMockTimeSource creates a new mock instance +func NewMockTimeSource(ctrl *gomock.Controller) *MockTimeSource { + mock := &MockTimeSource{ctrl: ctrl} + mock.recorder = &MockTimeSourceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockTimeSource) EXPECT() *MockTimeSourceMockRecorder { + return m.recorder +} + +// UnixNow mocks base method +func (m *MockTimeSource) UnixNow() int64 { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UnixNow") + ret0, _ := ret[0].(int64) + return ret0 +} + +// UnixNow indicates an expected call of UnixNow +func (mr *MockTimeSourceMockRecorder) UnixNow() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnixNow", reflect.TypeOf((*MockTimeSource)(nil).UnixNow)) +} + +// MockJitterRandSource is a mock of JitterRandSource interface +type MockJitterRandSource struct { + ctrl *gomock.Controller + recorder *MockJitterRandSourceMockRecorder +} + +// MockJitterRandSourceMockRecorder is the mock recorder for MockJitterRandSource +type MockJitterRandSourceMockRecorder struct { + mock *MockJitterRandSource +} + +// NewMockJitterRandSource creates a new mock instance +func NewMockJitterRandSource(ctrl *gomock.Controller) *MockJitterRandSource { + mock := &MockJitterRandSource{ctrl: ctrl} + mock.recorder = &MockJitterRandSourceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockJitterRandSource) EXPECT() *MockJitterRandSourceMockRecorder { + return m.recorder +} + +// Int63 mocks base method +func (m *MockJitterRandSource) Int63() int64 { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Int63") + ret0, _ := ret[0].(int64) + return ret0 +} + +// Int63 indicates an expected call of Int63 +func (mr *MockJitterRandSourceMockRecorder) Int63() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Int63", reflect.TypeOf((*MockJitterRandSource)(nil).Int63)) +} + +// Seed mocks base method +func (m *MockJitterRandSource) Seed(arg0 int64) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Seed", arg0) +} + +// Seed indicates an expected call of Seed +func (mr *MockJitterRandSourceMockRecorder) Seed(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Seed", reflect.TypeOf((*MockJitterRandSource)(nil).Seed), arg0) +} diff --git a/test/redis/bench_test.go b/test/redis/bench_test.go index e5c179777..4b1766b27 100644 --- a/test/redis/bench_test.go +++ b/test/redis/bench_test.go @@ -8,8 +8,8 @@ import ( pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" "github.com/envoyproxy/ratelimit/src/config" - "github.com/envoyproxy/ratelimit/src/limiter" "github.com/envoyproxy/ratelimit/src/redis" + "github.com/envoyproxy/ratelimit/src/utils" stats "github.com/lyft/gostats" "math/rand" @@ -44,7 +44,7 @@ func BenchmarkParallelDoLimit(b *testing.B) { client := redis.NewClientImpl(statsStore, false, "", "single", "127.0.0.1:6379", poolSize, pipelineWindow, pipelineLimit) defer client.Close() - cache := redis.NewFixedRateLimitCacheImpl(client, nil, limiter.NewTimeSourceImpl(), rand.New(limiter.NewLockedSource(time.Now().Unix())), 10, nil, 0.8) + cache := redis.NewFixedRateLimitCacheImpl(client, nil, utils.NewTimeSourceImpl(), rand.New(utils.NewLockedSource(time.Now().Unix())), 10, nil, 0.8) request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) limits := []*config.RateLimit{config.NewRateLimit(1000000000, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} diff --git a/test/redis/fixed_cache_impl_test.go b/test/redis/fixed_cache_impl_test.go index dea476cba..747cb9798 100644 --- a/test/redis/fixed_cache_impl_test.go +++ b/test/redis/fixed_cache_impl_test.go @@ -1,21 +1,23 @@ package redis_test import ( + "testing" + "github.com/coocood/freecache" "github.com/mediocregopher/radix/v3" - "testing" pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" "github.com/envoyproxy/ratelimit/src/config" "github.com/envoyproxy/ratelimit/src/limiter" "github.com/envoyproxy/ratelimit/src/redis" + "github.com/envoyproxy/ratelimit/src/utils" stats "github.com/lyft/gostats" "math/rand" "github.com/envoyproxy/ratelimit/test/common" - mock_limiter "github.com/envoyproxy/ratelimit/test/mocks/limiter" mock_redis "github.com/envoyproxy/ratelimit/test/mocks/redis" + mock_utils "github.com/envoyproxy/ratelimit/test/mocks/utils" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" ) @@ -37,7 +39,7 @@ func testRedis(usePerSecondRedis bool) func(*testing.T) { client := mock_redis.NewMockClient(controller) perSecondClient := mock_redis.NewMockClient(controller) - timeSource := mock_limiter.NewMockTimeSource(controller) + timeSource := mock_utils.NewMockTimeSource(controller) var cache limiter.RateLimitCache if usePerSecondRedis { cache = redis.NewFixedRateLimitCacheImpl(client, perSecondClient, timeSource, rand.New(rand.NewSource(1)), 0, nil, 0.8) @@ -62,7 +64,7 @@ func testRedis(usePerSecondRedis bool) func(*testing.T) { limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 5, DurationUntilReset: redis.CalculateReset(limits[0].Limit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 5, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) @@ -86,7 +88,7 @@ func testRedis(usePerSecondRedis bool) func(*testing.T) { config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, "key2_value2_subkey2_subvalue2", statsStore)} assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: nil, LimitRemaining: 0}, - {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[1].Limit, LimitRemaining: 0, DurationUntilReset: redis.CalculateReset(limits[1].Limit, timeSource)}}, + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[1].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[1].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(1), limits[1].Stats.TotalHits.Value()) assert.Equal(uint64(1), limits[1].Stats.OverLimit.Value()) @@ -113,8 +115,8 @@ func testRedis(usePerSecondRedis bool) func(*testing.T) { config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_DAY, "key3_value3_subkey3_subvalue3", statsStore)} assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: redis.CalculateReset(limits[0].Limit, timeSource)}, - {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[1].Limit, LimitRemaining: 0, DurationUntilReset: redis.CalculateReset(limits[1].Limit, timeSource)}}, + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}, + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[1].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[1].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(1), limits[0].Stats.OverLimit.Value()) @@ -167,7 +169,7 @@ func TestOverLimitWithLocalCache(t *testing.T) { defer controller.Finish() client := mock_redis.NewMockClient(controller) - timeSource := mock_limiter.NewMockTimeSource(controller) + timeSource := mock_utils.NewMockTimeSource(controller) localCache := freecache.NewCache(100) cache := redis.NewFixedRateLimitCacheImpl(client, nil, timeSource, rand.New(rand.NewSource(1)), 0, localCache, 0.8) sink := &common.TestStatSink{} @@ -188,7 +190,7 @@ func TestOverLimitWithLocalCache(t *testing.T) { assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 4, DurationUntilReset: redis.CalculateReset(limits[0].Limit, timeSource)}}, + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 4, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) @@ -207,7 +209,7 @@ func TestOverLimitWithLocalCache(t *testing.T) { assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 2, DurationUntilReset: redis.CalculateReset(limits[0].Limit, timeSource)}}, + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 2, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(2), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) @@ -226,7 +228,7 @@ func TestOverLimitWithLocalCache(t *testing.T) { assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: redis.CalculateReset(limits[0].Limit, timeSource)}}, + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(1), limits[0].Stats.OverLimit.Value()) @@ -243,7 +245,7 @@ func TestOverLimitWithLocalCache(t *testing.T) { "EXPIRE", "domain_key4_value4_997200", int64(3600)).Times(0) assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: redis.CalculateReset(limits[0].Limit, timeSource)}}, + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(4), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(2), limits[0].Stats.OverLimit.Value()) @@ -260,7 +262,7 @@ func TestNearLimit(t *testing.T) { defer controller.Finish() client := mock_redis.NewMockClient(controller) - timeSource := mock_limiter.NewMockTimeSource(controller) + timeSource := mock_utils.NewMockTimeSource(controller) cache := redis.NewFixedRateLimitCacheImpl(client, nil, timeSource, rand.New(rand.NewSource(1)), 0, nil, 0.8) statsStore := stats.NewStore(stats.NewNullSink(), false) @@ -278,7 +280,7 @@ func TestNearLimit(t *testing.T) { assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 4, DurationUntilReset: redis.CalculateReset(limits[0].Limit, timeSource)}}, + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 4, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) @@ -293,7 +295,7 @@ func TestNearLimit(t *testing.T) { assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 2, DurationUntilReset: redis.CalculateReset(limits[0].Limit, timeSource)}}, + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 2, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(2), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) @@ -309,7 +311,7 @@ func TestNearLimit(t *testing.T) { assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: redis.CalculateReset(limits[0].Limit, timeSource)}}, + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(1), limits[0].Stats.OverLimit.Value()) @@ -326,7 +328,7 @@ func TestNearLimit(t *testing.T) { limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, "key5_value5", statsStore)} assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 15, DurationUntilReset: redis.CalculateReset(limits[0].Limit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 15, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) @@ -342,7 +344,7 @@ func TestNearLimit(t *testing.T) { limits = []*config.RateLimit{config.NewRateLimit(8, pb.RateLimitResponse_RateLimit_SECOND, "key6_value6", statsStore)} assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 1, DurationUntilReset: redis.CalculateReset(limits[0].Limit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 1, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(2), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) @@ -358,7 +360,7 @@ func TestNearLimit(t *testing.T) { limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, "key7_value7", statsStore)} assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 1, DurationUntilReset: redis.CalculateReset(limits[0].Limit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 1, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) @@ -374,7 +376,7 @@ func TestNearLimit(t *testing.T) { limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, "key8_value8", statsStore)} assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: redis.CalculateReset(limits[0].Limit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(2), limits[0].Stats.OverLimit.Value()) @@ -390,7 +392,7 @@ func TestNearLimit(t *testing.T) { limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, "key9_value9", statsStore)} assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: redis.CalculateReset(limits[0].Limit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(7), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(2), limits[0].Stats.OverLimit.Value()) @@ -406,7 +408,7 @@ func TestNearLimit(t *testing.T) { limits = []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key10_value10", statsStore)} assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: redis.CalculateReset(limits[0].Limit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(3), limits[0].Stats.OverLimit.Value()) @@ -419,8 +421,8 @@ func TestRedisWithJitter(t *testing.T) { defer controller.Finish() client := mock_redis.NewMockClient(controller) - timeSource := mock_limiter.NewMockTimeSource(controller) - jitterSource := mock_limiter.NewMockJitterRandSource(controller) + timeSource := mock_utils.NewMockTimeSource(controller) + jitterSource := mock_utils.NewMockJitterRandSource(controller) cache := redis.NewFixedRateLimitCacheImpl(client, nil, timeSource, rand.New(jitterSource), 3600, nil, 0.8) statsStore := stats.NewStore(stats.NewNullSink(), false) @@ -434,7 +436,7 @@ func TestRedisWithJitter(t *testing.T) { limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 5, DurationUntilReset: redis.CalculateReset(limits[0].Limit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 5, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value())