Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions scripts/perf_ab.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ezmsg.util.perf.ab import main


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions src/ezmsg/core/backendprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ async def setup_state():
buf_size=stream.buf_size,
start_paused=True,
force_tcp=stream.force_tcp,
allow_local=stream.allow_local,
),
loop=loop,
).result()
Expand Down
95 changes: 76 additions & 19 deletions src/ezmsg/core/pubclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,36 @@

BACKPRESSURE_WARNING = "EZMSG_DISABLE_BACKPRESSURE_WARNING" not in os.environ
BACKPRESSURE_REFRACTORY = 5.0 # sec
ALLOW_LOCAL_ENV = "EZMSG_ALLOW_LOCAL"
FORCE_TCP_ENV = "EZMSG_FORCE_TCP"


def _process_allow_local_default() -> bool:
value = os.environ.get(ALLOW_LOCAL_ENV, "")
if value == "":
return True
return value.lower() in ("1", "true", "yes", "on")


def _process_force_tcp_default() -> bool:
value = os.environ.get(FORCE_TCP_ENV, "")
if value == "":
return False
return value.lower() in ("1", "true", "yes", "on")


def _resolve_force_tcp(force_tcp: bool | None) -> bool:
if force_tcp is None:
return _process_force_tcp_default()
return force_tcp


def _resolve_allow_local(force_tcp: bool, allow_local: bool | None) -> bool:
resolved = _process_allow_local_default() if allow_local is None else allow_local
if force_tcp and resolved:
logger.info("force_tcp=True disables local delivery for this publisher")
return False
return resolved


# Publisher needs a bit more information about connected channels
Expand Down Expand Up @@ -75,6 +105,7 @@ class Publisher:
_msg_id: int
_shm: SHMContext
_force_tcp: bool
_allow_local: bool
_last_backpressure_event: float

_graph_address: AddressType | None
Expand All @@ -99,7 +130,8 @@ async def create(
buf_size: int = DEFAULT_SHM_SIZE,
num_buffers: int = 32,
start_paused: bool = False,
force_tcp: bool = False,
force_tcp: bool | None = None,
allow_local: bool | None = None,
) -> "Publisher":
"""
Create a new Publisher instance and register it with the graph server.
Expand All @@ -116,6 +148,16 @@ async def create(
:type port: int | None
:param buf_size: Size of shared memory buffers.
:type buf_size: int
:param force_tcp: Whether to force TCP transport instead of shared memory.
If None, inherit the process default from ``EZMSG_FORCE_TCP`` which
defaults to disabled.
:type force_tcp: bool | None
:param allow_local: Whether to allow the in-process fast path when available.
If None, inherit the process default from ``EZMSG_ALLOW_LOCAL`` which
defaults to enabled. Set to False to bypass local delivery and
characterize same-process SHM or TCP. When ``force_tcp=True``, local
delivery is disabled regardless of this value.
:type allow_local: bool | None
:param kwargs: Additional keyword arguments for Publisher constructor.
:return: Initialized and registered Publisher instance.
:rtype: Publisher
Expand All @@ -127,6 +169,8 @@ async def create(
writer.write(Command.PUBLISH.value)
writer.write(encode_str(topic))

resolved_force_tcp = _resolve_force_tcp(force_tcp)

pub_id = UUID(await read_str(reader))
pub = cls(
id=pub_id,
Expand All @@ -135,7 +179,8 @@ async def create(
graph_address=graph_address,
num_buffers=num_buffers,
start_paused=start_paused,
force_tcp=force_tcp,
force_tcp=resolved_force_tcp,
allow_local=allow_local,
_guard=cls._SENTINEL,
)

Expand Down Expand Up @@ -189,7 +234,8 @@ def __init__(
graph_address: AddressType | None = None,
num_buffers: int = 32,
start_paused: bool = False,
force_tcp: bool = False,
force_tcp: bool | None = None,
allow_local: bool | None = None,
_guard = None
) -> None:
"""
Expand All @@ -207,7 +253,12 @@ def __init__(
:param start_paused: Whether to start in paused state.
:type start_paused: bool
:param force_tcp: Whether to force TCP transport instead of shared memory.
:type force_tcp: bool
If None, inherit the process default from ``EZMSG_FORCE_TCP``.
:type force_tcp: bool | None
:param allow_local: Whether to allow the direct in-process fast path when available.
If None, inherit the process default from ``EZMSG_ALLOW_LOCAL``.
When ``force_tcp=True``, local delivery is disabled regardless of this value.
:type allow_local: bool | None
"""
if _guard is not self._SENTINEL:
raise TypeError(
Expand All @@ -227,7 +278,8 @@ def __init__(
self._running.set()
self._num_buffers = num_buffers
self._backpressure = Backpressure(num_buffers)
self._force_tcp = force_tcp
self._force_tcp = _resolve_force_tcp(force_tcp)
self._allow_local = _resolve_allow_local(self._force_tcp, allow_local)
self._last_backpressure_event = -1
self._graph_address = graph_address

Expand Down Expand Up @@ -436,22 +488,18 @@ async def broadcast(self, obj: Any) -> None:
self._last_backpressure_event = time.time()
await self._backpressure.wait(buf_idx)

# Get local channel and put variable there for local tx
self._local_channel.put_local(self._msg_id, obj)
if self._should_use_local_fast_path():
self._local_channel.put_local(self._msg_id, obj)

if self._force_tcp or any(
ch.pid != self.pid or not ch.shm_ok for ch in self._channels.values()
):
if any(not self._can_deliver_locally(ch) for ch in self._channels.values()):
with MessageMarshal.serialize(self._msg_id, obj) as (
total_size,
header,
buffers,
):
total_size_bytes = uint64_to_bytes(total_size)

if not self._force_tcp and any(
ch.pid != self.pid and ch.shm_ok for ch in self._channels.values()
):
if any(self._can_deliver_via_shm(ch) for ch in self._channels.values()):
if self._shm.buf_size < total_size:
new_shm = await GraphService(self._graph_address).create_shm(
self._num_buffers, total_size * 2
Expand All @@ -475,14 +523,10 @@ async def broadcast(self, obj: Any) -> None:
for channel in self._channels.values():
msg: bytes = b""

if self.pid == channel.pid and channel.shm_ok:
if self._can_deliver_locally(channel):
continue # Local transmission handled by channel.put

elif (
(not self._force_tcp)
and self.pid != channel.pid
and channel.shm_ok
):
elif self._can_deliver_via_shm(channel):
msg = (
Command.TX_SHM.value
+ msg_id_bytes
Expand All @@ -509,3 +553,16 @@ async def broadcast(self, obj: Any) -> None:
)

self._msg_id += 1

def _should_use_local_fast_path(self) -> bool:
return any(self._can_deliver_locally(ch) for ch in self._channels.values())

def _can_deliver_locally(self, channel: PubChannelInfo) -> bool:
return self._allow_local and self.pid == channel.pid and channel.shm_ok

def _can_deliver_via_shm(self, channel: PubChannelInfo) -> bool:
return (
(not self._force_tcp)
and channel.shm_ok
and not self._can_deliver_locally(channel)
)
20 changes: 15 additions & 5 deletions src/ezmsg/core/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,15 +124,20 @@ class OutputStream(Stream):
:type num_buffers: int
:param buf_size: Size of each message buffer in bytes
:type buf_size: int
:param force_tcp: Whether to force TCP transport instead of shared memory
:type force_tcp: bool
:param force_tcp: Whether to force TCP transport instead of shared memory.
If None, inherit the process default from ``EZMSG_FORCE_TCP``.
:type force_tcp: bool | None
:param allow_local: Whether to allow the in-process fast path when available.
If None, inherit the process default from ``EZMSG_ALLOW_LOCAL``.
:type allow_local: bool | None
"""

host: str | None
port: int | None
num_buffers: int
buf_size: int
force_tcp: bool
force_tcp: bool | None
allow_local: bool | None

def __init__(
self,
Expand All @@ -141,15 +146,20 @@ def __init__(
port: int | None = None,
num_buffers: int = 32,
buf_size: int = DEFAULT_SHM_SIZE,
force_tcp: bool = False,
force_tcp: bool | None = None,
allow_local: bool | None = None,
) -> None:
super().__init__(msg_type)
self.host = host
self.port = port
self.num_buffers = num_buffers
self.buf_size = buf_size
self.force_tcp = force_tcp
self.allow_local = allow_local

def __repr__(self) -> str:
preamble = f"Output{super().__repr__()}"
return f"{preamble}({self.num_buffers=}, {self.force_tcp=})"
return (
f"{preamble}({self.num_buffers=}, {self.force_tcp=}, "
f"{self.allow_local=})"
)
Loading
Loading