diff --git a/README.md b/README.md index f3b7d93..1e651f2 100644 --- a/README.md +++ b/README.md @@ -83,6 +83,7 @@ Python 3.8+ * `Broadcast('memory://')` * `Broadcast("redis://localhost:6379")` +* `Broadcast("redis-stream://localhost:6379")` * `Broadcast("postgres://localhost:5432/broadcaster")` * `Broadcast("kafka://localhost:9092")` @@ -96,6 +97,6 @@ state, make sure to strictly pin your requirements to `broadcaster==0.2.0`. To be more capable we'd really want to add some additional backends, provide API support for reading recent event history from persistent stores, and provide a serialization/deserialization API... * Serialization / deserialization to support broadcasting structured data. -* Backends for Redis Streams, Apache Kafka, and RabbitMQ. +* Backends for Apache Kafka, and RabbitMQ. * Add support for `subscribe('chatroom', history=100)` for backends which provide persistence. (Redis Streams, Apache Kafka) This will allow applications to subscribe to channel updates, while also being given an initial window onto the most recent events. We *might* also want to support some basic paging operations, to allow applications to scan back in the event history. * Support for pattern subscribes in backends that support it. diff --git a/broadcaster/_backends/redis.py b/broadcaster/_backends/redis.py index 2c4aeba..c6583f1 100644 --- a/broadcaster/_backends/redis.py +++ b/broadcaster/_backends/redis.py @@ -1,7 +1,9 @@ +import asyncio import typing from urllib.parse import urlparse import asyncio_redis +import aioredis from .._base import Event from .base import BroadcastBackend @@ -36,3 +38,49 @@ async def publish(self, channel: str, message: typing.Any) -> None: async def next_published(self) -> Event: message = await self._subscriber.next_published() return Event(channel=message.channel, message=message.value) + + +class RedisStreamBackend(BroadcastBackend): + def __init__(self, url: str): + self.conn_url = url.replace("redis-stream", "redis", 1) + self.streams: typing.Dict = dict() + + async def connect(self) -> None: + self._producer = await aioredis.from_url(self.conn_url) + self._consumer = await aioredis.from_url(self.conn_url) + + async def disconnect(self) -> None: + await self._producer.close() + await self._consumer.close() + + async def subscribe(self, channel: str) -> None: + try: + info = await self._consumer.xinfo_stream(channel) + last_id = info["last-generated-id"] + except aioredis.exceptions.ResponseError: + last_id = "0" + self.streams[channel] = last_id + + async def unsubscribe(self, channel: str) -> None: + self.streams.pop(channel, None) + + async def publish(self, channel: str, message: typing.Any) -> None: + await self._producer.xadd(channel, {"message": message}) + + async def wait_for_messages(self) -> typing.List: + messages = None + while not messages: + while not self.streams: + await asyncio.sleep(1) + messages = await self._consumer.xread(self.streams, count=1, block=1000) + return messages + + async def next_published(self) -> Event: + messages = await self.wait_for_messages() + stream, events = messages[0] + _msg_id, message = events[0] + self.streams[stream.decode("utf-8")] = _msg_id.decode("utf-8") + return Event( + channel=stream.decode("utf-8"), + message=message.get(b"message").decode("utf-8"), + ) diff --git a/broadcaster/_base.py b/broadcaster/_base.py index 4de1417..70cf23a 100644 --- a/broadcaster/_base.py +++ b/broadcaster/_base.py @@ -36,6 +36,11 @@ def __init__(self, url: str): self._backend = RedisBackend(url) + elif parsed_url.scheme == "redis-stream": + from broadcaster._backends.redis import RedisStreamBackend + + self._backend = RedisStreamBackend(url) + elif parsed_url.scheme in ("postgres", "postgresql"): from broadcaster._backends.postgres import PostgresBackend @@ -85,17 +90,16 @@ async def subscribe(self, channel: str) -> AsyncIterator["Subscriber"]: try: if not self._subscribers.get(channel): await self._backend.subscribe(channel) - self._subscribers[channel] = set([queue]) + self._subscribers[channel] = {queue} else: self._subscribers[channel].add(queue) yield Subscriber(queue) - + finally: self._subscribers[channel].remove(queue) if not self._subscribers.get(channel): del self._subscribers[channel] await self._backend.unsubscribe(channel) - finally: await queue.put(None) diff --git a/pyproject.toml b/pyproject.toml index ef59da4..2132183 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ dependencies = [ ] [project.optional-dependencies] -redis = ["asyncio-redis"] +redis = ["asyncio-redis", "aioredis>=2.0.1"] postgres = ["asyncpg"] kafka = ["aiokafka"] test = ["pytest", "pytest-asyncio"] diff --git a/tests/test_broadcast.py b/tests/test_broadcast.py index e3313bc..6df2d7e 100644 --- a/tests/test_broadcast.py +++ b/tests/test_broadcast.py @@ -23,6 +23,21 @@ async def test_redis(): assert event.message == "hello" +@pytest.mark.asyncio +async def test_redis_stream(): + async with Broadcast("redis-stream://localhost:6379") as broadcast: + async with broadcast.subscribe("chatroom") as subscriber: + await broadcast.publish("chatroom", "hello") + event = await subscriber.get() + assert event.channel == "chatroom" + assert event.message == "hello" + async with broadcast.subscribe("chatroom1") as subscriber: + await broadcast.publish("chatroom1", "hello") + event = await subscriber.get() + assert event.channel == "chatroom1" + assert event.message == "hello" + + @pytest.mark.asyncio async def test_postgres(): async with Broadcast(