diff --git a/README.md b/README.md index 4b263195a..8bf562863 100644 --- a/README.md +++ b/README.md @@ -380,6 +380,8 @@ Ratelimit uses Redis as its caching layer. Ratelimit supports two operation mode 1. One Redis server for all limits. 1. Two Redis instances: one for per second limits and another one for all other limits. +1. `REDIS_DB` and `REDIS_PERSECOND_DB` could be used to configure redis database number + ## One Redis Instance To configure one Redis instance use the following environment variables: @@ -387,6 +389,7 @@ To configure one Redis instance use the following environment variables: 1. `REDIS_SOCKET_TYPE` 1. `REDIS_URL` 1. `REDIS_POOL_SIZE` +1. `REDIS_DB` This setup will use the same Redis server for all limits. @@ -397,10 +400,12 @@ To configure two Redis instances use the following environment variables: 1. `REDIS_SOCKET_TYPE` 1. `REDIS_URL` 1. `REDIS_POOL_SIZE` +1. `REDIS_DB` 1. `REDIS_PERSECOND`: set this to `"true"`. 1. `REDIS_PERSECOND_SOCKET_TYPE` 1. `REDIS_PERSECOND_URL` 1. `REDIS_PERSECOND_POOL_SIZE` +1. `REDIS_PERSECOND_DB` This setup will use the Redis server configured with the `_PERSECOND_` vars for per second limits, and the other Redis server for all other limits. diff --git a/src/redis/driver_impl.go b/src/redis/driver_impl.go index eac51fbf5..032560710 100644 --- a/src/redis/driver_impl.go +++ b/src/redis/driver_impl.go @@ -64,15 +64,41 @@ func (this *poolImpl) Put(c Connection) { } } -func NewPoolImpl(scope stats.Scope, socketType string, url string, poolSize int) Pool { +type DialFunc func(*redis.Client) error + +func NewPoolImpl(scope stats.Scope, socketType string, url string, poolSize int, dfs ...DialFunc) Pool { logger.Warnf("connecting to redis on %s %s with pool size %d", socketType, url, poolSize) - pool, err := pool.New(socketType, url, poolSize) + + df := redis.Dial + if len(dfs) != 0 { + df = func(network, addr string) (*redis.Client, error) { + c, err := redis.Dial(network, addr) + if err != nil { + return nil, err + } + for _, f := range dfs { + dialErr := f(c) + if dialErr != nil { + return nil, dialErr + } + } + return c, nil + } + } + pool, err := pool.NewCustom(socketType, url, poolSize, df) checkError(err) return &poolImpl{ pool: pool, stats: newPoolStats(scope)} } +func WithDatabase(db int) DialFunc { + return func(c *redis.Client) error { + logger.Warnf("connecting to redis database %d", db) + return c.Cmd("select", db).Err + } +} + func (this *connectionImpl) PipeAppend(cmd string, args ...interface{}) { this.client.PipeAppend(cmd, args...) this.pending++ diff --git a/src/service_cmd/runner/runner.go b/src/service_cmd/runner/runner.go index cd8535176..1a45239b0 100644 --- a/src/service_cmd/runner/runner.go +++ b/src/service_cmd/runner/runner.go @@ -29,16 +29,39 @@ func Run() { srv := server.NewServer("ratelimit", settings.GrpcUnaryInterceptor(nil)) + var pool redis.Pool + var dials []redis.DialFunc + if s.RedisDatabase != 0 { + dials = append(dials, redis.WithDatabase(s.RedisDatabase)) + } + pool = redis.NewPoolImpl( + srv.Scope().Scope("redis_pool"), + s.RedisSocketType, + s.RedisUrl, + s.RedisPoolSize, + dials..., + ) + var perSecondPool redis.Pool if s.RedisPerSecond { - perSecondPool = redis.NewPoolImpl(srv.Scope().Scope("redis_per_second_pool"), s.RedisPerSecondSocketType, s.RedisPerSecondUrl, s.RedisPerSecondPoolSize) + var perSecondDials []redis.DialFunc + if s.RedisPerSecondDatabase != 0 { + perSecondDials = append(perSecondDials, redis.WithDatabase(s.RedisPerSecondDatabase)) + } + perSecondPool = redis.NewPoolImpl( + srv.Scope().Scope("redis_per_second_pool"), + s.RedisPerSecondSocketType, + s.RedisPerSecondUrl, + s.RedisPerSecondPoolSize, + perSecondDials..., + ) } service := ratelimit.NewService( srv.Runtime(), redis.NewRateLimitCacheImpl( - redis.NewPoolImpl(srv.Scope().Scope("redis_pool"), s.RedisSocketType, s.RedisUrl, s.RedisPoolSize), + pool, perSecondPool, redis.NewTimeSourceImpl(), rand.New(redis.NewLockedSource(time.Now().Unix())), diff --git a/src/settings/settings.go b/src/settings/settings.go index 890955477..7ff42cf9d 100644 --- a/src/settings/settings.go +++ b/src/settings/settings.go @@ -21,10 +21,12 @@ type Settings struct { LogLevel string `envconfig:"LOG_LEVEL" default:"WARN"` RedisSocketType string `envconfig:"REDIS_SOCKET_TYPE" default:"unix"` RedisUrl string `envconfig:"REDIS_URL" default:"/var/run/nutcracker/ratelimit.sock"` + RedisDatabase int `envconfig:"REDIS_DB" default:0` RedisPoolSize int `envconfig:"REDIS_POOL_SIZE" default:"10"` RedisPerSecond bool `envconfig:"REDIS_PERSECOND" default:"false"` RedisPerSecondSocketType string `envconfig:"REDIS_PERSECOND_SOCKET_TYPE" default:"unix"` RedisPerSecondUrl string `envconfig:"REDIS_PERSECOND_URL" default:"/var/run/nutcracker/ratelimitpersecond.sock"` + RedisPerSecondDatabase int `envconfig:"REDIS_PERSECOND_DB" default:0` RedisPerSecondPoolSize int `envconfig:"REDIS_PERSECOND_POOL_SIZE" default:"10"` ExpirationJitterMaxSeconds int64 `envconfig:"EXPIRATION_JITTER_MAX_SECONDS" default:"300"` } diff --git a/test/integration/integration_test.go b/test/integration/integration_test.go index c48387f3e..899ab5d77 100644 --- a/test/integration/integration_test.go +++ b/test/integration/integration_test.go @@ -46,6 +46,19 @@ func TestBasicConfig(t *testing.T) { t.Run("WithPerSecondRedis", testBasicConfig("8085", "true")) } +func TestWithDbNumber(t *testing.T) { + // Use same redis configuration + // If database number configuration doesn't work it will lead to key collisions + t.Run("WithoutPerSecondRedisDbNumber", testDbNumber("8087", "false")) + t.Run("WithPerSecondRedisDbNumber", testDbNumber("8089", "true")) +} + +func testDbNumber(grpcPort, perSecond string) func(*testing.T) { + os.Setenv("REDIS_DB", "10") + os.Setenv("REDIS_PERSECOND_DB", "10") + return testBasicConfig(grpcPort, perSecond) +} + func testBasicConfig(grpcPort, perSecond string) func(*testing.T) { return func(t *testing.T) { os.Setenv("REDIS_PERSECOND", perSecond)