diff --git a/.gitignore b/.gitignore index 7b5d431..e001ace 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,6 @@ test.db .mypy_cache/ starlette.egg-info/ venv/ +broadcaster_noteable.egg-info/ +build/ +dist/ diff --git a/README.md b/README.md index 61709ac..58eff06 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ # Broadcaster +**Forked from original [broadcaster](https://github.com/encode/broadcaster) to get patch fixes available for re-use** + Broadcaster helps you develop realtime streaming functionality by providing a simple broadcast API onto a number of different backend services. diff --git a/RELEASING.md b/RELEASING.md new file mode 100644 index 0000000..3b04959 --- /dev/null +++ b/RELEASING.md @@ -0,0 +1,24 @@ +# Releasing + +## Prerequisites + +- Ensure release requirements are installed `pip install -r requirements-release.txt` + +## Push to GitHub + +Change from patch to minor or major for appropriate version updates in `broadcaster/__init__.py`, then push it to git. + +```bash + +git tag +git push upstream && git push upstream --tags +``` + +## Push to PyPI + +```bash +rm -rf dist/* +rm -rf build/* +python setup.py sdist bdist_wheel +twine upload dist/* +``` diff --git a/broadcaster/__init__.py b/broadcaster/__init__.py index 49a4d37..81e8797 100644 --- a/broadcaster/__init__.py +++ b/broadcaster/__init__.py @@ -1,3 +1,3 @@ from ._base import Broadcast, Event -__version__ = "0.2.0" +__version__ = "0.2.1" diff --git a/broadcaster/_backends/memory.py b/broadcaster/_backends/memory.py index 013c028..426a678 100644 --- a/broadcaster/_backends/memory.py +++ b/broadcaster/_backends/memory.py @@ -1,19 +1,26 @@ import asyncio +import logging import typing from .base import BroadcastBackend from .._base import Event +logger = logging.getLogger("broadcaster.memory") + class MemoryBackend(BroadcastBackend): def __init__(self, url: str): self._subscribed: typing.Set = set() - self._published: asyncio.Queue = asyncio.Queue() + self._published: typing.Optional[asyncio.Queue] = None async def connect(self) -> None: - pass + if self._published is not None: + logger.warning("already connected, cannot connect again!") + return + + self._published = asyncio.Queue() async def disconnect(self) -> None: - pass + self._published = None async def subscribe(self, channel: str) -> None: self._subscribed.add(channel) @@ -22,11 +29,19 @@ async def unsubscribe(self, channel: str) -> None: self._subscribed.remove(channel) async def publish(self, channel: str, message: typing.Any) -> None: + if self._published is None: + logger.warning("not connected, unable to publish message") + return + event = Event(channel=channel, message=message) await self._published.put(event) async def next_published(self) -> Event: while True: + if self._published is None: + logger.warning("not connected, unable to retrieve next published message") + continue + event = await self._published.get() if event.channel in self._subscribed: return event diff --git a/broadcaster/_backends/redis.py b/broadcaster/_backends/redis.py index 096a033..2235fce 100644 --- a/broadcaster/_backends/redis.py +++ b/broadcaster/_backends/redis.py @@ -1,34 +1,99 @@ -import asyncio_redis +import aioredis +from aioredis.abc import AbcChannel +from aioredis.pubsub import Receiver +import asyncio +import json +import logging import typing -from urllib.parse import urlparse from .base import BroadcastBackend from .._base import Event +logger = logging.getLogger("broadcaster.redis") + class RedisBackend(BroadcastBackend): def __init__(self, url: str): - parsed_url = urlparse(url) - self._host = parsed_url.hostname or "localhost" - self._port = parsed_url.port or 6379 + self.conn_url = url + + self._pub_conn: typing.Optional[aioredis.Redis] = None + self._sub_conn: typing.Optional[aioredis.Redis] = None + + self._msg_queue: typing.Optional[asyncio.Queue] = None + self._reader_task: typing.Optional[asyncio.Task] = None + self._mpsc: typing.Optional[Receiver] = None async def connect(self) -> None: - self._pub_conn = await asyncio_redis.Connection.create(self._host, self._port) - self._sub_conn = await asyncio_redis.Connection.create(self._host, self._port) - self._subscriber = await self._sub_conn.start_subscribe() + if self._pub_conn or self._sub_conn or self._msg_queue: + logger.warning("connections are already setup but connect called again; not doing anything") + return + + self._pub_conn = await aioredis.create_redis(self.conn_url) + self._sub_conn = await aioredis.create_redis(self.conn_url) + self._msg_queue = asyncio.Queue() # must be created here, to get proper event loop + self._mpsc = Receiver() + self._reader_task = asyncio.create_task(self._reader()) async def disconnect(self) -> None: - self._pub_conn.close() - self._sub_conn.close() + if self._pub_conn and self._sub_conn: + self._pub_conn.close() + self._sub_conn.close() + else: + logger.warning("connections are not setup, invalid call to disconnect") + + self._pub_conn = None + self._sub_conn = None + self._msg_queue = None + + if self._mpsc: + self._mpsc.stop() + else: + logger.warning("redis mpsc receiver is not set, cannot stop it") + + if self._reader_task: + if self._reader_task.done(): + self._reader_task.result() + else: + logger.debug("cancelling reader task") + self._reader_task.cancel() + self._reader_task = None async def subscribe(self, channel: str) -> None: - await self._subscriber.subscribe([channel]) + if not self._sub_conn: + logger.error(f"not connected, cannot subscribe to channel {channel!r}") + return + + await self._sub_conn.subscribe(self._mpsc.channel(channel)) async def unsubscribe(self, channel: str) -> None: - await self._subscriber.unsubscribe([channel]) + if not self._sub_conn: + logger.error(f"not connected, cannot unsubscribe from channel {channel!r}") + return + + await self._sub_conn.unsubscribe(channel) async def publish(self, channel: str, message: typing.Any) -> None: - await self._pub_conn.publish(channel, message) + if not self._pub_conn: + logger.error(f"not connected, cannot publish to channel {channel!r}") + return + + await self._pub_conn.publish_json(channel, message) async def next_published(self) -> Event: - message = await self._subscriber.next_published() - return Event(channel=message.channel, message=message.value) + if not self._msg_queue: + raise RuntimeError("unable to get next_published event, RedisBackend is not connected") + + return await self._msg_queue.get() + + async def _reader(self) -> None: + async for channel, msg in self._mpsc.iter(encoding="utf8", decoder=json.loads): + if not isinstance(channel, AbcChannel): + logger.error(f"invalid channel returned from Receiver().iter() - {channel!r}") + continue + + channel_name = channel.name.decode("utf8") + + if not self._msg_queue: + logger.error(f"unable to put new message from {channel_name} into queue, not connected") + continue + + await self._msg_queue.put(Event(channel=channel_name, message=msg)) diff --git a/broadcaster/_base.py b/broadcaster/_base.py index 44ec030..25ccc27 100644 --- a/broadcaster/_base.py +++ b/broadcaster/_base.py @@ -28,20 +28,24 @@ class Broadcast: def __init__(self, url: str): parsed_url = urlparse(url) self._subscribers = {} - if parsed_url.scheme == 'redis': + if parsed_url.scheme in ('redis', 'rediss'): from ._backends.redis import RedisBackend + self._backend = RedisBackend(url) elif parsed_url.scheme in ('postgres', 'postgresql'): from ._backends.postgres import PostgresBackend + self._backend = PostgresBackend(url) if parsed_url.scheme == 'kafka': from ._backends.kafka import KafkaBackend + self._backend = KafkaBackend(url) elif parsed_url.scheme == 'memory': from ._backends.memory import MemoryBackend + self._backend = MemoryBackend(url) async def __aenter__(self) -> 'Broadcast': diff --git a/docker-compose.yaml b/docker-compose.yaml index 60073b4..9b2b7d9 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -36,3 +36,8 @@ services: - POSTGRES_USER=postgres ports: - 5432:5432 + redis-pass: + image: "redis:alpine" + command: redis-server --requirepass 123 + ports: + - 6377:6379 diff --git a/requirements-release.txt b/requirements-release.txt new file mode 100644 index 0000000..2f6cb7e --- /dev/null +++ b/requirements-release.txt @@ -0,0 +1,4 @@ +pip>=21.0.1 +wheel>=0.36.2 +setuptools>=51.1.0 +twine>=3.3.0 diff --git a/setup.py b/setup.py index 75229e5..ad9db5f 100644 --- a/setup.py +++ b/setup.py @@ -35,19 +35,18 @@ def get_packages(package): setup( - name="broadcaster", + name="broadcaster-noteable", python_requires=">=3.7", version=get_version("broadcaster"), - url="https://github.com/encode/broadcaster", + url="https://github.com/noteable-io/broadcaster", license="BSD", description="Simple broadcast channels.", long_description=get_long_description(), long_description_content_type="text/markdown", - author="Tom Christie", - author_email="tom@tomchristie.com", + author="Noteable (extending from Tom Christie)", packages=get_packages("broadcaster"), extras_require={ - "redis": ["asyncio-redis"], + "redis": ["aioredis"], "postgres": ["asyncpg"], "kafka": ["aiokafka"] }, diff --git a/tests/test_broadcast.py b/tests/test_broadcast.py index 7e0c8c0..6ce711c 100644 --- a/tests/test_broadcast.py +++ b/tests/test_broadcast.py @@ -1,4 +1,3 @@ -import asyncio import pytest from broadcaster import Broadcast @@ -7,6 +6,7 @@ async def test_memory(): async with Broadcast('memory://') as broadcast: async with broadcast.subscribe('chatroom') as subscriber: + await broadcast.publish('foo', 'bar') await broadcast.publish('chatroom', 'hello') event = await subscriber.get() assert event.channel == 'chatroom' @@ -16,6 +16,17 @@ async def test_memory(): @pytest.mark.asyncio async def test_redis(): async with Broadcast('redis://localhost:6379') as broadcast: + async with broadcast.subscribe('chatroom') as subscriber: + await broadcast.publish('foo', 'bar') + await broadcast.publish('chatroom', 'hello') + event = await subscriber.get() + assert event.channel == 'chatroom' + assert event.message == 'hello' + + +@pytest.mark.asyncio +async def test_redis_complex(): + async with Broadcast('redis://:123@localhost:6377/4') as broadcast: async with broadcast.subscribe('chatroom') as subscriber: await broadcast.publish('chatroom', 'hello') event = await subscriber.get()