From 512560f8de1238636e1c047dbf841f7a5b61a9ff Mon Sep 17 00:00:00 2001 From: "alex.oleshkevich" Date: Thu, 6 Jun 2024 10:48:33 +0200 Subject: [PATCH] kafka improvements --- broadcaster/_backends/kafka.py | 16 ++++++++++++++-- tests/test_broadcast.py | 12 +++++++++++- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/broadcaster/_backends/kafka.py b/broadcaster/_backends/kafka.py index e577769..46fba6b 100644 --- a/broadcaster/_backends/kafka.py +++ b/broadcaster/_backends/kafka.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import typing from urllib.parse import urlparse @@ -10,9 +11,11 @@ class KafkaBackend(BroadcastBackend): - def __init__(self, url: str): - self._servers = [urlparse(url).netloc] + def __init__(self, urls: str | list[str]) -> None: + urls = [urls] if isinstance(urls, str) else urls + self._servers = [urlparse(url).netloc for url in urls] self._consumer_channels: set[str] = set() + self._ready = asyncio.Event() async def connect(self) -> None: self._producer = AIOKafkaProducer(bootstrap_servers=self._servers) @@ -27,6 +30,7 @@ async def disconnect(self) -> None: async def subscribe(self, channel: str) -> None: self._consumer_channels.add(channel) self._consumer.subscribe(topics=self._consumer_channels) + await self._wait_for_assignment() async def unsubscribe(self, channel: str) -> None: self._consumer.unsubscribe() @@ -35,5 +39,13 @@ async def publish(self, channel: str, message: typing.Any) -> None: await self._producer.send_and_wait(channel, message.encode("utf8")) async def next_published(self) -> Event: + await self._ready.wait() message = await self._consumer.getone() return Event(channel=message.topic, message=message.value.decode("utf8")) + + async def _wait_for_assignment(self) -> None: + """Wait for the consumer to be assigned to the partition.""" + while not self._consumer.assignment(): + await asyncio.sleep(0.001) + + self._ready.set() diff --git a/tests/test_broadcast.py b/tests/test_broadcast.py index b516ee2..deec741 100644 --- a/tests/test_broadcast.py +++ b/tests/test_broadcast.py @@ -6,6 +6,7 @@ import pytest from broadcaster import Broadcast, BroadcastBackend, Event +from broadcaster._backends.kafka import KafkaBackend class CustomBackend(BroadcastBackend): @@ -65,7 +66,6 @@ async def test_postgres(): assert event.message == "hello" -@pytest.mark.skip("Deadlock on `next_published`") @pytest.mark.asyncio async def test_kafka(): async with Broadcast("kafka://localhost:9092") as broadcast: @@ -76,6 +76,16 @@ async def test_kafka(): assert event.message == "hello" +@pytest.mark.asyncio +async def test_kafka_multiple_urls(): + async with Broadcast(backend=KafkaBackend(urls=["kafka://localhost:9092", "kafka://localhost:9092"])) as broadcast: + async with broadcast.subscribe("chatroom") as subscriber: + await broadcast.publish("chatroom", "hello") + event = await subscriber.get() + assert event.channel == "chatroom" + assert event.message == "hello" + + @pytest.mark.asyncio async def test_custom(): backend = CustomBackend("")