diff --git a/README.md b/README.md index f3b7d93..d11a9ed 100644 --- a/README.md +++ b/README.md @@ -86,6 +86,21 @@ Python 3.8+ * `Broadcast("postgres://localhost:5432/broadcaster")` * `Broadcast("kafka://localhost:9092")` + +### Using custom backends + +You can create your own backend and use it with `broadcaster`. +To do that you need to create a class which extends from `BroadcastBackend` +and pass it to the `broadcaster` via `backend` argument. + +```python +from broadcaster import Broadcaster, BroadcastBackend + +class MyBackend(BroadcastBackend): + ... + +broadcaster = Broadcaster(backend=MyBackend()) + ## Where next? At the moment `broadcaster` is in Alpha, and should be considered a working design document. diff --git a/broadcaster/__init__.py b/broadcaster/__init__.py index edc56d6..b5dd0bf 100644 --- a/broadcaster/__init__.py +++ b/broadcaster/__init__.py @@ -1,4 +1,5 @@ from ._base import Broadcast, Event +from ._backends.base import BroadcastBackend __version__ = "0.2.0" -__all__ = ["Broadcast", "Event"] +__all__ = ["Broadcast", "Event", "BroadcastBackend"] diff --git a/broadcaster/_base.py b/broadcaster/_base.py index 4de1417..4e0c2ef 100644 --- a/broadcaster/_base.py +++ b/broadcaster/_base.py @@ -1,8 +1,19 @@ import asyncio from contextlib import asynccontextmanager -from typing import Any, AsyncGenerator, AsyncIterator, Dict, Optional +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + AsyncIterator, + Dict, + Optional, + cast, +) from urllib.parse import urlparse +if TYPE_CHECKING: # pragma: no cover + from broadcaster._backends.base import BroadcastBackend + class Event: def __init__(self, channel: str, message: str) -> None: @@ -25,31 +36,35 @@ class Unsubscribed(Exception): class Broadcast: - def __init__(self, url: str): - from broadcaster._backends.base import BroadcastBackend + def __init__( + self, url: Optional[str] = None, *, backend: Optional["BroadcastBackend"] = None + ) -> None: + assert url or backend, "Either `url` or `backend` must be provided." + self._backend = backend or self._create_backend(cast(str, url)) + self._subscribers: Dict[str, Any] = {} + def _create_backend(self, url: str) -> "BroadcastBackend": parsed_url = urlparse(url) - self._backend: BroadcastBackend - self._subscribers: Dict[str, Any] = {} if parsed_url.scheme in ("redis", "rediss"): from broadcaster._backends.redis import RedisBackend - self._backend = RedisBackend(url) + return RedisBackend(url) elif parsed_url.scheme in ("postgres", "postgresql"): from broadcaster._backends.postgres import PostgresBackend - self._backend = PostgresBackend(url) + return PostgresBackend(url) if parsed_url.scheme == "kafka": from broadcaster._backends.kafka import KafkaBackend - self._backend = KafkaBackend(url) + return KafkaBackend(url) elif parsed_url.scheme == "memory": from broadcaster._backends.memory import MemoryBackend - self._backend = MemoryBackend(url) + return MemoryBackend(url) + raise ValueError(f"Unsupported backend: {parsed_url.scheme}") async def __aenter__(self) -> "Broadcast": await self.connect() diff --git a/tests/test_broadcast.py b/tests/test_broadcast.py index e3313bc..4cf9e45 100644 --- a/tests/test_broadcast.py +++ b/tests/test_broadcast.py @@ -1,6 +1,35 @@ import pytest +import typing +import asyncio -from broadcaster import Broadcast +from broadcaster import Broadcast, BroadcastBackend, Event + + +class CustomBackend(BroadcastBackend): + def __init__(self, url: str): + self._subscribed: typing.Set = set() + + async def connect(self) -> None: + self._published: asyncio.Queue = asyncio.Queue() + + async def disconnect(self) -> None: + pass + + async def subscribe(self, channel: str) -> None: + self._subscribed.add(channel) + + async def unsubscribe(self, channel: str) -> None: + self._subscribed.remove(channel) + + async def publish(self, channel: str, message: typing.Any) -> None: + event = Event(channel=channel, message=message) + await self._published.put(event) + + async def next_published(self) -> Event: + while True: + event = await self._published.get() + if event.channel in self._subscribed: + return event @pytest.mark.asyncio @@ -44,3 +73,29 @@ async def test_kafka(): event = await subscriber.get() assert event.channel == "chatroom" assert event.message == "hello" + + +@pytest.mark.asyncio +async def test_custom(): + backend = CustomBackend("") + async with Broadcast(backend=backend) as broadcast: + async with broadcast.subscribe("chatroom") as subscriber: + await broadcast.publish("chatroom", "hello") + event = await subscriber.get() + assert event.channel == "chatroom" + assert event.message == "hello" + + +@pytest.mark.asyncio +async def test_unknown_backend(): + with pytest.raises(ValueError, match="Unsupported backend"): + async with Broadcast(url="unknown://"): + pass + + +@pytest.mark.asyncio +async def test_needs_url_or_backend(): + with pytest.raises( + AssertionError, match="Either `url` or `backend` must be provided." + ): + Broadcast()