diff --git a/broadcaster/_backends/redis.py b/broadcaster/_backends/redis.py index 2c4aeba..8f15e84 100644 --- a/broadcaster/_backends/redis.py +++ b/broadcaster/_backends/redis.py @@ -1,7 +1,8 @@ -import typing +from typing import Any from urllib.parse import urlparse -import asyncio_redis +import redis.asyncio as redis +from redis.asyncio.client import PubSub from .._base import Event from .base import BroadcastBackend @@ -14,25 +15,34 @@ def __init__(self, url: str): self._port = parsed_url.port or 6379 self._password = parsed_url.password or None + self._sub_conn: PubSub | None = None + self._pub_conn: PubSub | None = None + async def connect(self) -> None: kwargs = {"host": self._host, "port": self._port, "password": self._password} - self._pub_conn = await asyncio_redis.Connection.create(**kwargs) - self._sub_conn = await asyncio_redis.Connection.create(**kwargs) - self._subscriber = await self._sub_conn.start_subscribe() + self._pub_conn = redis.Redis(**kwargs).pubsub() + self._sub_conn = redis.Redis(**kwargs).pubsub() async def disconnect(self) -> None: - self._pub_conn.close() - self._sub_conn.close() + await self._pub_conn.close() + await self._sub_conn.close() async def subscribe(self, channel: str) -> None: - await self._subscriber.subscribe([channel]) + await self._sub_conn.subscribe(channel) async def unsubscribe(self, channel: str) -> None: - await self._subscriber.unsubscribe([channel]) + await self._sub_conn.unsubscribe(channel) - async def publish(self, channel: str, message: typing.Any) -> None: - await self._pub_conn.publish(channel, message) + async def publish(self, channel: str, message: Any) -> None: + await self._pub_conn.execute_command("PUBLISH", channel, message) async def next_published(self) -> Event: - message = await self._subscriber.next_published() - return Event(channel=message.channel, message=message.value) + message = None + # get_message with timeout=None can return None + while not message: + # + message = await self._sub_conn.get_message(timeout=None) + return Event( + channel=message["channel"].decode(), + message=message["data"].decode(), + ) diff --git a/broadcaster/_base.py b/broadcaster/_base.py index 4de1417..a52b11f 100644 --- a/broadcaster/_base.py +++ b/broadcaster/_base.py @@ -1,6 +1,6 @@ import asyncio from contextlib import asynccontextmanager -from typing import Any, AsyncGenerator, AsyncIterator, Dict, Optional +from typing import Any, AsyncGenerator, Dict, Optional, AsyncIterator from urllib.parse import urlparse @@ -51,6 +51,8 @@ def __init__(self, url: str): self._backend = MemoryBackend(url) + self._listener_task: asyncio.Task | None = None + async def __aenter__(self) -> "Broadcast": await self.connect() return self @@ -60,13 +62,13 @@ async def __aexit__(self, *args: Any, **kwargs: Any) -> None: async def connect(self) -> None: await self._backend.connect() - self._listener_task = asyncio.create_task(self._listener()) async def disconnect(self) -> None: - if self._listener_task.done(): - self._listener_task.result() - else: - self._listener_task.cancel() + if self._listener_task: + if self._listener_task.done(): + self._listener_task.result() + else: + self._listener_task.cancel() await self._backend.disconnect() async def _listener(self) -> None: @@ -85,6 +87,8 @@ async def subscribe(self, channel: str) -> AsyncIterator["Subscriber"]: try: if not self._subscribers.get(channel): await self._backend.subscribe(channel) + if not self._listener_task: + self._listener_task = asyncio.create_task(self._listener()) self._subscribers[channel] = set([queue]) else: self._subscribers[channel].add(queue) diff --git a/setup.py b/setup.py index 9cdf906..7ef96d0 100644 --- a/setup.py +++ b/setup.py @@ -46,7 +46,7 @@ def get_packages(package): package_data={"broadcaster": ["py.typed"]}, include_package_data=True, extras_require={ - "redis": ["asyncio-redis"], + "redis": ["redis"], "postgres": ["asyncpg"], "kafka": ["aiokafka"], "test": ["pytest", "pytest-asyncio"],