diff --git a/src/socketio/async_redis_manager.py b/src/socketio/async_redis_manager.py index b8ac4a0f..1cb7fb9d 100644 --- a/src/socketio/async_redis_manager.py +++ b/src/socketio/async_redis_manager.py @@ -61,7 +61,9 @@ def __init__(self, url='redis://localhost:6379/0', channel='socketio', super().__init__(channel=channel, write_only=write_only, logger=logger) self.redis_url = url self.redis_options = redis_options or {} - self._redis_connect() + self.connected = False + self.redis = None + self.pubsub = None def _get_redis_module_and_error(self): parsed_url = urlparse(self.redis_url) @@ -106,23 +108,23 @@ def _redis_connect(self): self.redis = module.Redis.from_url(self.redis_url, **self.redis_options) self.pubsub = self.redis.pubsub(ignore_subscribe_messages=True) + self.connected = True async def _publish(self, data): # pragma: no cover - retry = True _, error = self._get_redis_module_and_error() - while True: + for retries_left in range(1, -1, -1): # 2 attempts try: - if not retry: + if not self.connected: self._redis_connect() return await self.redis.publish( self.channel, json.dumps(data)) except error as exc: - if retry: + if retries_left > 0: self._get_logger().error( 'Cannot publish to redis... ' 'retrying', extra={"redis_exception": str(exc)}) - retry = False + self.connected = False else: self._get_logger().error( 'Cannot publish to redis... ' @@ -133,11 +135,10 @@ async def _publish(self, data): # pragma: no cover async def _redis_listen_with_retries(self): # pragma: no cover retry_sleep = 1 - connect = False _, error = self._get_redis_module_and_error() while True: try: - if connect: + if not self.connected: self._redis_connect() await self.pubsub.subscribe(self.channel) retry_sleep = 1 @@ -148,7 +149,7 @@ async def _redis_listen_with_retries(self): # pragma: no cover 'retrying in ' f'{retry_sleep} secs', extra={"redis_exception": str(exc)}) - connect = True + self.connected = False await asyncio.sleep(retry_sleep) retry_sleep *= 2 if retry_sleep > 60: @@ -156,7 +157,6 @@ async def _redis_listen_with_retries(self): # pragma: no cover async def _listen(self): # pragma: no cover channel = self.channel.encode('utf-8') - await self.pubsub.subscribe(self.channel) async for message in self._redis_listen_with_retries(): if message['channel'] == channel and \ message['type'] == 'message' and 'data' in message: diff --git a/src/socketio/redis_manager.py b/src/socketio/redis_manager.py index 5e58ef41..827918b7 100644 --- a/src/socketio/redis_manager.py +++ b/src/socketio/redis_manager.py @@ -83,7 +83,9 @@ def __init__(self, url='redis://localhost:6379/0', channel='socketio', super().__init__(channel=channel, write_only=write_only, logger=logger) self.redis_url = url self.redis_options = redis_options or {} - self._redis_connect() + self.connected = False + self.redis = None + self.pubsub = None def initialize(self): # pragma: no cover super().initialize() @@ -143,22 +145,22 @@ def _redis_connect(self): self.redis = module.Redis.from_url(self.redis_url, **self.redis_options) self.pubsub = self.redis.pubsub(ignore_subscribe_messages=True) + self.connected = True def _publish(self, data): # pragma: no cover - retry = True _, error = self._get_redis_module_and_error() - while True: + for retries_left in range(1, -1, -1): # 2 attempts try: - if not retry: + if not self.connected: self._redis_connect() return self.redis.publish(self.channel, json.dumps(data)) except error as exc: - if retry: + if retries_left > 0: logger.error( 'Cannot publish to redis... retrying', extra={"redis_exception": str(exc)} ) - retry = False + self.connected = False else: logger.error( 'Cannot publish to redis... giving up', @@ -168,11 +170,10 @@ def _publish(self, data): # pragma: no cover def _redis_listen_with_retries(self): # pragma: no cover retry_sleep = 1 - connect = False _, error = self._get_redis_module_and_error() while True: try: - if connect: + if not self.connected: self._redis_connect() self.pubsub.subscribe(self.channel) retry_sleep = 1 @@ -181,7 +182,7 @@ def _redis_listen_with_retries(self): # pragma: no cover logger.error('Cannot receive from redis... ' f'retrying in {retry_sleep} secs', extra={"redis_exception": str(exc)}) - connect = True + self.connected = False time.sleep(retry_sleep) retry_sleep *= 2 if retry_sleep > 60: @@ -189,7 +190,6 @@ def _redis_listen_with_retries(self): # pragma: no cover def _listen(self): # pragma: no cover channel = self.channel.encode('utf-8') - self.pubsub.subscribe(self.channel) for message in self._redis_listen_with_retries(): if message['channel'] == channel and \ message['type'] == 'message' and 'data' in message: diff --git a/tests/async/test_redis_manager.py b/tests/async/test_redis_manager.py index 01c0c375..046c26d2 100644 --- a/tests/async/test_redis_manager.py +++ b/tests/async/test_redis_manager.py @@ -12,7 +12,7 @@ def test_redis_not_installed(self): async_redis_manager.aioredis = None with pytest.raises(RuntimeError): - AsyncRedisManager('redis://') + AsyncRedisManager('redis://')._redis_connect() assert AsyncRedisManager('unix:///var/sock/redis.sock') is not None async_redis_manager.aioredis = saved_redis @@ -22,7 +22,7 @@ def test_valkey_not_installed(self): async_redis_manager.aiovalkey = None with pytest.raises(RuntimeError): - AsyncRedisManager('valkey://') + AsyncRedisManager('valkey://')._redis_connect() assert AsyncRedisManager('unix:///var/sock/redis.sock') is not None async_redis_manager.aiovalkey = saved_valkey @@ -34,18 +34,18 @@ def test_redis_valkey_not_installed(self): async_redis_manager.aiovalkey = None with pytest.raises(RuntimeError): - AsyncRedisManager('redis://') + AsyncRedisManager('redis://')._redis_connect() with pytest.raises(RuntimeError): - AsyncRedisManager('valkey://') + AsyncRedisManager('valkey://')._redis_connect() with pytest.raises(RuntimeError): - AsyncRedisManager('unix:///var/sock/redis.sock') + AsyncRedisManager('unix:///var/sock/redis.sock')._redis_connect() async_redis_manager.aioredis = saved_redis async_redis_manager.aiovalkey = saved_valkey def test_bad_url(self): with pytest.raises(ValueError): - AsyncRedisManager('http://localhost:6379') + AsyncRedisManager('http://localhost:6379')._redis_connect() def test_redis_connect(self): urls = [ @@ -72,6 +72,8 @@ def test_redis_connect(self): ] for url in urls: c = AsyncRedisManager(url) + assert c.redis is None + c._redis_connect() assert isinstance(c.redis, redis.asyncio.Redis) def test_valkey_connect(self): @@ -102,6 +104,8 @@ def test_valkey_connect(self): ] for url in urls: c = AsyncRedisManager(url) + assert c.redis is None + c._redis_connect() assert isinstance(c.redis, valkey.asyncio.Valkey) async_redis_manager.aioredis = saved_redis diff --git a/tests/common/test_redis_manager.py b/tests/common/test_redis_manager.py index 3beadf3b..dbc927e4 100644 --- a/tests/common/test_redis_manager.py +++ b/tests/common/test_redis_manager.py @@ -12,7 +12,7 @@ def test_redis_not_installed(self): redis_manager.redis = None with pytest.raises(RuntimeError): - RedisManager('redis://') + RedisManager('redis://')._redis_connect() assert RedisManager('unix:///var/sock/redis.sock') is not None redis_manager.redis = saved_redis @@ -22,7 +22,7 @@ def test_valkey_not_installed(self): redis_manager.valkey = None with pytest.raises(RuntimeError): - RedisManager('valkey://') + RedisManager('valkey://')._redis_connect() assert RedisManager('unix:///var/sock/redis.sock') is not None redis_manager.valkey = saved_valkey @@ -34,18 +34,18 @@ def test_redis_valkey_not_installed(self): redis_manager.valkey = None with pytest.raises(RuntimeError): - RedisManager('redis://') + RedisManager('redis://')._redis_connect() with pytest.raises(RuntimeError): - RedisManager('valkey://') + RedisManager('valkey://')._redis_connect() with pytest.raises(RuntimeError): - RedisManager('unix:///var/sock/redis.sock') + RedisManager('unix:///var/sock/redis.sock')._redis_connect() redis_manager.redis = saved_redis redis_manager.valkey = saved_valkey def test_bad_url(self): with pytest.raises(ValueError): - RedisManager('http://localhost:6379') + RedisManager('http://localhost:6379')._redis_connect() def test_redis_connect(self): urls = [ @@ -72,6 +72,8 @@ def test_redis_connect(self): ] for url in urls: c = RedisManager(url) + assert c.redis is None + c._redis_connect() assert isinstance(c.redis, redis.Redis) def test_valkey_connect(self): @@ -102,6 +104,8 @@ def test_valkey_connect(self): ] for url in urls: c = RedisManager(url) + assert c.redis is None + c._redis_connect() assert isinstance(c.redis, valkey.Valkey) redis_manager.redis = saved_redis