diff --git a/broadcaster/_backends/redis.py b/broadcaster/_backends/redis.py index 2c4aeba..7074d35 100644 --- a/broadcaster/_backends/redis.py +++ b/broadcaster/_backends/redis.py @@ -1,10 +1,12 @@ -import typing +from typing import Any from urllib.parse import urlparse -import asyncio_redis +import redis.asyncio as redis +from redis.asyncio.client import PubSub from .._base import Event from .base import BroadcastBackend +import asyncio class RedisBackend(BroadcastBackend): @@ -13,26 +15,47 @@ def __init__(self, url: str): self._host = parsed_url.hostname or "localhost" self._port = parsed_url.port or 6379 self._password = parsed_url.password or None + self._ssl = parsed_url.scheme == "rediss" + self.kwargs = { + "host": self._host, + "port": self._port, + "password": self._password, + "ssl": self._ssl, + } + + self._sub_conn: PubSub | None = None + self._pub_conn: PubSub | None = None async def connect(self) -> None: - kwargs = {"host": self._host, "port": self._port, "password": self._password} - self._pub_conn = await asyncio_redis.Connection.create(**kwargs) - self._sub_conn = await asyncio_redis.Connection.create(**kwargs) - self._subscriber = await self._sub_conn.start_subscribe() + self._pub_conn = redis.Redis(**self.kwargs).pubsub() + self._sub_conn = redis.Redis(**self.kwargs).pubsub() async def disconnect(self) -> None: - self._pub_conn.close() - self._sub_conn.close() + await self._pub_conn.close() + await self._sub_conn.close() async def subscribe(self, channel: str) -> None: - await self._subscriber.subscribe([channel]) + await self._sub_conn.subscribe(channel) async def unsubscribe(self, channel: str) -> None: - await self._subscriber.unsubscribe([channel]) + await self._sub_conn.unsubscribe(channel) - async def publish(self, channel: str, message: typing.Any) -> None: - await self._pub_conn.publish(channel, message) + async def publish(self, channel: str, message: Any) -> None: + try: + await self._pub_conn.execute_command("PUBLISH", channel, message) + except (redis.ConnectionError, redis.TimeoutError): + await asyncio.sleep(1) + self._pub_conn = redis.Redis(**self.kwargs).pubsub() + await self.publish(channel, message) async def next_published(self) -> Event: - message = await self._subscriber.next_published() - return Event(channel=message.channel, message=message.value) + message = None + while not message: + message = await self._sub_conn.get_message( + ignore_subscribe_messages=True, timeout=None + ) + event = Event( + channel=message["channel"].decode(), + message=message["data"].decode(), + ) + return event diff --git a/broadcaster/_base.py b/broadcaster/_base.py index 4de1417..b570188 100644 --- a/broadcaster/_base.py +++ b/broadcaster/_base.py @@ -1,6 +1,6 @@ import asyncio from contextlib import asynccontextmanager -from typing import Any, AsyncGenerator, AsyncIterator, Dict, Optional +from typing import Any, AsyncGenerator, Optional, AsyncIterator from urllib.parse import urlparse @@ -26,13 +26,13 @@ class Unsubscribed(Exception): class Broadcast: def __init__(self, url: str): - from broadcaster._backends.base import BroadcastBackend + from ._backends.base import BroadcastBackend parsed_url = urlparse(url) self._backend: BroadcastBackend - self._subscribers: Dict[str, Any] = {} + self._subscribers: dict[str, set[asyncio.Queue]] = {} if parsed_url.scheme in ("redis", "rediss"): - from broadcaster._backends.redis import RedisBackend + from ._backends.redis import RedisBackend self._backend = RedisBackend(url) @@ -51,6 +51,8 @@ def __init__(self, url: str): self._backend = MemoryBackend(url) + self._listener_task: asyncio.Task | None = None + async def __aenter__(self) -> "Broadcast": await self.connect() return self @@ -60,13 +62,13 @@ async def __aexit__(self, *args: Any, **kwargs: Any) -> None: async def connect(self) -> None: await self._backend.connect() - self._listener_task = asyncio.create_task(self._listener()) async def disconnect(self) -> None: - if self._listener_task.done(): - self._listener_task.result() - else: - self._listener_task.cancel() + if self._listener_task: + if self._listener_task.done(): + self._listener_task.result() + else: + self._listener_task.cancel() await self._backend.disconnect() async def _listener(self) -> None: @@ -81,10 +83,11 @@ async def publish(self, channel: str, message: Any) -> None: @asynccontextmanager async def subscribe(self, channel: str) -> AsyncIterator["Subscriber"]: queue: asyncio.Queue = asyncio.Queue() - try: if not self._subscribers.get(channel): await self._backend.subscribe(channel) + if not self._listener_task: + self._listener_task = asyncio.create_task(self._listener()) self._subscribers[channel] = set([queue]) else: self._subscribers[channel].add(queue)