diff --git a/synapse/config/redis.py b/synapse/config/redis.py index 7f621272b8d6..cadf2b9d199c 100644 --- a/synapse/config/redis.py +++ b/synapse/config/redis.py @@ -40,3 +40,8 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: (s.get("host"), int(s.get("port"))) for s in redis_config.get("sentinels", []) ] + self.redis_use_ssl = redis_config.get('use_ssl', False) + self.redis_certificate = redis_config.get("certificate_file", None) + self.redis_private_key = redis_config.get("private_key_file", None) + self.redis_ca_file = redis_config.get("ca_file", None) + self.redis_ca_path = redis_config.get("ca_path", None) diff --git a/synapse/replication/tcp/redis/client_context_factory.py b/synapse/replication/tcp/redis/client_context_factory.py new file mode 100644 index 000000000000..14ae69f1c337 --- /dev/null +++ b/synapse/replication/tcp/redis/client_context_factory.py @@ -0,0 +1,17 @@ +from twisted.internet import ssl + +class ClientContextFactory(ssl.ClientContextFactory): + def __init__(self, redis_config): + self.redis_config = redis_config + + def getContext(self): + ctx = ssl.ClientContextFactory.getContext(self) + if (self.redis_config.redis_certificate): + ctx.use_certificate_file(self.redis_config.redis_certificate)) + if (self.redis_config.private_key): + ctx.use_privatekey_file(self.redis_config.private_key) + if (self.redis_config.ca_file): + ctx.load_verify_locations(cafile=self.redis_config.ca_file) + elif (self.redis_config.ca_path): + ctx.load_verify_locationa(capath=self.redis_config.ca_path) + return ctx \ No newline at end of file diff --git a/synapse/replication/tcp/redis/connection.py b/synapse/replication/tcp/redis/connection.py index 13e59f0c9a12..ebe203492113 100644 --- a/synapse/replication/tcp/redis/connection.py +++ b/synapse/replication/tcp/redis/connection.py @@ -29,6 +29,7 @@ from synapse.logging.context import make_deferred_yieldable from synapse.metrics.background_process_metrics import wrap_as_background_process +from synapse.replication.tcp.redis.client_context_factory import ClientContextFactory if TYPE_CHECKING: from synapse.server import HomeServer @@ -178,13 +179,24 @@ def __init__( factory.continueTrying = reconnect reactor = hs.get_reactor() - reactor.connectTCP( - host, - port, - factory, - timeout=30, - bindAddress=None, - ) + if (hs.config.redis.redis_use_ssl): + ssl_context_factory = ClientContextFactory(hs.config.redis) + reactor.connectSSL( + host, + port, + factory, + ssl_context_factory, + timeout=30, + bindAddress=None, + ) + else: + reactor.connectTCP( + host, + port, + factory, + timeout=30, + bindAddress=None, + ) self.handler = factory.handler @@ -222,7 +234,11 @@ def __init__( ): self.service_name = service_name self.password = password - self.sentinel = Sentinel(sentinels) + + client_context_factory= None + if (hs.config.redis.redis_use_ssl): + client_context_factory=ClientContextFactory(hs.config.redis) + self.sentinel = Sentinel(sentinels, ssl_client_context_factory=client_context_factory) self.dbid = dbid self.replyTimeout = replyTimeout diff --git a/synapse/replication/tcp/redis/factory.py b/synapse/replication/tcp/redis/factory.py index 74dff38950e5..6b6b59e00444 100644 --- a/synapse/replication/tcp/redis/factory.py +++ b/synapse/replication/tcp/redis/factory.py @@ -17,6 +17,7 @@ from txredisapi import RedisFactory, Sentinel # type: ignore[attr-defined] from synapse.config.redis import RedisConfig +from synapse.replication.tcp.redis.client_context_factory import ClientContextFactory from synapse.replication.tcp.redis.connection import ( IRedisConnection, RedisConnection, @@ -75,11 +76,23 @@ def get_replication_factory( ) reactor = hs.get_reactor() - reactor.connectTCP( - hs.config.redis.redis_host, - hs.config.redis.redis_port, - factory, - timeout=30, - bindAddress=None, - ) + + if (hs.config.redis.redis_use_ssl): + ssl_context_factory = ClientContextFactory(hs.config.redis) + reactor.connectSSL( + hs.config.redis.redis_host, + hs.config.redis.redis_port, + factory, + ssl_context_factory, + timeout=30, + bindAddress=None, + ) + else: + reactor.connectTCP( + hs.config.redis.redis_host, + hs.config.redis.redis_port, + factory, + timeout=30, + bindAddress=None, + ) return factory