diff --git a/broadcaster/_backends/redis.py b/broadcaster/_backends/redis.py index 096a033..8b4553e 100644 --- a/broadcaster/_backends/redis.py +++ b/broadcaster/_backends/redis.py @@ -1,34 +1,36 @@ -import asyncio_redis +import aioredis +import asyncio import typing -from urllib.parse import urlparse + from .base import BroadcastBackend from .._base import Event class RedisBackend(BroadcastBackend): def __init__(self, url: str): - parsed_url = urlparse(url) - self._host = parsed_url.hostname or "localhost" - self._port = parsed_url.port or 6379 + self.conn_url = url + self.channel = None async def connect(self) -> None: - self._pub_conn = await asyncio_redis.Connection.create(self._host, self._port) - self._sub_conn = await asyncio_redis.Connection.create(self._host, self._port) - self._subscriber = await self._sub_conn.start_subscribe() + loop = asyncio.get_event_loop() + self._pub_conn = await aioredis.create_redis(self.conn_url, loop=loop) + self._sub_conn = await aioredis.create_redis(self.conn_url, loop=loop) async def disconnect(self) -> None: self._pub_conn.close() self._sub_conn.close() async def subscribe(self, channel: str) -> None: - await self._subscriber.subscribe([channel]) + channel = await self._sub_conn.subscribe(channel) + self.channel = channel[0] async def unsubscribe(self, channel: str) -> None: - await self._subscriber.unsubscribe([channel]) + await self._sub_conn.unsubscribe(channel) async def publish(self, channel: str, message: typing.Any) -> None: - await self._pub_conn.publish(channel, message) + await self._pub_conn.publish_json(channel, message) async def next_published(self) -> Event: - message = await self._subscriber.next_published() - return Event(channel=message.channel, message=message.value) + while (await self.channel.wait_message()): + message = await self.channel.get_json() + return Event(channel=self.channel.name.decode("utf8"), message=message) diff --git a/broadcaster/_base.py b/broadcaster/_base.py index 44ec030..73621f7 100644 --- a/broadcaster/_base.py +++ b/broadcaster/_base.py @@ -53,7 +53,6 @@ async def __aexit__(self, *args, **kwargs) -> None: async def connect(self) -> None: await self._backend.connect() - self._listener_task = asyncio.create_task(self._listener()) async def disconnect(self) -> None: if self._listener_task.done(): @@ -65,8 +64,9 @@ async def disconnect(self) -> None: async def _listener(self) -> None: while True: event = await self._backend.next_published() - for queue in list(self._subscribers.get(event.channel, [])): - await queue.put(event) + if event: + for queue in list(self._subscribers.get(event.channel, [])): + await queue.put(event) async def publish(self, channel: str, message: typing.Any) -> None: await self._backend.publish(channel, message) @@ -81,6 +81,7 @@ async def subscribe(self, channel: str) -> 'Subscriber': self._subscribers[channel] = set([queue]) else: self._subscribers[channel].add(queue) + self._listener_task = asyncio.create_task(self._listener()) yield Subscriber(queue) diff --git a/setup.py b/setup.py index 61793bb..04a3985 100644 --- a/setup.py +++ b/setup.py @@ -49,7 +49,7 @@ def get_packages(package): # package_data={"starlette": ["py.typed"]}, # data_files=[("", ["LICENSE.md"])], extras_require={ - "redis": ["asyncio-redis"], + "redis": ["aioredis"], "postgres": ["asyncpg"], "kafka": ["aiokafka"] },