Skip to content
This repository was archived by the owner on Aug 19, 2025. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 37 additions & 14 deletions broadcaster/_backends/redis.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import typing
from typing import Any
from urllib.parse import urlparse

import asyncio_redis
import redis.asyncio as redis
from redis.asyncio.client import PubSub

from .._base import Event
from .base import BroadcastBackend
import asyncio


class RedisBackend(BroadcastBackend):
Expand All @@ -13,26 +15,47 @@ def __init__(self, url: str):
self._host = parsed_url.hostname or "localhost"
self._port = parsed_url.port or 6379
self._password = parsed_url.password or None
self._ssl = parsed_url.scheme == "rediss"
self.kwargs = {
"host": self._host,
"port": self._port,
"password": self._password,
"ssl": self._ssl,
}

self._sub_conn: PubSub | None = None
self._pub_conn: PubSub | None = None

async def connect(self) -> None:
kwargs = {"host": self._host, "port": self._port, "password": self._password}
self._pub_conn = await asyncio_redis.Connection.create(**kwargs)
self._sub_conn = await asyncio_redis.Connection.create(**kwargs)
self._subscriber = await self._sub_conn.start_subscribe()
self._pub_conn = redis.Redis(**self.kwargs).pubsub()
self._sub_conn = redis.Redis(**self.kwargs).pubsub()

async def disconnect(self) -> None:
self._pub_conn.close()
self._sub_conn.close()
await self._pub_conn.close()
await self._sub_conn.close()

async def subscribe(self, channel: str) -> None:
await self._subscriber.subscribe([channel])
await self._sub_conn.subscribe(channel)

async def unsubscribe(self, channel: str) -> None:
await self._subscriber.unsubscribe([channel])
await self._sub_conn.unsubscribe(channel)

async def publish(self, channel: str, message: typing.Any) -> None:
await self._pub_conn.publish(channel, message)
async def publish(self, channel: str, message: Any) -> None:
try:
await self._pub_conn.execute_command("PUBLISH", channel, message)
except (redis.ConnectionError, redis.TimeoutError):
await asyncio.sleep(1)
self._pub_conn = redis.Redis(**self.kwargs).pubsub()
await self.publish(channel, message)

async def next_published(self) -> Event:
message = await self._subscriber.next_published()
return Event(channel=message.channel, message=message.value)
message = None
while not message:
message = await self._sub_conn.get_message(
ignore_subscribe_messages=True, timeout=None
)
event = Event(
channel=message["channel"].decode(),
message=message["data"].decode(),
)
return event
23 changes: 13 additions & 10 deletions broadcaster/_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, AsyncIterator, Dict, Optional
from typing import Any, AsyncGenerator, Optional, AsyncIterator
from urllib.parse import urlparse


Expand All @@ -26,13 +26,13 @@ class Unsubscribed(Exception):

class Broadcast:
def __init__(self, url: str):
from broadcaster._backends.base import BroadcastBackend
from ._backends.base import BroadcastBackend

parsed_url = urlparse(url)
self._backend: BroadcastBackend
self._subscribers: Dict[str, Any] = {}
self._subscribers: dict[str, set[asyncio.Queue]] = {}
if parsed_url.scheme in ("redis", "rediss"):
from broadcaster._backends.redis import RedisBackend
from ._backends.redis import RedisBackend

self._backend = RedisBackend(url)

Expand All @@ -51,6 +51,8 @@ def __init__(self, url: str):

self._backend = MemoryBackend(url)

self._listener_task: asyncio.Task | None = None

async def __aenter__(self) -> "Broadcast":
await self.connect()
return self
Expand All @@ -60,13 +62,13 @@ async def __aexit__(self, *args: Any, **kwargs: Any) -> None:

async def connect(self) -> None:
await self._backend.connect()
self._listener_task = asyncio.create_task(self._listener())

async def disconnect(self) -> None:
if self._listener_task.done():
self._listener_task.result()
else:
self._listener_task.cancel()
if self._listener_task:
if self._listener_task.done():
self._listener_task.result()
else:
self._listener_task.cancel()
await self._backend.disconnect()

async def _listener(self) -> None:
Expand All @@ -81,10 +83,11 @@ async def publish(self, channel: str, message: Any) -> None:
@asynccontextmanager
async def subscribe(self, channel: str) -> AsyncIterator["Subscriber"]:
queue: asyncio.Queue = asyncio.Queue()

try:
if not self._subscribers.get(channel):
await self._backend.subscribe(channel)
if not self._listener_task:
self._listener_task = asyncio.create_task(self._listener())
self._subscribers[channel] = set([queue])
else:
self._subscribers[channel].add(queue)
Expand Down