From cc3e62f80dc162ee742333547e9f75dbaf7a4622 Mon Sep 17 00:00:00 2001 From: tangxinfa Date: Thu, 26 Dec 2019 13:35:23 +0800 Subject: [PATCH] fix: support auth without tls Signed-off-by: tangxinfa --- .travis.yml | 2 ++ README.md | 4 ++-- src/redis/driver_impl.go | 24 +++++++++++------------- src/service_cmd/runner/runner.go | 13 ++----------- test/integration/integration_test.go | 23 +++++++++++++++++++++++ 5 files changed, 40 insertions(+), 26 deletions(-) diff --git a/.travis.yml b/.travis.yml index ea43ffb90..02bdf2e5b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,4 +8,6 @@ before_script: - redis-server --port 6380 & - redis-server --port 6381 --requirepass password123 & - redis-server --port 6382 --requirepass password123 & +- redis-server --port 6384 --requirepass password123 & +- redis-server --port 6385 --requirepass password123 & script: make check_format tests diff --git a/README.md b/README.md index f28cadedd..67290c441 100644 --- a/README.md +++ b/README.md @@ -387,10 +387,10 @@ Ratelimit uses Redis as its caching layer. Ratelimit supports two operation mode 1. One Redis server for all limits. 1. Two Redis instances: one for per second limits and another one for all other limits. -As well Ratelimit supports TLS connections and authentication over TLS connections. These can be configured using the following environment variables: +As well Ratelimit supports TLS connections and authentication. These can be configured using the following environment variables: 1. `REDIS_TLS` & `REDIS_PERSECOND_TLS`: set to `"true"` to enable a TLS connection for the specific connection type. -1. `REDIS_AUTH` & `REDIS_PERSECOND_AUTH`: set to `"password"` to enable authentication to the redis host. This requires TLS to be enabled as well for the specific connection. +1. `REDIS_AUTH` & `REDIS_PERSECOND_AUTH`: set to `"password"` to enable authentication to the redis host. ## One Redis Instance diff --git a/src/redis/driver_impl.go b/src/redis/driver_impl.go index 15c56f0f0..5405bc8c5 100644 --- a/src/redis/driver_impl.go +++ b/src/redis/driver_impl.go @@ -2,6 +2,7 @@ package redis import ( "crypto/tls" + "net" stats "github.com/lyft/gostats" "github.com/lyft/ratelimit/src/assert" @@ -66,19 +67,16 @@ func (this *poolImpl) Put(c Connection) { } } -func NewPoolImpl(scope stats.Scope, socketType string, url string, poolSize int) Pool { - logger.Warnf("connecting to redis on %s %s with pool size %d", socketType, url, poolSize) - pool, err := pool.New(socketType, url, poolSize) - checkError(err) - return &poolImpl{ - pool: pool, - stats: newPoolStats(scope)} -} - -func NewAuthTLSPoolImpl(scope stats.Scope, auth string, url string, poolSize int) Pool { - logger.Warnf("connecting to redis on tls %s with pool size %d", url, poolSize) +func NewPoolImpl(scope stats.Scope, useTls bool, auth string, url string, poolSize int) Pool { + logger.Warnf("connecting to redis on %s with pool size %d", url, poolSize) df := func(network, addr string) (*redis.Client, error) { - conn, err := tls.Dial("tcp", addr, &tls.Config{}) + var conn net.Conn + var err error + if useTls { + conn, err = tls.Dial("tcp", addr, &tls.Config{}) + } else { + conn, err = net.Dial("tcp", addr) + } if err != nil { return nil, err } @@ -88,7 +86,7 @@ func NewAuthTLSPoolImpl(scope stats.Scope, auth string, url string, poolSize int return nil, err } if auth != "" { - logger.Warnf("enabling authentication to redis on tls %s", url) + logger.Warnf("enabling authentication to redis on %s", url) if err = client.Cmd("AUTH", auth).Err; err != nil { client.Close() return nil, err diff --git a/src/service_cmd/runner/runner.go b/src/service_cmd/runner/runner.go index 2c0d8f83e..30cdde735 100644 --- a/src/service_cmd/runner/runner.go +++ b/src/service_cmd/runner/runner.go @@ -47,19 +47,10 @@ func (runner *Runner) Run() { var perSecondPool redis.Pool if s.RedisPerSecond { - if s.RedisPerSecondAuth != "" || s.RedisPerSecondTls { - perSecondPool = redis.NewAuthTLSPoolImpl(srv.Scope().Scope("redis_per_second_pool"), s.RedisPerSecondAuth, s.RedisPerSecondUrl, s.RedisPerSecondPoolSize) - } else { - perSecondPool = redis.NewPoolImpl(srv.Scope().Scope("redis_per_second_pool"), s.RedisSocketType, s.RedisPerSecondUrl, s.RedisPerSecondPoolSize) - } - + perSecondPool = redis.NewPoolImpl(srv.Scope().Scope("redis_per_second_pool"), s.RedisPerSecondTls, s.RedisPerSecondAuth, s.RedisPerSecondUrl, s.RedisPerSecondPoolSize) } var otherPool redis.Pool - if s.RedisAuth != "" || s.RedisTls { - otherPool = redis.NewAuthTLSPoolImpl(srv.Scope().Scope("redis_pool"), s.RedisAuth, s.RedisUrl, s.RedisPoolSize) - } else { - otherPool = redis.NewPoolImpl(srv.Scope().Scope("redis_pool"), s.RedisSocketType, s.RedisUrl, s.RedisPoolSize) - } + otherPool = redis.NewPoolImpl(srv.Scope().Scope("redis_pool"), s.RedisTls, s.RedisAuth, s.RedisUrl, s.RedisPoolSize) var localCache *freecache.Cache if s.LocalCacheSizeInBytes != 0 { diff --git a/test/integration/integration_test.go b/test/integration/integration_test.go index 1965e3466..3f083dadd 100644 --- a/test/integration/integration_test.go +++ b/test/integration/integration_test.go @@ -57,22 +57,43 @@ func TestBasicTLSConfig(t *testing.T) { t.Run("WithPerSecondRedisTLSWithLocalCache", testBasicConfigAuthTLS("18089", "true", "1000")) } +func TestBasicAuthConfig(t *testing.T) { + t.Run("WithoutPerSecondRedisAuth", testBasicConfigAuth("8091", "false", "0")) + t.Run("WithPerSecondRedisAuth", testBasicConfigAuth("8093", "true", "0")) + t.Run("WithoutPerSecondRedisAuthWithLocalCache", testBasicConfigAuth("18091", "false", "1000")) + t.Run("WithPerSecondRedisAuthWithLocalCache", testBasicConfigAuth("18093", "true", "1000")) +} + func testBasicConfigAuthTLS(grpcPort, perSecond string, local_cache_size string) func(*testing.T) { os.Setenv("REDIS_PERSECOND_URL", "localhost:16382") os.Setenv("REDIS_URL", "localhost:16381") os.Setenv("REDIS_AUTH", "password123") + os.Setenv("REDIS_TLS", "true") os.Setenv("REDIS_PERSECOND_AUTH", "password123") + os.Setenv("REDIS_PERSECOND_TLS", "true") return testBasicBaseConfig(grpcPort, perSecond, local_cache_size) } func testBasicConfig(grpcPort, perSecond string, local_cache_size string) func(*testing.T) { os.Setenv("REDIS_PERSECOND_URL", "localhost:6380") os.Setenv("REDIS_URL", "localhost:6379") + os.Setenv("REDIS_AUTH", "") os.Setenv("REDIS_TLS", "false") + os.Setenv("REDIS_PERSECOND_AUTH", "") os.Setenv("REDIS_PERSECOND_TLS", "false") return testBasicBaseConfig(grpcPort, perSecond, local_cache_size) } +func testBasicConfigAuth(grpcPort, perSecond string, local_cache_size string) func(*testing.T) { + os.Setenv("REDIS_PERSECOND_URL", "localhost:6385") + os.Setenv("REDIS_URL", "localhost:6384") + os.Setenv("REDIS_TLS", "false") + os.Setenv("REDIS_AUTH", "password123") + os.Setenv("REDIS_PERSECOND_TLS", "false") + os.Setenv("REDIS_PERSECOND_AUTH", "password123") + return testBasicBaseConfig(grpcPort, perSecond, local_cache_size) +} + func getCacheKey(cacheKey string, enableLocalCache bool) string { if enableLocalCache { return cacheKey + "_local" @@ -214,7 +235,9 @@ func testBasicConfigLegacy(local_cache_size string) func(*testing.T) { os.Setenv("REDIS_PERSECOND_URL", "localhost:6380") os.Setenv("REDIS_URL", "localhost:6379") os.Setenv("REDIS_TLS", "false") + os.Setenv("REDIS_AUTH", "") os.Setenv("REDIS_PERSECOND_TLS", "false") + os.Setenv("REDIS_PERSECOND_AUTH", "") os.Setenv("LOCAL_CACHE_SIZE", local_cache_size) local_cache_size_val, _ := strconv.Atoi(local_cache_size) enable_local_cache := local_cache_size_val > 0