Skip to content
This repository was archived by the owner on Aug 19, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions broadcaster/_backends/kafka.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import typing
from urllib.parse import urlparse

Expand All @@ -10,9 +11,11 @@


class KafkaBackend(BroadcastBackend):
def __init__(self, url: str):
self._servers = [urlparse(url).netloc]
def __init__(self, urls: str | list[str]) -> None:
urls = [urls] if isinstance(urls, str) else urls
self._servers = [urlparse(url).netloc for url in urls]
self._consumer_channels: set[str] = set()
self._ready = asyncio.Event()

async def connect(self) -> None:
self._producer = AIOKafkaProducer(bootstrap_servers=self._servers)
Expand All @@ -27,6 +30,7 @@ async def disconnect(self) -> None:
async def subscribe(self, channel: str) -> None:
self._consumer_channels.add(channel)
self._consumer.subscribe(topics=self._consumer_channels)
await self._wait_for_assignment()

async def unsubscribe(self, channel: str) -> None:
self._consumer.unsubscribe()
Expand All @@ -35,5 +39,13 @@ async def publish(self, channel: str, message: typing.Any) -> None:
await self._producer.send_and_wait(channel, message.encode("utf8"))

async def next_published(self) -> Event:
await self._ready.wait()
message = await self._consumer.getone()
return Event(channel=message.topic, message=message.value.decode("utf8"))

async def _wait_for_assignment(self) -> None:
"""Wait for the consumer to be assigned to the partition."""
while not self._consumer.assignment():
await asyncio.sleep(0.001)

self._ready.set()
12 changes: 11 additions & 1 deletion tests/test_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

from broadcaster import Broadcast, BroadcastBackend, Event
from broadcaster._backends.kafka import KafkaBackend


class CustomBackend(BroadcastBackend):
Expand Down Expand Up @@ -65,7 +66,6 @@ async def test_postgres():
assert event.message == "hello"


@pytest.mark.skip("Deadlock on `next_published`")
@pytest.mark.asyncio
async def test_kafka():
async with Broadcast("kafka://localhost:9092") as broadcast:
Expand All @@ -76,6 +76,16 @@ async def test_kafka():
assert event.message == "hello"


@pytest.mark.asyncio
async def test_kafka_multiple_urls():
async with Broadcast(backend=KafkaBackend(urls=["kafka://localhost:9092", "kafka://localhost:9092"])) as broadcast:
async with broadcast.subscribe("chatroom") as subscriber:
await broadcast.publish("chatroom", "hello")
event = await subscriber.get()
assert event.channel == "chatroom"
assert event.message == "hello"


@pytest.mark.asyncio
async def test_custom():
backend = CustomBackend("")
Expand Down