diff --git a/broadcaster/_base.py b/broadcaster/_base.py index 4de1417..c8dc221 100644 --- a/broadcaster/_base.py +++ b/broadcaster/_base.py @@ -90,12 +90,11 @@ async def subscribe(self, channel: str) -> AsyncIterator["Subscriber"]: self._subscribers[channel].add(queue) yield Subscriber(queue) - + finally: self._subscribers[channel].remove(queue) if not self._subscribers.get(channel): del self._subscribers[channel] await self._backend.unsubscribe(channel) - finally: await queue.put(None) diff --git a/tests/test_unsubscribe.py b/tests/test_unsubscribe.py new file mode 100644 index 0000000..30f928b --- /dev/null +++ b/tests/test_unsubscribe.py @@ -0,0 +1,25 @@ +import pytest +from broadcaster import Broadcast + + +@pytest.mark.asyncio +async def test_unsubscribe(): + """The queue should be removed when the context manager is left.""" + async with Broadcast("memory://") as broadcast: + async with broadcast.subscribe("chatroom"): + pass + + assert "chatroom" not in broadcast._subscribers + + +@pytest.mark.asyncio +async def test_unsubscribe_w_exception(): + """In case an exception is raised inside the context manager, the queue should be removed.""" + async with Broadcast("memory://") as broadcast: + try: + async with broadcast.subscribe("chatroom"): + raise RuntimeError("MyException") + except RuntimeError: + pass + + assert "chatroom" not in broadcast._subscribers