From 5420577780a5fc1d5a227e4e9a54d3d50710a259 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Fri, 27 Feb 2026 13:56:08 -0500 Subject: [PATCH 01/52] phase 1 refactor --- src/ezmsg/core/backend.py | 142 +++++++++++++++++- src/ezmsg/core/graphcontext.py | 178 ++++++++++++++++++++-- src/ezmsg/core/graphserver.py | 263 +++++++++++++++++++++++++++++++-- src/ezmsg/core/netprotocol.py | 8 + 4 files changed, 567 insertions(+), 24 deletions(-) diff --git a/src/ezmsg/core/backend.py b/src/ezmsg/core/backend.py index 7a0aba1d..9e42005b 100644 --- a/src/ezmsg/core/backend.py +++ b/src/ezmsg/core/backend.py @@ -1,7 +1,9 @@ import asyncio +from dataclasses import fields, is_dataclass from collections.abc import Callable, Mapping, Iterable from collections.abc import Collection as AbstractCollection import enum +import inspect import logging import os import signal @@ -16,8 +18,9 @@ from .collection import Collection, NetworkDefinition from .component import Component -from .stream import Stream -from .unit import Unit, PROCESS_ATTR +from .stream import Stream, InputStream, OutputStream +from .unit import Unit, PROCESS_ATTR, SUBSCRIBES_ATTR, PUBLISHES_ATTR +from .settings import Settings from .graphserver import GraphService from .graphcontext import GraphContext @@ -255,6 +258,132 @@ def processes(self) -> list[BackendProcess]: def running(self) -> bool: return self._started + def _serialize_metadata_value(self, value): + if value is None or isinstance(value, (str, int, float, bool)): + return value + + if is_dataclass(value): + return { + field.name: self._serialize_metadata_value(getattr(value, field.name)) + for field in fields(value) + } + + if isinstance(value, dict): + return { + str(key): self._serialize_metadata_value(item) + for key, item in value.items() + } + + if isinstance(value, (list, tuple, set)): + return [self._serialize_metadata_value(item) for item in value] + + return repr(value) + + def _type_name(self, tp: type) -> str: + return f"{tp.__module__}.{tp.__qualname__}" + + def _component_metadata(self) -> dict[str, object]: + components: dict[str, dict[str, object]] = {} + + def crawl(component: Component) -> list[Component]: + search: list[Component] = [component] + out: list[Component] = [] + while search: + comp = search.pop() + out.append(comp) + search.extend(comp.components.values()) + return out + + for root in self._components.values(): + for comp in crawl(root): + dynamic_settings = { + "enabled": False, + "input_topic": None, + "settings_type": None, + } + input_settings = comp.streams.get("INPUT_SETTINGS") + if ( + isinstance(input_settings, InputStream) + and inspect.isclass(input_settings.msg_type) + and issubclass(input_settings.msg_type, Settings) + ): + dynamic_settings = { + "enabled": True, + "input_topic": input_settings.address, + "settings_type": self._type_name(input_settings.msg_type), + } + + stream_entries: dict[str, dict[str, object]] = {} + for stream_name, stream in comp.streams.items(): + entry: dict[str, object] = { + "name": stream_name, + "address": stream.address, + "msg_type": self._type_name(stream.msg_type) + if inspect.isclass(stream.msg_type) + else repr(stream.msg_type), + } + if isinstance(stream, InputStream): + entry["kind"] = "input" + entry["leaky"] = stream.leaky + entry["max_queue"] = stream.max_queue + elif isinstance(stream, OutputStream): + entry["kind"] = "output" + entry["host"] = stream.host + entry["port"] = stream.port + entry["num_buffers"] = stream.num_buffers + entry["buf_size"] = stream.buf_size + entry["force_tcp"] = stream.force_tcp + else: + entry["kind"] = "stream" + stream_entries[stream_name] = entry + + task_entries: list[dict[str, object]] = [] + for task_name, task in comp.tasks.items(): + task_entry: dict[str, object] = {"name": task_name} + if hasattr(task, SUBSCRIBES_ATTR): + sub_stream = getattr(task, SUBSCRIBES_ATTR) + if hasattr(sub_stream, "name") and sub_stream.name in comp.streams: + task_entry["subscribes"] = comp.streams[sub_stream.name].address + if hasattr(task, PUBLISHES_ATTR): + pub_streams = getattr(task, PUBLISHES_ATTR) + task_entry["publishes"] = [ + comp.streams[stream.name].address + for stream in pub_streams + if hasattr(stream, "name") and stream.name in comp.streams + ] + task_entries.append(task_entry) + + settings_type = getattr(comp.__class__, "__settings_type__", Settings) + metadata_entry: dict[str, object] = { + "address": comp.address, + "name": comp.name, + "component_type": self._type_name(comp.__class__), + "kind": ( + "collection" + if isinstance(comp, Collection) + else "unit" + if isinstance(comp, Unit) + else "component" + ), + "settings_type": self._type_name(settings_type), + "startup_settings": self._serialize_metadata_value(comp.SETTINGS), + "children": sorted(child.address for child in comp.components.values()), + "streams": stream_entries, + "tasks": sorted(task_entries, key=lambda task: str(task["name"])), + "main": comp.main.__name__ if comp.main is not None else None, + "threads": sorted(comp.threads.keys()), + "dynamic_settings": dynamic_settings, + } + components[comp.address] = metadata_entry + + return { + "schema_version": 1, + "root_name": self._root_name, + "components": { + address: components[address] for address in sorted(components.keys()) + }, + } + def start(self) -> None: if self._started: raise RuntimeError("GraphRunner is already running") @@ -360,6 +489,15 @@ async def setup_graph() -> None: asyncio.run_coroutine_threadsafe(setup_graph(), self._loop).result() + metadata = self._component_metadata() + + async def register_graph_metadata() -> None: + await graph_context.register_metadata(metadata) + + asyncio.run_coroutine_threadsafe( + register_graph_metadata(), self._loop + ).result() + if len(self._execution_context.processes) > 1: logger.info( f"Running in {len(self._execution_context.processes)} processes." diff --git a/src/ezmsg/core/graphcontext.py b/src/ezmsg/core/graphcontext.py index 1d62c2b2..52474b0d 100644 --- a/src/ezmsg/core/graphcontext.py +++ b/src/ezmsg/core/graphcontext.py @@ -1,14 +1,25 @@ import asyncio import logging +import pickle import typing -from .netprotocol import AddressType +from uuid import UUID +from types import TracebackType + +from .dag import CyclicException +from .netprotocol import ( + AddressType, + Command, + close_stream_writer, + encode_str, + read_int, + read_str, + uint64_to_bytes, +) from .graphserver import GraphServer, GraphService from .pubclient import Publisher from .subclient import Subscriber -from types import TracebackType - logger = logging.getLogger("ezmsg") @@ -40,6 +51,10 @@ class GraphContext: _graph_address: AddressType | None _graph_server: GraphServer | None + _session_id: UUID | None + _session_reader: asyncio.StreamReader | None + _session_writer: asyncio.StreamWriter | None + _session_lock: asyncio.Lock def __init__( self, @@ -51,6 +66,10 @@ def __init__( self._graph_address = graph_address self._graph_server = None self._auto_start = auto_start + self._session_id = None + self._session_reader = None + self._session_writer = None + self._session_lock = asyncio.Lock() @property def graph_address(self) -> AddressType | None: @@ -98,8 +117,18 @@ async def connect(self, from_topic: str, to_topic: str) -> None: :param to_topic: The destination topic name :type to_topic: str """ - - await GraphService(self.graph_address).connect(from_topic, to_topic) + if self._session_writer is not None: + response = await self._session_command( + Command.SESSION_CONNECT, + from_topic, + to_topic, + ) + if response == Command.CYCLIC.value: + raise CyclicException + if response != Command.COMPLETE.value: + raise RuntimeError("Unexpected response to session connect") + else: + await GraphService(self.graph_address).connect(from_topic, to_topic) self._edges.add((from_topic, to_topic)) async def disconnect(self, from_topic: str, to_topic: str) -> None: @@ -111,7 +140,16 @@ async def disconnect(self, from_topic: str, to_topic: str) -> None: :param to_topic: The destination topic name :type to_topic: str """ - await GraphService(self.graph_address).disconnect(from_topic, to_topic) + if self._session_writer is not None: + response = await self._session_command( + Command.SESSION_DISCONNECT, + from_topic, + to_topic, + ) + if response != Command.COMPLETE.value: + raise RuntimeError("Unexpected response to session disconnect") + else: + await GraphService(self.graph_address).disconnect(from_topic, to_topic) self._edges.discard((from_topic, to_topic)) async def sync(self, timeout: float | None = None) -> None: @@ -140,6 +178,103 @@ async def _ensure_servers(self) -> None: auto_start=self._auto_start ) + async def _open_session(self) -> None: + if self._session_writer is not None: + return + + reader, writer = await GraphService(self.graph_address).open_connection() + writer.write(Command.SESSION.value) + await writer.drain() + + session_id = UUID(await read_str(reader)) + response = await reader.read(1) + if response != Command.COMPLETE.value: + await close_stream_writer(writer) + raise RuntimeError("Failed to create GraphContext session") + + self._session_id = session_id + self._session_reader = reader + self._session_writer = writer + + async def _close_session(self) -> None: + writer = self._session_writer + if writer is None: + return + + try: + await self._session_command(Command.SESSION_CLEAR) + except ( + ConnectionRefusedError, + ConnectionResetError, + BrokenPipeError, + asyncio.IncompleteReadError, + ): + pass + + await close_stream_writer(writer) + self._session_id = None + self._session_reader = None + self._session_writer = None + self._edges.clear() + + async def _session_command( + self, + command: Command, + *args: str, + payload: bytes | None = None, + expect_snapshot: bool = False, + ) -> typing.Any: + reader = self._session_reader + writer = self._session_writer + if reader is None or writer is None: + raise RuntimeError("GraphContext session is not active") + + async with self._session_lock: + writer.write(command.value) + for arg in args: + writer.write(encode_str(arg)) + if payload is not None: + writer.write(uint64_to_bytes(len(payload))) + writer.write(payload) + await writer.drain() + + if expect_snapshot: + num_bytes = await read_int(reader) + snapshot_bytes = await reader.readexactly(num_bytes) + response = await reader.read(1) + if response != Command.COMPLETE.value: + raise RuntimeError("Unexpected response to session snapshot") + return pickle.loads(snapshot_bytes) + + return await reader.read(1) + + async def register_metadata(self, metadata: dict[str, typing.Any]) -> None: + if self._session_writer is None: + logger.warning("No active GraphContext session; metadata registration skipped") + return + + response = await self._session_command( + Command.SESSION_REGISTER, payload=pickle.dumps(metadata) + ) + if response != Command.COMPLETE.value: + raise RuntimeError("Unexpected response to session metadata registration") + + async def snapshot(self) -> dict[str, typing.Any]: + if self._session_writer is None: + dag = await GraphService(self.graph_address).dag() + return { + "graph": {node: sorted(conns) for node, conns in dag.graph.items()}, + "edge_owners": [], + "sessions": {}, + } + + snapshot = await self._session_command( + Command.SESSION_SNAPSHOT, expect_snapshot=True + ) + if not isinstance(snapshot, dict): + raise RuntimeError("Session snapshot payload was not a dictionary") + return snapshot + async def _shutdown_servers(self) -> None: if self._graph_server is not None: self._graph_server.stop() @@ -147,6 +282,7 @@ async def _shutdown_servers(self) -> None: async def __aenter__(self) -> "GraphContext": await self._ensure_servers() + await self._open_session() return self async def __aexit__( @@ -156,6 +292,7 @@ async def __aexit__( exc_tb: TracebackType | None, ) -> bool: await self.revert() + await self._close_session() await self._shutdown_servers() return False @@ -174,8 +311,29 @@ async def revert(self) -> None: for future in asyncio.as_completed(wait): await future - for edge in self._edges: + self._clients.clear() + + if self._session_writer is not None: try: - await GraphService(self.graph_address).disconnect(*edge) - except (ConnectionRefusedError, BrokenPipeError, ConnectionResetError) as e: - logger.warn(f"Could not remove edge {edge} from GraphServer: {e}") + response = await self._session_command(Command.SESSION_CLEAR) + if response != Command.COMPLETE.value: + logger.warning("GraphServer returned unexpected response to SESSION_CLEAR") + except ( + ConnectionRefusedError, + BrokenPipeError, + ConnectionResetError, + asyncio.IncompleteReadError, + ) as e: + logger.warning(f"Could not clear GraphContext session state: {e}") + else: + for edge in self._edges: + try: + await GraphService(self.graph_address).disconnect(*edge) + except ( + ConnectionRefusedError, + BrokenPipeError, + ConnectionResetError, + ) as e: + logger.warning(f"Could not remove edge {edge} from GraphServer: {e}") + + self._edges.clear() diff --git a/src/ezmsg/core/graphserver.py b/src/ezmsg/core/graphserver.py index f4abf8fc..7a0769db 100644 --- a/src/ezmsg/core/graphserver.py +++ b/src/ezmsg/core/graphserver.py @@ -4,8 +4,11 @@ import os import socket import threading +import warnings +from dataclasses import dataclass, field from contextlib import suppress from uuid import UUID, uuid1 +from typing import Any from . import __version__ @@ -36,6 +39,16 @@ logger = logging.getLogger("ezmsg") +LEGACY_OWNER = "legacy" + + +@dataclass +class SessionInfo: + id: UUID + writer: asyncio.StreamWriter + edges: set[tuple[str, str]] = field(default_factory=set) + metadata: dict[str, Any] = field(default_factory=dict) + class GraphServer(threading.Thread): """ @@ -61,6 +74,8 @@ class GraphServer(threading.Thread): graph: DAG clients: dict[UUID, ClientInfo] + sessions: dict[UUID, SessionInfo] + edge_owners: dict[tuple[str, str], set[UUID | str]] shms: dict[str, SHMInfo] _client_tasks: dict[UUID, "asyncio.Task[None]"] @@ -77,6 +92,8 @@ def __init__(self, **kwargs) -> None: # graph/server data self.graph = DAG() self.clients = {} + self.sessions = {} + self.edge_owners = {} self._client_tasks = {} self.shms = {} self._address = None @@ -261,6 +278,20 @@ async def api( # to avoid closing writer return + elif req == Command.SESSION.value: + session_id = uuid1() + self.sessions[session_id] = SessionInfo(session_id, writer) + writer.write(encode_str(str(session_id))) + writer.write(Command.COMPLETE.value) + await writer.drain() + self._client_tasks[session_id] = asyncio.create_task( + self._handle_session(session_id, reader, writer) + ) + + # NOTE: Created a session client, must return early + # to avoid closing writer + return + else: # We only want to handle one command at a time async with self._command_lock: @@ -302,23 +333,25 @@ async def api( elif req in [Command.CONNECT.value, Command.DISCONNECT.value]: from_topic = await read_str(reader) to_topic = await read_str(reader) - - cmd = self.graph.add_edge - if req == Command.DISCONNECT.value: - cmd = self.graph.remove_edge + topology_changed = False try: - cmd(from_topic, to_topic) - for sub in self._downstream_subs(to_topic): - await self._notify_subscriber(sub) + if req == Command.CONNECT.value: + topology_changed = self._connect_owner( + from_topic, to_topic, LEGACY_OWNER + ) + else: + topology_changed = self._disconnect_owner( + from_topic, to_topic, LEGACY_OWNER + ) writer.write(Command.COMPLETE.value) except CyclicException: writer.write(Command.CYCLIC.value) - await writer.drain() + if topology_changed: + await self._notify_downstream_for_topic(to_topic) - if req == Command.DISCONNECT.value: - await close_stream_writer(writer) + await writer.drain() elif req == Command.SYNC.value: for pub in self._publishers(): @@ -393,11 +426,205 @@ async def _handle_client( finally: # Ensure any waiter on this client unblocks - # with suppress(Exception): - self.clients[client_id].set_sync() + info = self.clients.get(client_id) + if info is not None: + info.set_sync() self.clients.pop(client_id, None) await close_stream_writer(writer) + async def _handle_session( + self, + session_id: UUID, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + ) -> None: + logger.debug(f"Graph Server: Session connected: {session_id}") + + try: + while True: + req = await reader.read(1) + + if not req: + break + + if req in [Command.SESSION_CONNECT.value, Command.SESSION_DISCONNECT.value]: + from_topic = await read_str(reader) + to_topic = await read_str(reader) + + async with self._command_lock: + topology_changed = False + try: + if req == Command.SESSION_CONNECT.value: + topology_changed = self._connect_owner( + from_topic, to_topic, session_id + ) + else: + topology_changed = self._disconnect_owner( + from_topic, to_topic, session_id + ) + writer.write(Command.COMPLETE.value) + except CyclicException: + writer.write(Command.CYCLIC.value) + + if topology_changed: + await self._notify_downstream_for_topic(to_topic) + + await writer.drain() + + elif req == Command.SESSION_CLEAR.value: + async with self._command_lock: + notify_topics = self._clear_session_state(session_id) + writer.write(Command.COMPLETE.value) + for topic in notify_topics: + await self._notify_downstream_for_topic(topic) + + await writer.drain() + + elif req == Command.SESSION_REGISTER.value: + num_bytes = await read_int(reader) + payload = await reader.readexactly(num_bytes) + metadata = pickle.loads(payload) + + async with self._command_lock: + session = self.sessions.get(session_id) + if session is not None: + session.metadata = metadata + writer.write(Command.COMPLETE.value) + + await writer.drain() + + elif req == Command.SESSION_SNAPSHOT.value: + async with self._command_lock: + snapshot = self._snapshot() + snapshot_bytes = pickle.dumps(snapshot) + writer.write(uint64_to_bytes(len(snapshot_bytes))) + writer.write(snapshot_bytes) + writer.write(Command.COMPLETE.value) + + await writer.drain() + + else: + logger.warning( + f"Session {session_id} rx unknown command from GraphServer: {req}" + ) + + except (ConnectionResetError, BrokenPipeError) as e: + logger.debug(f"Session {session_id} disconnected from GraphServer: {e}") + + finally: + async with self._command_lock: + notify_topics = self._drop_session(session_id) + + for topic in notify_topics: + await self._notify_downstream_for_topic(topic) + + self._client_tasks.pop(session_id, None) + await close_stream_writer(writer) + + def _connect_owner( + self, from_topic: str, to_topic: str, owner: UUID | str + ) -> bool: + edge = (from_topic, to_topic) + owners = self.edge_owners.setdefault(edge, set()) + + if owner in owners: + return False + + topology_changed = len(owners) == 0 + if topology_changed: + try: + self.graph.add_edge(from_topic, to_topic) + except CyclicException: + if len(owners) == 0: + self.edge_owners.pop(edge, None) + raise + + owners.add(owner) + + if isinstance(owner, UUID): + session = self.sessions.get(owner) + if session is not None: + session.edges.add(edge) + + return topology_changed + + def _disconnect_owner( + self, from_topic: str, to_topic: str, owner: UUID | str + ) -> bool: + edge = (from_topic, to_topic) + owners = self.edge_owners.get(edge) + if owners is None or owner not in owners: + return False + + owners.remove(owner) + + if isinstance(owner, UUID): + session = self.sessions.get(owner) + if session is not None: + session.edges.discard(edge) + + if len(owners) == 0: + self.edge_owners.pop(edge, None) + self.graph.remove_edge(from_topic, to_topic) + return True + + return False + + def _clear_session_state(self, session_id: UUID) -> set[str]: + notify_topics: set[str] = set() + session = self.sessions.get(session_id) + if session is None: + return notify_topics + + for from_topic, to_topic in list(session.edges): + if self._disconnect_owner(from_topic, to_topic, session_id): + notify_topics.add(to_topic) + + session.metadata.clear() + return notify_topics + + def _drop_session(self, session_id: UUID) -> set[str]: + notify_topics: set[str] = set() + session = self.sessions.pop(session_id, None) + if session is None: + return notify_topics + + for from_topic, to_topic in list(session.edges): + if self._disconnect_owner(from_topic, to_topic, session_id): + notify_topics.add(to_topic) + + session.metadata.clear() + return notify_topics + + def _snapshot(self) -> dict[str, Any]: + graph = {node: sorted(conns) for node, conns in self.graph.graph.items()} + edge_owners = [ + { + "from_topic": from_topic, + "to_topic": to_topic, + "owners": [str(owner) for owner in sorted(owners, key=str)], + } + for (from_topic, to_topic), owners in sorted(self.edge_owners.items()) + ] + sessions = { + str(session_id): { + "edges": sorted( + [ + {"from_topic": from_topic, "to_topic": to_topic} + for from_topic, to_topic in session.edges + ], + key=lambda edge: (edge["from_topic"], edge["to_topic"]), + ), + "metadata": session.metadata, + } + for session_id, session in sorted(self.sessions.items(), key=lambda item: str(item[0])) + } + return {"graph": graph, "edge_owners": edge_owners, "sessions": sessions} + + async def _notify_downstream_for_topic(self, topic: str) -> None: + for sub in self._downstream_subs(topic): + await self._notify_subscriber(sub) + async def _notify_subscriber(self, sub: SubscriberInfo) -> None: try: pub_ids = [str(pub.id) for pub in self._upstream_pubs(sub.topic)] @@ -497,6 +724,12 @@ async def open_connection( return reader, writer async def connect(self, from_topic: str, to_topic: str) -> None: + warnings.warn( + "GraphService.connect is deprecated. Prefer GraphContext.connect " + "to use a session-scoped control plane.", + DeprecationWarning, + stacklevel=2, + ) reader, writer = await self.open_connection() writer.write(Command.CONNECT.value) writer.write(encode_str(from_topic)) @@ -509,6 +742,12 @@ async def connect(self, from_topic: str, to_topic: str) -> None: await close_stream_writer(writer) async def disconnect(self, from_topic: str, to_topic: str) -> None: + warnings.warn( + "GraphService.disconnect is deprecated. Prefer GraphContext.disconnect " + "to use a session-scoped control plane.", + DeprecationWarning, + stacklevel=2, + ) reader, writer = await self.open_connection() writer.write(Command.DISCONNECT.value) writer.write(encode_str(from_topic)) diff --git a/src/ezmsg/core/netprotocol.py b/src/ezmsg/core/netprotocol.py index 04ab0839..d640d032 100644 --- a/src/ezmsg/core/netprotocol.py +++ b/src/ezmsg/core/netprotocol.py @@ -300,6 +300,14 @@ def _generate_next_value_(name, start, count, last_values) -> bytes: SHM_OK = enum.auto() SHM_ATTACH_FAILED = enum.auto() + # GraphContext Session Commands (control plane) + SESSION = enum.auto() + SESSION_CONNECT = enum.auto() + SESSION_DISCONNECT = enum.auto() + SESSION_CLEAR = enum.auto() + SESSION_REGISTER = enum.auto() + SESSION_SNAPSHOT = enum.auto() + def create_socket( host: str | None = None, From 13226e1239bffdb3241d63ab829c69abb13db402 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Mon, 2 Mar 2026 09:14:04 -0500 Subject: [PATCH 02/52] bugfix --- src/ezmsg/core/graphserver.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/ezmsg/core/graphserver.py b/src/ezmsg/core/graphserver.py index 7a0769db..4a6ebcee 100644 --- a/src/ezmsg/core/graphserver.py +++ b/src/ezmsg/core/graphserver.py @@ -462,22 +462,24 @@ async def _handle_session( topology_changed = self._disconnect_owner( from_topic, to_topic, session_id ) - writer.write(Command.COMPLETE.value) except CyclicException: writer.write(Command.CYCLIC.value) + await writer.drain() + continue if topology_changed: await self._notify_downstream_for_topic(to_topic) + writer.write(Command.COMPLETE.value) await writer.drain() elif req == Command.SESSION_CLEAR.value: async with self._command_lock: notify_topics = self._clear_session_state(session_id) - writer.write(Command.COMPLETE.value) for topic in notify_topics: await self._notify_downstream_for_topic(topic) + writer.write(Command.COMPLETE.value) await writer.drain() elif req == Command.SESSION_REGISTER.value: From d2853403ef4e2bf0f1904b47d6bd383c6c8d8bc5 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Mon, 2 Mar 2026 14:47:00 -0500 Subject: [PATCH 03/52] phase-1 impl. --- src/ezmsg/core/backend.py | 241 +++++++++++++++------------- src/ezmsg/core/dag.py | 81 ++++++++-- src/ezmsg/core/graphcontext.py | 243 ++++++++++++++++++----------- src/ezmsg/core/graphmeta.py | 101 ++++++++++++ src/ezmsg/core/graphserver.py | 277 ++++++++++++++++----------------- src/ezmsg/core/netprotocol.py | 11 ++ tests/test_graph_session.py | 106 +++++++++++++ 7 files changed, 714 insertions(+), 346 deletions(-) create mode 100644 src/ezmsg/core/graphmeta.py create mode 100644 tests/test_graph_session.py diff --git a/src/ezmsg/core/backend.py b/src/ezmsg/core/backend.py index 9e42005b..3787b5be 100644 --- a/src/ezmsg/core/backend.py +++ b/src/ezmsg/core/backend.py @@ -1,11 +1,12 @@ import asyncio -from dataclasses import fields, is_dataclass from collections.abc import Callable, Mapping, Iterable from collections.abc import Collection as AbstractCollection +from dataclasses import asdict, is_dataclass import enum import inspect import logging import os +import pickle import signal from threading import BrokenBarrierError from multiprocessing import Event, Barrier @@ -21,6 +22,19 @@ from .stream import Stream, InputStream, OutputStream from .unit import Unit, PROCESS_ATTR, SUBSCRIBES_ATTR, PUBLISHES_ATTR from .settings import Settings +from .graphmeta import ( + CollectionMetadata, + ComponentMetadata, + ComponentMetadataType, + DynamicSettingsMetadata, + InputStreamMetadata, + OutputStreamMetadata, + StreamMetadataType, + StreamMetadata, + TaskMetadata, + GraphMetadata, + UnitMetadata, +) from .graphserver import GraphService from .graphcontext import GraphContext @@ -36,6 +50,21 @@ logger = logging.getLogger("ezmsg") +def crawl_components( + component: Component, + callback: Callable[[Component], None] | None = None, +) -> list[Component]: + search: list[Component] = [component] + out: list[Component] = [] + while len(search): + comp = search.pop() + out.append(comp) + search += list(comp.components.values()) + if callback is not None: + callback(comp) + return out + + class ExecutionContext: _process_units: list[list[Unit]] _processes: list[BackendProcess] | None @@ -115,15 +144,6 @@ def setup( from_topic = from_topic.name graph_connections.append((from_topic, to_topic)) - def crawl_components( - component: Component, callback: Callable[[Component], None] - ) -> None: - search: list[Component] = [component] - while len(search): - comp = search.pop() - search += list(comp.components.values()) - callback(comp) - def gather_edges(comp: Component): if isinstance(comp, Collection): for from_stream, to_stream in comp.network(): @@ -258,131 +278,142 @@ def processes(self) -> list[BackendProcess]: def running(self) -> bool: return self._started - def _serialize_metadata_value(self, value): - if value is None or isinstance(value, (str, int, float, bool)): - return value - - if is_dataclass(value): - return { - field.name: self._serialize_metadata_value(getattr(value, field.name)) - for field in fields(value) - } + def _type_name(self, tp: type) -> str: + return f"{tp.__module__}.{tp.__qualname__}" - if isinstance(value, dict): - return { - str(key): self._serialize_metadata_value(item) - for key, item in value.items() - } + def _stream_type_name(self, stream_type: object) -> str: + if inspect.isclass(stream_type): + return self._type_name(stream_type) + return repr(stream_type) - if isinstance(value, (list, tuple, set)): - return [self._serialize_metadata_value(item) for item in value] + def _settings_repr(self, value: object) -> dict[str, object] | str: + if is_dataclass(value): + try: + asdict_value = asdict(value) + if isinstance(asdict_value, dict): + return asdict_value + except Exception: + pass return repr(value) - def _type_name(self, tp: type) -> str: - return f"{tp.__module__}.{tp.__qualname__}" - - def _component_metadata(self) -> dict[str, object]: - components: dict[str, dict[str, object]] = {} + def _settings_snapshot(self, value: object) -> tuple[bytes | None, dict[str, object] | str]: + try: + pickled = pickle.dumps(value) + except Exception as exc: + logger.warning(f"Could not pickle settings for metadata: {exc}") + pickled = None + return pickled, self._settings_repr(value) - def crawl(component: Component) -> list[Component]: - search: list[Component] = [component] - out: list[Component] = [] - while search: - comp = search.pop() - out.append(comp) - search.extend(comp.components.values()) - return out + def _component_metadata(self) -> GraphMetadata: + components: dict[str, ComponentMetadataType] = {} for root in self._components.values(): - for comp in crawl(root): - dynamic_settings = { - "enabled": False, - "input_topic": None, - "settings_type": None, - } + for comp in crawl_components(root): input_settings = comp.streams.get("INPUT_SETTINGS") - if ( - isinstance(input_settings, InputStream) - and inspect.isclass(input_settings.msg_type) - and issubclass(input_settings.msg_type, Settings) - ): - dynamic_settings = { - "enabled": True, - "input_topic": input_settings.address, - "settings_type": self._type_name(input_settings.msg_type), - } - - stream_entries: dict[str, dict[str, object]] = {} + dynamic_settings = DynamicSettingsMetadata( + enabled=isinstance(input_settings, InputStream), + input_topic=( + input_settings.address + if isinstance(input_settings, InputStream) + else None + ), + settings_type=( + self._stream_type_name(input_settings.msg_type) + if isinstance(input_settings, InputStream) + else None + ), + ) + + stream_entries: dict[str, StreamMetadataType] = {} for stream_name, stream in comp.streams.items(): - entry: dict[str, object] = { - "name": stream_name, - "address": stream.address, - "msg_type": self._type_name(stream.msg_type) - if inspect.isclass(stream.msg_type) - else repr(stream.msg_type), - } if isinstance(stream, InputStream): - entry["kind"] = "input" - entry["leaky"] = stream.leaky - entry["max_queue"] = stream.max_queue + entry = InputStreamMetadata( + name=stream_name, + address=stream.address, + msg_type=self._stream_type_name(stream.msg_type), + leaky=stream.leaky, + max_queue=stream.max_queue, + ) elif isinstance(stream, OutputStream): - entry["kind"] = "output" - entry["host"] = stream.host - entry["port"] = stream.port - entry["num_buffers"] = stream.num_buffers - entry["buf_size"] = stream.buf_size - entry["force_tcp"] = stream.force_tcp + entry = OutputStreamMetadata( + name=stream_name, + address=stream.address, + msg_type=self._stream_type_name(stream.msg_type), + host=stream.host, + port=stream.port, + num_buffers=stream.num_buffers, + buf_size=stream.buf_size, + force_tcp=stream.force_tcp, + ) else: - entry["kind"] = "stream" + entry = StreamMetadata( + name=stream_name, + address=stream.address, + msg_type=self._stream_type_name(stream.msg_type), + ) stream_entries[stream_name] = entry - task_entries: list[dict[str, object]] = [] + task_entries: list[TaskMetadata] = [] for task_name, task in comp.tasks.items(): - task_entry: dict[str, object] = {"name": task_name} + task_entry = TaskMetadata(name=task_name) + if hasattr(task, SUBSCRIBES_ATTR): sub_stream = getattr(task, SUBSCRIBES_ATTR) if hasattr(sub_stream, "name") and sub_stream.name in comp.streams: - task_entry["subscribes"] = comp.streams[sub_stream.name].address + task_entry.subscribes = comp.streams[sub_stream.name].address + if hasattr(task, PUBLISHES_ATTR): pub_streams = getattr(task, PUBLISHES_ATTR) - task_entry["publishes"] = [ + task_entry.publishes = [ comp.streams[stream.name].address for stream in pub_streams if hasattr(stream, "name") and stream.name in comp.streams ] + task_entries.append(task_entry) settings_type = getattr(comp.__class__, "__settings_type__", Settings) - metadata_entry: dict[str, object] = { - "address": comp.address, - "name": comp.name, - "component_type": self._type_name(comp.__class__), - "kind": ( - "collection" - if isinstance(comp, Collection) - else "unit" - if isinstance(comp, Unit) - else "component" - ), - "settings_type": self._type_name(settings_type), - "startup_settings": self._serialize_metadata_value(comp.SETTINGS), - "children": sorted(child.address for child in comp.components.values()), - "streams": stream_entries, - "tasks": sorted(task_entries, key=lambda task: str(task["name"])), - "main": comp.main.__name__ if comp.main is not None else None, - "threads": sorted(comp.threads.keys()), - "dynamic_settings": dynamic_settings, - } + settings_type_name = ( + self._type_name(settings_type) + if inspect.isclass(settings_type) + else repr(settings_type) + ) + + component_common = dict( + address=comp.address, + name=comp.name, + component_type=self._type_name(comp.__class__), + settings_type=settings_type_name, + initial_settings=self._settings_snapshot(comp.SETTINGS), + streams=stream_entries, + dynamic_settings=dynamic_settings, + ) + + metadata_entry: ComponentMetadataType + if isinstance(comp, Collection): + metadata_entry = CollectionMetadata( + **component_common, + children=sorted( + child.address for child in comp.components.values() + ), + ) + elif isinstance(comp, Unit): + metadata_entry = UnitMetadata( + **component_common, + tasks=sorted(task_entries, key=lambda task: task.name), + main=comp.main.__name__ if comp.main is not None else None, + threads=sorted(comp.threads.keys()), + ) + else: + metadata_entry = ComponentMetadata(**component_common) components[comp.address] = metadata_entry - return { - "schema_version": 1, - "root_name": self._root_name, - "components": { - address: components[address] for address in sorted(components.keys()) - }, - } + return GraphMetadata( + schema_version=1, + root_name=self._root_name, + components={address: components[address] for address in sorted(components)}, + ) def start(self) -> None: if self._started: diff --git a/src/ezmsg/core/dag.py b/src/ezmsg/core/dag.py index d0c9a72a..1ae5290e 100644 --- a/src/ezmsg/core/dag.py +++ b/src/ezmsg/core/dag.py @@ -1,6 +1,7 @@ from collections import defaultdict from copy import deepcopy from dataclasses import dataclass, field +from collections.abc import Hashable class CyclicException(Exception): @@ -15,6 +16,8 @@ class CyclicException(Exception): GraphType = defaultdict[str, set[str]] +EdgeType = tuple[str, str] +OwnerType = Hashable | None @dataclass @@ -28,6 +31,9 @@ class DAG: """ graph: GraphType = field(default_factory=lambda: defaultdict(set), init=False) + edge_owners: dict[EdgeType, set[OwnerType]] = field( + default_factory=dict, init=False + ) @property def nodes(self) -> set[str]: @@ -60,47 +66,94 @@ def invgraph(self) -> GraphType: invgraph[to_node].add(from_node) return invgraph - def add_edge(self, from_node: str, to_node: str) -> None: + def add_edge( + self, from_node: str, to_node: str, owner: OwnerType = None + ) -> bool: """ Ensure an edge exists in the graph. - Adds an edge from from_node to to_node. Does nothing if the edge already exists. + Adds an edge from from_node to to_node for the given owner. + If this is an additional owner for an existing edge, topology does not change. If the edge would make the graph cyclic, raises CyclicException. :param from_node: Source node name :type from_node: str :param to_node: Destination node name :type to_node: str + :param owner: Owner token for this edge; ``None`` is treated as persistent. + :type owner: collections.abc.Hashable | None :raises CyclicException: If adding the edge would create a cycle + :return: True if graph topology changed; False if this only added an owner. + :rtype: bool """ if from_node == to_node: raise CyclicException - test_graph = deepcopy(self.graph) - test_graph[from_node].add(to_node) - test_graph[to_node] + edge = (from_node, to_node) + owners = self.edge_owners.setdefault(edge, set()) + if owner in owners: + return False - if from_node in _bfs(test_graph, from_node): - raise CyclicException + topology_changed = len(owners) == 0 + if topology_changed: + test_graph = deepcopy(self.graph) + test_graph[from_node].add(to_node) + test_graph[to_node] + + if from_node in _bfs(test_graph, from_node): + if len(owners) == 0: + self.edge_owners.pop(edge, None) + raise CyclicException + + # No cycles! Modify referenced data structure + self.graph[from_node].add(to_node) + self.graph[to_node] - # No cycles! Modify referenced data structure - self.graph[from_node].add(to_node) - self.graph[to_node] + owners.add(owner) + return topology_changed - def remove_edge(self, from_node: str, to_node: str) -> None: + def remove_edge( + self, from_node: str, to_node: str, owner: OwnerType = None + ) -> bool: """ Ensure an edge is not present in the graph. - Removes an edge from from_node to to_node. Does nothing if the edge doesn't exist. + Removes ownership of an edge from from_node to to_node. + Topology only changes when the last owner is removed. Automatically prunes unconnected nodes after removal. :param from_node: Source node name :type from_node: str :param to_node: Destination node name :type to_node: str + :param owner: Owner token for this edge; ``None`` targets persistent ownership. + :type owner: collections.abc.Hashable | None + :return: True if graph topology changed; False if owner was absent or still shared. + :rtype: bool """ - self.graph.get(from_node, set()).discard(to_node) - self._prune() + edge = (from_node, to_node) + owners = self.edge_owners.get(edge, None) + if owners is None or owner not in owners: + return False + + owners.remove(owner) + + topology_changed = False + if len(owners) == 0: + self.edge_owners.pop(edge, None) + self.graph.get(from_node, set()).discard(to_node) + self._prune() + topology_changed = True + + return topology_changed + + def remove_owner(self, owner: OwnerType) -> set[EdgeType]: + removed_edges: set[EdgeType] = set() + for edge in list(self.edge_owners.keys()): + if owner in self.edge_owners.get(edge, set()): + if self.remove_edge(*edge, owner=owner): + removed_edges.add(edge) + return removed_edges def downstream(self, from_node: str) -> list[str]: """ diff --git a/src/ezmsg/core/graphcontext.py b/src/ezmsg/core/graphcontext.py index 52474b0d..3beddfdf 100644 --- a/src/ezmsg/core/graphcontext.py +++ b/src/ezmsg/core/graphcontext.py @@ -1,10 +1,13 @@ import asyncio import logging -import pickle import typing +import enum +import pickle from uuid import UUID from types import TracebackType +from dataclasses import dataclass +from contextlib import suppress from .dag import CyclicException from .netprotocol import ( @@ -19,10 +22,25 @@ from .graphserver import GraphServer, GraphService from .pubclient import Publisher from .subclient import Subscriber +from .graphmeta import GraphMetadata, GraphSnapshot logger = logging.getLogger("ezmsg") +class _SessionResponseKind(enum.Enum): + BYTE = enum.auto() + SNAPSHOT = enum.auto() + + +@dataclass +class _SessionCommand: + command: Command + args: tuple[str, ...] + payload: bytes | None + response_kind: _SessionResponseKind + response_fut: "asyncio.Future[typing.Any]" + + class GraphContext: """ GraphContext maintains a list of created publishers, subscribers, and connections in the graph. @@ -54,7 +72,8 @@ class GraphContext: _session_id: UUID | None _session_reader: asyncio.StreamReader | None _session_writer: asyncio.StreamWriter | None - _session_lock: asyncio.Lock + _session_task: asyncio.Task[None] | None + _session_commands: asyncio.Queue[_SessionCommand | None] | None def __init__( self, @@ -69,7 +88,8 @@ def __init__( self._session_id = None self._session_reader = None self._session_writer = None - self._session_lock = asyncio.Lock() + self._session_task = None + self._session_commands = None @property def graph_address(self) -> AddressType | None: @@ -117,18 +137,16 @@ async def connect(self, from_topic: str, to_topic: str) -> None: :param to_topic: The destination topic name :type to_topic: str """ - if self._session_writer is not None: - response = await self._session_command( - Command.SESSION_CONNECT, - from_topic, - to_topic, - ) - if response == Command.CYCLIC.value: - raise CyclicException - if response != Command.COMPLETE.value: - raise RuntimeError("Unexpected response to session connect") - else: - await GraphService(self.graph_address).connect(from_topic, to_topic) + response = await self._session_command( + Command.SESSION_CONNECT, + from_topic, + to_topic, + response_kind=_SessionResponseKind.BYTE, + ) + if response == Command.CYCLIC.value: + raise CyclicException + if response != Command.COMPLETE.value: + raise RuntimeError("Unexpected response to session connect") self._edges.add((from_topic, to_topic)) async def disconnect(self, from_topic: str, to_topic: str) -> None: @@ -140,16 +158,14 @@ async def disconnect(self, from_topic: str, to_topic: str) -> None: :param to_topic: The destination topic name :type to_topic: str """ - if self._session_writer is not None: - response = await self._session_command( - Command.SESSION_DISCONNECT, - from_topic, - to_topic, - ) - if response != Command.COMPLETE.value: - raise RuntimeError("Unexpected response to session disconnect") - else: - await GraphService(self.graph_address).disconnect(from_topic, to_topic) + response = await self._session_command( + Command.SESSION_DISCONNECT, + from_topic, + to_topic, + response_kind=_SessionResponseKind.BYTE, + ) + if response != Command.COMPLETE.value: + raise RuntimeError("Unexpected response to session disconnect") self._edges.discard((from_topic, to_topic)) async def sync(self, timeout: float | None = None) -> None: @@ -195,26 +211,97 @@ async def _open_session(self) -> None: self._session_id = session_id self._session_reader = reader self._session_writer = writer + self._session_commands = asyncio.Queue() + self._session_task = asyncio.create_task( + self._session_io_loop(), + name=f"graphctx-session-{session_id}", + ) + + def _require_session(self) -> tuple[asyncio.Queue[_SessionCommand | None], asyncio.Task[None]]: + if self._session_commands is None or self._session_task is None: + raise RuntimeError( + "GraphContext session is not active. Use GraphContext as an async context manager." + ) + return self._session_commands, self._session_task + + async def _session_io_loop(self) -> None: + reader = self._session_reader + writer = self._session_writer + commands = self._session_commands + if reader is None or writer is None or commands is None: + return + + try: + while True: + cmd = await commands.get() + if cmd is None: + break + + writer.write(cmd.command.value) + for arg in cmd.args: + writer.write(encode_str(arg)) + if cmd.payload is not None: + writer.write(uint64_to_bytes(len(cmd.payload))) + writer.write(cmd.payload) + await writer.drain() + + if cmd.response_kind == _SessionResponseKind.BYTE: + response = await reader.read(1) + + elif cmd.response_kind == _SessionResponseKind.SNAPSHOT: + num_bytes = await read_int(reader) + snapshot_bytes = await reader.readexactly(num_bytes) + complete = await reader.read(1) + if complete != Command.COMPLETE.value: + raise RuntimeError("Unexpected response to session snapshot") + response = pickle.loads(snapshot_bytes) + + else: + raise RuntimeError(f"Unsupported response kind: {cmd.response_kind}") + + if not cmd.response_fut.done(): + cmd.response_fut.set_result(response) + + except Exception as exc: + while True: + try: + pending = commands.get_nowait() + except asyncio.QueueEmpty: + break + + if pending is not None and not pending.response_fut.done(): + pending.response_fut.set_exception(exc) + finally: + while True: + try: + pending = commands.get_nowait() + except asyncio.QueueEmpty: + break + + if pending is not None and not pending.response_fut.done(): + pending.response_fut.set_exception( + RuntimeError("GraphContext session closed") + ) async def _close_session(self) -> None: + commands = self._session_commands + task = self._session_task writer = self._session_writer if writer is None: return - try: - await self._session_command(Command.SESSION_CLEAR) - except ( - ConnectionRefusedError, - ConnectionResetError, - BrokenPipeError, - asyncio.IncompleteReadError, - ): - pass + if commands is not None: + await commands.put(None) + if task is not None: + with suppress(asyncio.CancelledError): + await task await close_stream_writer(writer) self._session_id = None self._session_reader = None self._session_writer = None + self._session_task = None + self._session_commands = None self._edges.clear() async def _session_command( @@ -222,57 +309,41 @@ async def _session_command( command: Command, *args: str, payload: bytes | None = None, - expect_snapshot: bool = False, + response_kind: _SessionResponseKind = _SessionResponseKind.BYTE, ) -> typing.Any: - reader = self._session_reader - writer = self._session_writer - if reader is None or writer is None: - raise RuntimeError("GraphContext session is not active") - - async with self._session_lock: - writer.write(command.value) - for arg in args: - writer.write(encode_str(arg)) - if payload is not None: - writer.write(uint64_to_bytes(len(payload))) - writer.write(payload) - await writer.drain() - - if expect_snapshot: - num_bytes = await read_int(reader) - snapshot_bytes = await reader.readexactly(num_bytes) - response = await reader.read(1) - if response != Command.COMPLETE.value: - raise RuntimeError("Unexpected response to session snapshot") - return pickle.loads(snapshot_bytes) - - return await reader.read(1) - - async def register_metadata(self, metadata: dict[str, typing.Any]) -> None: - if self._session_writer is None: - logger.warning("No active GraphContext session; metadata registration skipped") - return + commands, task = self._require_session() + if task.done(): + raise RuntimeError("GraphContext session task is not running") + + response_fut: asyncio.Future[typing.Any] = asyncio.get_running_loop().create_future() + await commands.put( + _SessionCommand( + command=command, + args=tuple(args), + payload=payload, + response_kind=response_kind, + response_fut=response_fut, + ) + ) + return await response_fut + async def register_metadata(self, metadata: GraphMetadata) -> None: + payload = pickle.dumps(metadata) response = await self._session_command( - Command.SESSION_REGISTER, payload=pickle.dumps(metadata) + Command.SESSION_REGISTER, + payload=payload, + response_kind=_SessionResponseKind.BYTE, ) if response != Command.COMPLETE.value: raise RuntimeError("Unexpected response to session metadata registration") - async def snapshot(self) -> dict[str, typing.Any]: - if self._session_writer is None: - dag = await GraphService(self.graph_address).dag() - return { - "graph": {node: sorted(conns) for node, conns in dag.graph.items()}, - "edge_owners": [], - "sessions": {}, - } - + async def snapshot(self) -> GraphSnapshot: snapshot = await self._session_command( - Command.SESSION_SNAPSHOT, expect_snapshot=True + Command.SESSION_SNAPSHOT, + response_kind=_SessionResponseKind.SNAPSHOT, ) - if not isinstance(snapshot, dict): - raise RuntimeError("Session snapshot payload was not a dictionary") + if not isinstance(snapshot, GraphSnapshot): + raise RuntimeError("Session snapshot payload was not a GraphSnapshot") return snapshot async def _shutdown_servers(self) -> None: @@ -315,25 +386,21 @@ async def revert(self) -> None: if self._session_writer is not None: try: - response = await self._session_command(Command.SESSION_CLEAR) + response = await self._session_command( + Command.SESSION_CLEAR, + response_kind=_SessionResponseKind.BYTE, + ) if response != Command.COMPLETE.value: - logger.warning("GraphServer returned unexpected response to SESSION_CLEAR") + logger.warning( + "GraphServer returned unexpected response to SESSION_CLEAR" + ) except ( ConnectionRefusedError, BrokenPipeError, ConnectionResetError, asyncio.IncompleteReadError, + RuntimeError, ) as e: logger.warning(f"Could not clear GraphContext session state: {e}") - else: - for edge in self._edges: - try: - await GraphService(self.graph_address).disconnect(*edge) - except ( - ConnectionRefusedError, - BrokenPipeError, - ConnectionResetError, - ) as e: - logger.warning(f"Could not remove edge {edge} from GraphServer: {e}") self._edges.clear() diff --git a/src/ezmsg/core/graphmeta.py b/src/ezmsg/core/graphmeta.py new file mode 100644 index 00000000..c02f67f6 --- /dev/null +++ b/src/ezmsg/core/graphmeta.py @@ -0,0 +1,101 @@ +from dataclasses import dataclass, field +from typing import Any, TypeAlias, NamedTuple + + +@dataclass +class DynamicSettingsMetadata: + enabled: bool + input_topic: str | None + settings_type: str | None + + +@dataclass +class StreamMetadata: + name: str + address: str + msg_type: str + + +@dataclass +class InputStreamMetadata(StreamMetadata): + leaky: bool = False + max_queue: int | None = None + + +@dataclass +class OutputStreamMetadata(StreamMetadata): + host: str | None = None + port: int | None = None + num_buffers: int | None = None + buf_size: int | None = None + force_tcp: bool | None = None + + +StreamMetadataType: TypeAlias = ( + StreamMetadata | InputStreamMetadata | OutputStreamMetadata +) + + +@dataclass +class TaskMetadata: + name: str + subscribes: str | None = None + publishes: list[str] = field(default_factory=list) + + +SettingsReprType: TypeAlias = dict[str, Any] | str +SerializedSettingsType: TypeAlias = bytes | None +InitialSettingsType: TypeAlias = tuple[SerializedSettingsType, SettingsReprType] + + +@dataclass +class ComponentMetadata: + address: str + name: str + component_type: str + settings_type: str + initial_settings: InitialSettingsType + streams: dict[str, StreamMetadataType] + dynamic_settings: DynamicSettingsMetadata + + +@dataclass +class CollectionMetadata(ComponentMetadata): + children: list[str] + + +@dataclass +class UnitMetadata(ComponentMetadata): + tasks: list[TaskMetadata] + main: str | None + threads: list[str] + + +ComponentMetadataType: TypeAlias = ( + ComponentMetadata | CollectionMetadata | UnitMetadata +) + + +@dataclass +class GraphMetadata: + schema_version: int + root_name: str | None + components: dict[str, ComponentMetadataType] + + +class Edge(NamedTuple): + from_topic: str + to_topic: str + + +@dataclass +class SnapshotSession: + edges: list[Edge] + metadata: GraphMetadata | None + + +@dataclass +class GraphSnapshot: + graph: dict[str, list[str]] + edge_owners: dict[Edge, list[str]] + sessions: dict[str, SnapshotSession] diff --git a/src/ezmsg/core/graphserver.py b/src/ezmsg/core/graphserver.py index 4a6ebcee..873f120b 100644 --- a/src/ezmsg/core/graphserver.py +++ b/src/ezmsg/core/graphserver.py @@ -4,20 +4,24 @@ import os import socket import threading -import warnings -from dataclasses import dataclass, field from contextlib import suppress from uuid import UUID, uuid1 -from typing import Any from . import __version__ from .dag import DAG, CyclicException from .graph_util import get_compactified_graph, graph_string, prune_graph_connections +from .graphmeta import ( + Edge, + GraphMetadata, + GraphSnapshot, + SnapshotSession, +) from .netprotocol import ( Address, Command, ClientInfo, + SessionInfo, SubscriberInfo, PublisherInfo, ChannelInfo, @@ -38,16 +42,7 @@ from .shm import SHMContext, SHMInfo logger = logging.getLogger("ezmsg") - -LEGACY_OWNER = "legacy" - - -@dataclass -class SessionInfo: - id: UUID - writer: asyncio.StreamWriter - edges: set[tuple[str, str]] = field(default_factory=set) - metadata: dict[str, Any] = field(default_factory=dict) +PERSISTENT_EDGE_OWNER = None class GraphServer(threading.Thread): @@ -74,8 +69,6 @@ class GraphServer(threading.Thread): graph: DAG clients: dict[UUID, ClientInfo] - sessions: dict[UUID, SessionInfo] - edge_owners: dict[tuple[str, str], set[UUID | str]] shms: dict[str, SHMInfo] _client_tasks: dict[UUID, "asyncio.Task[None]"] @@ -92,8 +85,6 @@ def __init__(self, **kwargs) -> None: # graph/server data self.graph = DAG() self.clients = {} - self.sessions = {} - self.edge_owners = {} self._client_tasks = {} self.shms = {} self._address = None @@ -280,7 +271,7 @@ async def api( elif req == Command.SESSION.value: session_id = uuid1() - self.sessions[session_id] = SessionInfo(session_id, writer) + self.clients[session_id] = SessionInfo(session_id, writer) writer.write(encode_str(str(session_id))) writer.write(Command.COMPLETE.value) await writer.drain() @@ -338,11 +329,11 @@ async def api( try: if req == Command.CONNECT.value: topology_changed = self._connect_owner( - from_topic, to_topic, LEGACY_OWNER + from_topic, to_topic, PERSISTENT_EDGE_OWNER ) else: topology_changed = self._disconnect_owner( - from_topic, to_topic, LEGACY_OWNER + from_topic, to_topic, PERSISTENT_EDGE_OWNER ) writer.write(Command.COMPLETE.value) except CyclicException: @@ -426,9 +417,7 @@ async def _handle_client( finally: # Ensure any waiter on this client unblocks - info = self.clients.get(client_id) - if info is not None: - info.set_sync() + self.clients[client_id].set_sync() self.clients.pop(client_id, None) await close_stream_writer(writer) @@ -447,62 +436,30 @@ async def _handle_session( if not req: break - if req in [Command.SESSION_CONNECT.value, Command.SESSION_DISCONNECT.value]: - from_topic = await read_str(reader) - to_topic = await read_str(reader) - - async with self._command_lock: - topology_changed = False - try: - if req == Command.SESSION_CONNECT.value: - topology_changed = self._connect_owner( - from_topic, to_topic, session_id - ) - else: - topology_changed = self._disconnect_owner( - from_topic, to_topic, session_id - ) - except CyclicException: - writer.write(Command.CYCLIC.value) - await writer.drain() - continue - - if topology_changed: - await self._notify_downstream_for_topic(to_topic) - - writer.write(Command.COMPLETE.value) + if req in [ + Command.SESSION_CONNECT.value, + Command.SESSION_DISCONNECT.value, + ]: + response = await self._handle_session_edge_request( + session_id, req, reader + ) + writer.write(response) await writer.drain() elif req == Command.SESSION_CLEAR.value: - async with self._command_lock: - notify_topics = self._clear_session_state(session_id) - for topic in notify_topics: - await self._notify_downstream_for_topic(topic) - - writer.write(Command.COMPLETE.value) + response = await self._handle_session_clear_request(session_id) + writer.write(response) await writer.drain() elif req == Command.SESSION_REGISTER.value: - num_bytes = await read_int(reader) - payload = await reader.readexactly(num_bytes) - metadata = pickle.loads(payload) - - async with self._command_lock: - session = self.sessions.get(session_id) - if session is not None: - session.metadata = metadata - writer.write(Command.COMPLETE.value) - + response = await self._handle_session_register_request( + session_id, reader + ) + writer.write(response) await writer.drain() elif req == Command.SESSION_SNAPSHOT.value: - async with self._command_lock: - snapshot = self._snapshot() - snapshot_bytes = pickle.dumps(snapshot) - writer.write(uint64_to_bytes(len(snapshot_bytes))) - writer.write(snapshot_bytes) - writer.write(Command.COMPLETE.value) - + await self._handle_session_snapshot_request(writer) await writer.drain() else: @@ -524,57 +481,102 @@ async def _handle_session( await close_stream_writer(writer) def _connect_owner( - self, from_topic: str, to_topic: str, owner: UUID | str + self, from_topic: str, to_topic: str, owner: UUID | str | None ) -> bool: - edge = (from_topic, to_topic) - owners = self.edge_owners.setdefault(edge, set()) - - if owner in owners: - return False - - topology_changed = len(owners) == 0 - if topology_changed: - try: - self.graph.add_edge(from_topic, to_topic) - except CyclicException: - if len(owners) == 0: - self.edge_owners.pop(edge, None) - raise - - owners.add(owner) - + topology_changed = self.graph.add_edge(from_topic, to_topic, owner=owner) if isinstance(owner, UUID): - session = self.sessions.get(owner) + session = self._session_info(owner) if session is not None: - session.edges.add(edge) - + session.edges.add((from_topic, to_topic)) return topology_changed def _disconnect_owner( - self, from_topic: str, to_topic: str, owner: UUID | str + self, from_topic: str, to_topic: str, owner: UUID | str | None ) -> bool: - edge = (from_topic, to_topic) - owners = self.edge_owners.get(edge) - if owners is None or owner not in owners: - return False - - owners.remove(owner) - + topology_changed = self.graph.remove_edge(from_topic, to_topic, owner=owner) if isinstance(owner, UUID): - session = self.sessions.get(owner) + session = self._session_info(owner) if session is not None: - session.edges.discard(edge) + session.edges.discard((from_topic, to_topic)) + return topology_changed - if len(owners) == 0: - self.edge_owners.pop(edge, None) - self.graph.remove_edge(from_topic, to_topic) - return True + def _session_info(self, session_id: UUID) -> SessionInfo | None: + info = self.clients.get(session_id) + if isinstance(info, SessionInfo): + return info + return None - return False + async def _handle_session_edge_request( + self, + session_id: UUID, + req: bytes, + reader: asyncio.StreamReader, + ) -> bytes: + from_topic = await read_str(reader) + to_topic = await read_str(reader) + + async with self._command_lock: + try: + if req == Command.SESSION_CONNECT.value: + topology_changed = self._connect_owner( + from_topic, to_topic, session_id + ) + else: + topology_changed = self._disconnect_owner( + from_topic, to_topic, session_id + ) + except CyclicException: + return Command.CYCLIC.value + + if topology_changed: + await self._notify_downstream_for_topic(to_topic) + + return Command.COMPLETE.value + + async def _handle_session_clear_request(self, session_id: UUID) -> bytes: + async with self._command_lock: + notify_topics = self._clear_session_state(session_id) + for topic in notify_topics: + await self._notify_downstream_for_topic(topic) + return Command.COMPLETE.value + + async def _handle_session_register_request( + self, session_id: UUID, reader: asyncio.StreamReader + ) -> bytes: + num_bytes = await read_int(reader) + payload = await reader.readexactly(num_bytes) + metadata: GraphMetadata | None = None + try: + metadata_obj = pickle.loads(payload) + if isinstance(metadata_obj, GraphMetadata): + metadata = metadata_obj + else: + raise RuntimeError("metadata payload was not GraphMetadata") + except Exception as exc: + logger.warning( + f"Session {session_id} metadata parse failed; ignoring payload: {exc}" + ) + + async with self._command_lock: + session = self._session_info(session_id) + if session is not None and metadata is not None: + session.metadata = metadata + + return Command.COMPLETE.value + + async def _handle_session_snapshot_request( + self, writer: asyncio.StreamWriter + ) -> None: + async with self._command_lock: + snapshot = self._snapshot() + snapshot_bytes = pickle.dumps(snapshot) + writer.write(uint64_to_bytes(len(snapshot_bytes))) + writer.write(snapshot_bytes) + writer.write(Command.COMPLETE.value) def _clear_session_state(self, session_id: UUID) -> set[str]: notify_topics: set[str] = set() - session = self.sessions.get(session_id) + session = self._session_info(session_id) if session is None: return notify_topics @@ -582,12 +584,12 @@ def _clear_session_state(self, session_id: UUID) -> set[str]: if self._disconnect_owner(from_topic, to_topic, session_id): notify_topics.add(to_topic) - session.metadata.clear() + session.metadata = None return notify_topics def _drop_session(self, session_id: UUID) -> set[str]: notify_topics: set[str] = set() - session = self.sessions.pop(session_id, None) + session = self._session_info(session_id) if session is None: return notify_topics @@ -595,33 +597,42 @@ def _drop_session(self, session_id: UUID) -> set[str]: if self._disconnect_owner(from_topic, to_topic, session_id): notify_topics.add(to_topic) - session.metadata.clear() + session.metadata = None + self.clients.pop(session_id, None) return notify_topics - def _snapshot(self) -> dict[str, Any]: + def _snapshot(self) -> GraphSnapshot: graph = {node: sorted(conns) for node, conns in self.graph.graph.items()} - edge_owners = [ - { - "from_topic": from_topic, - "to_topic": to_topic, - "owners": [str(owner) for owner in sorted(owners, key=str)], - } - for (from_topic, to_topic), owners in sorted(self.edge_owners.items()) - ] + edge_owners = { + Edge(from_topic=from_topic, to_topic=to_topic): [ + "persistent" if owner is None else str(owner) + for owner in sorted( + owners, key=lambda owner: "" if owner is None else str(owner) + ) + ] + for (from_topic, to_topic), owners in sorted(self.graph.edge_owners.items()) + } sessions = { - str(session_id): { - "edges": sorted( + str(session_id): SnapshotSession( + edges=sorted( [ - {"from_topic": from_topic, "to_topic": to_topic} + Edge(from_topic=from_topic, to_topic=to_topic) for from_topic, to_topic in session.edges ], - key=lambda edge: (edge["from_topic"], edge["to_topic"]), + key=lambda edge: (edge.from_topic, edge.to_topic), ), - "metadata": session.metadata, - } - for session_id, session in sorted(self.sessions.items(), key=lambda item: str(item[0])) + metadata=session.metadata, + ) + for session_id, session in sorted( + [ + (client_id, info) + for client_id, info in self.clients.items() + if isinstance(info, SessionInfo) + ], + key=lambda item: str(item[0]), + ) } - return {"graph": graph, "edge_owners": edge_owners, "sessions": sessions} + return GraphSnapshot(graph=graph, edge_owners=edge_owners, sessions=sessions) async def _notify_downstream_for_topic(self, topic: str) -> None: for sub in self._downstream_subs(topic): @@ -726,12 +737,6 @@ async def open_connection( return reader, writer async def connect(self, from_topic: str, to_topic: str) -> None: - warnings.warn( - "GraphService.connect is deprecated. Prefer GraphContext.connect " - "to use a session-scoped control plane.", - DeprecationWarning, - stacklevel=2, - ) reader, writer = await self.open_connection() writer.write(Command.CONNECT.value) writer.write(encode_str(from_topic)) @@ -744,12 +749,6 @@ async def connect(self, from_topic: str, to_topic: str) -> None: await close_stream_writer(writer) async def disconnect(self, from_topic: str, to_topic: str) -> None: - warnings.warn( - "GraphService.disconnect is deprecated. Prefer GraphContext.disconnect " - "to use a session-scoped control plane.", - DeprecationWarning, - stacklevel=2, - ) reader, writer = await self.open_connection() writer.write(Command.DISCONNECT.value) writer.write(encode_str(from_topic)) diff --git a/src/ezmsg/core/netprotocol.py b/src/ezmsg/core/netprotocol.py index d640d032..dca58583 100644 --- a/src/ezmsg/core/netprotocol.py +++ b/src/ezmsg/core/netprotocol.py @@ -9,6 +9,7 @@ from dataclasses import field, dataclass from contextlib import asynccontextmanager from asyncio.base_events import Server +from .graphmeta import GraphMetadata VERSION = b"1" UINT64_SIZE = 8 @@ -165,6 +166,16 @@ class ChannelInfo(ClientInfo): pub_id: UUID +@dataclass +class SessionInfo(ClientInfo): + """ + Session-scoped control-plane client information. + """ + + edges: set[tuple[str, str]] = field(default_factory=set) + metadata: GraphMetadata | None = None + + def uint64_to_bytes(i: int) -> bytes: """ Convert a 64-bit unsigned integer to bytes. diff --git a/tests/test_graph_session.py b/tests/test_graph_session.py new file mode 100644 index 00000000..f3969d32 --- /dev/null +++ b/tests/test_graph_session.py @@ -0,0 +1,106 @@ +import asyncio + +import pytest + +from ezmsg.core.graphcontext import GraphContext +from ezmsg.core.graphmeta import GraphMetadata, GraphSnapshot +from ezmsg.core.graphserver import GraphService + + +def _edge_exists(snapshot: GraphSnapshot, from_topic: str, to_topic: str) -> bool: + return to_topic in snapshot.graph.get(from_topic, []) + + +@pytest.mark.asyncio +async def test_session_drop_cleans_owned_edges(): + graph_server = GraphService().create_server() + address = graph_server.address + + owner = GraphContext(address, auto_start=False) + observer = GraphContext(address, auto_start=False) + + await owner.__aenter__() + await observer.__aenter__() + + try: + await owner.connect("SRC", "DST") + + snapshot = await observer.snapshot() + assert _edge_exists(snapshot, "SRC", "DST") + + await owner._close_session() + await asyncio.sleep(0.05) + + snapshot = await observer.snapshot() + assert not _edge_exists(snapshot, "SRC", "DST") + finally: + await owner.__aexit__(None, None, None) + await observer.__aexit__(None, None, None) + graph_server.stop() + + +@pytest.mark.asyncio +async def test_shared_edge_survives_until_last_session_drops(): + graph_server = GraphService().create_server() + address = graph_server.address + + owner_a = GraphContext(address, auto_start=False) + owner_b = GraphContext(address, auto_start=False) + + await owner_a.__aenter__() + await owner_b.__aenter__() + + try: + await owner_a.connect("SRC", "DST") + await owner_b.connect("SRC", "DST") + + await owner_a._close_session() + await asyncio.sleep(0.05) + + snapshot = await owner_b.snapshot() + assert _edge_exists(snapshot, "SRC", "DST") + + await owner_b._close_session() + await asyncio.sleep(0.05) + + dag = await GraphService(address).dag() + assert "DST" not in dag.graph.get("SRC", set()) + finally: + await owner_a.__aexit__(None, None, None) + await owner_b.__aexit__(None, None, None) + graph_server.stop() + + +@pytest.mark.asyncio +async def test_session_metadata_drops_with_session(): + graph_server = GraphService().create_server() + address = graph_server.address + + owner = GraphContext(address, auto_start=False) + observer = GraphContext(address, auto_start=False) + + await owner.__aenter__() + await observer.__aenter__() + + try: + metadata = GraphMetadata( + schema_version=1, + root_name="TEST", + components={}, + ) + await owner.register_metadata(metadata) + + owner_session_id = str(owner._session_id) + snapshot = await observer.snapshot() + assert owner_session_id in snapshot.sessions + assert snapshot.sessions[owner_session_id].metadata == metadata + + await owner._close_session() + await asyncio.sleep(0.05) + + snapshot = await observer.snapshot() + assert owner_session_id not in snapshot.sessions + finally: + await owner.__aexit__(None, None, None) + await observer.__aexit__(None, None, None) + graph_server.stop() From cfe283c4f71d34bc5f7e1a1bb4fee0e201dbe0a2 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Mon, 2 Mar 2026 16:11:35 -0500 Subject: [PATCH 04/52] topic and relay implementation --- src/ezmsg/core/__init__.py | 15 +++- src/ezmsg/core/backend.py | 160 ++++++++++++++++++++++++++++++----- src/ezmsg/core/collection.py | 13 ++- src/ezmsg/core/relay.py | 42 +++++++++ src/ezmsg/core/stream.py | 104 +++++++++++++++++++++++ src/ezmsg/core/unit.py | 8 +- tests/test_topics.py | 156 ++++++++++++++++++++++++++++++++++ 7 files changed, 476 insertions(+), 22 deletions(-) create mode 100644 src/ezmsg/core/relay.py create mode 100644 tests/test_topics.py diff --git a/src/ezmsg/core/__init__.py b/src/ezmsg/core/__init__.py index bc56a3b9..aaf36899 100644 --- a/src/ezmsg/core/__init__.py +++ b/src/ezmsg/core/__init__.py @@ -14,6 +14,11 @@ "Settings", "Collection", "NetworkDefinition", + "Topic", + "InputTopic", + "OutputTopic", + "InputRelay", + "OutputRelay", "InputStream", "OutputStream", "Unit", @@ -44,7 +49,15 @@ from .settings import Settings from .collection import Collection, NetworkDefinition from .unit import Unit, task, publisher, subscriber, main, timeit, process, thread -from .stream import InputStream, OutputStream +from .stream import ( + Topic, + InputTopic, + OutputTopic, + InputRelay, + OutputRelay, + InputStream, + OutputStream, +) from .backend import run, GraphRunner, GraphRunnerStartError from .backendprocess import Complete, NormalTermination from .graphserver import GraphServer diff --git a/src/ezmsg/core/backend.py b/src/ezmsg/core/backend.py index 7a0aba1d..910f81a8 100644 --- a/src/ezmsg/core/backend.py +++ b/src/ezmsg/core/backend.py @@ -5,6 +5,7 @@ import logging import os import signal +from dataclasses import dataclass from threading import BrokenBarrierError from multiprocessing import Event, Barrier from multiprocessing.synchronize import Event as EventType @@ -16,8 +17,9 @@ from .collection import Collection, NetworkDefinition from .component import Component -from .stream import Stream +from .stream import Stream, InputRelay, OutputRelay from .unit import Unit, PROCESS_ATTR +from .relay import _CollectionRelayUnit, _RelaySettings from .graphserver import GraphService from .graphcontext import GraphContext @@ -33,6 +35,16 @@ logger = logging.getLogger("ezmsg") +@dataclass +class _RelayBinding: + kind: str # "input" or "output" + endpoint_topic: str + relay_in_topic: str + relay_out_topic: str + endpoint: InputRelay | OutputRelay + relay_unit: _CollectionRelayUnit + + class ExecutionContext: _process_units: list[list[Unit]] _processes: list[BackendProcess] | None @@ -95,22 +107,32 @@ def setup( start_participant: bool = False, ) -> "ExecutionContext | None": graph_connections: list[tuple[str, str]] = [] + relay_bindings: dict[str, _RelayBinding] = {} for name, component in components.items(): component._set_name(name) component._set_location([root_name] if root_name is not None else []) + def normalize_topic(endpoint: Stream | str | enum.Enum, where: str) -> str: + if isinstance(endpoint, Stream): + return endpoint.address + if isinstance(endpoint, enum.Enum): + return endpoint.name + if isinstance(endpoint, str): + return endpoint + raise TypeError( + f"Invalid endpoint type in {where}: {type(endpoint)}. " + "Expected Stream, str, or Enum." + ) + if connections is not None: for from_topic, to_topic in connections: - if isinstance(from_topic, Stream): - from_topic = from_topic.address - if isinstance(to_topic, Stream): - to_topic = to_topic.address - if isinstance(to_topic, enum.Enum): - to_topic = to_topic.name - if isinstance(from_topic, enum.Enum): - from_topic = from_topic.name - graph_connections.append((from_topic, to_topic)) + graph_connections.append( + ( + normalize_topic(from_topic, "connections"), + normalize_topic(to_topic, "connections"), + ) + ) def crawl_components( component: Component, callback: Callable[[Component], None] @@ -121,23 +143,115 @@ def crawl_components( search += list(comp.components.values()) callback(comp) + def input_relay_settings(relay: InputRelay) -> _RelaySettings: + return _RelaySettings( + leaky=relay.leaky, + max_queue=relay.max_queue, + copy_on_forward=relay.copy_on_forward, + ) + + def output_relay_settings(relay: OutputRelay) -> _RelaySettings: + return _RelaySettings( + host=relay.host, + port=relay.port, + num_buffers=relay.num_buffers, + buf_size=relay.buf_size, + force_tcp=relay.force_tcp, + copy_on_forward=relay.copy_on_forward, + ) + + def add_collection_relay_units(comp: Component) -> None: + if not isinstance(comp, Collection): + return + + for endpoint_name, endpoint in comp.streams.items(): + if isinstance(endpoint, InputRelay): + relay_name = f"__relay_in_{endpoint_name}" + if relay_name in comp.components: + raise ValueError( + f"{comp.address} already defines component '{relay_name}'." + ) + + relay_unit = _CollectionRelayUnit(input_relay_settings(endpoint)) + relay_unit._set_name(relay_name) + relay_unit._set_location(comp.location + [comp.name]) + comp.components[relay_name] = relay_unit + setattr(comp, relay_name, relay_unit) + + relay_bindings[endpoint.address] = _RelayBinding( + kind="input", + endpoint_topic=endpoint.address, + relay_in_topic=relay_unit.INPUT.address, + relay_out_topic=relay_unit.OUTPUT.address, + endpoint=endpoint, + relay_unit=relay_unit, + ) + + elif isinstance(endpoint, OutputRelay): + relay_name = f"__relay_out_{endpoint_name}" + if relay_name in comp.components: + raise ValueError( + f"{comp.address} already defines component '{relay_name}'." + ) + + relay_unit = _CollectionRelayUnit(output_relay_settings(endpoint)) + relay_unit._set_name(relay_name) + relay_unit._set_location(comp.location + [comp.name]) + comp.components[relay_name] = relay_unit + setattr(comp, relay_name, relay_unit) + + relay_bindings[endpoint.address] = _RelayBinding( + kind="output", + endpoint_topic=endpoint.address, + relay_in_topic=relay_unit.INPUT.address, + relay_out_topic=relay_unit.OUTPUT.address, + endpoint=endpoint, + relay_unit=relay_unit, + ) + + for component in components.values(): + if isinstance(component, Collection): + crawl_components(component, add_collection_relay_units) + def gather_edges(comp: Component): if isinstance(comp, Collection): for from_stream, to_stream in comp.network(): - if isinstance(from_stream, Stream): - from_stream = from_stream.address - if isinstance(to_stream, Stream): - to_stream = to_stream.address - if isinstance(to_stream, enum.Enum): - to_stream = to_stream.name - if isinstance(from_stream, enum.Enum): - from_stream = from_stream.name - graph_connections.append((from_stream, to_stream)) + graph_connections.append( + ( + normalize_topic(from_stream, f"{comp.address}.network"), + normalize_topic(to_stream, f"{comp.address}.network"), + ) + ) for component in components.values(): if isinstance(component, Collection): crawl_components(component, gather_edges) + if relay_bindings: + rewritten_connections: list[tuple[str, str]] = [] + for from_topic, to_topic in graph_connections: + to_binding = relay_bindings.get(to_topic, None) + if to_binding is not None and to_binding.kind == "output": + to_topic = to_binding.relay_in_topic + + from_binding = relay_bindings.get(from_topic, None) + if from_binding is not None and from_binding.kind == "input": + from_topic = from_binding.relay_out_topic + + rewritten_connections.append((from_topic, to_topic)) + + for binding in relay_bindings.values(): + if binding.kind == "input": + rewritten_connections.append( + (binding.endpoint_topic, binding.relay_in_topic) + ) + else: + rewritten_connections.append( + (binding.relay_out_topic, binding.endpoint_topic) + ) + + graph_connections = rewritten_connections + processes = collect_processes(components.values(), process_components) for component in components.values(): @@ -149,6 +263,14 @@ def configure_collections(comp: Component): crawl_components(component, configure_collections) + for binding in relay_bindings.values(): + if isinstance(binding.endpoint, InputRelay): + binding.relay_unit.apply_settings(input_relay_settings(binding.endpoint)) + elif isinstance(binding.endpoint, OutputRelay): + binding.relay_unit.apply_settings( + output_relay_settings(binding.endpoint) + ) + if force_single_process: processes = [[u for pu in processes for u in pu]] diff --git a/src/ezmsg/core/collection.py b/src/ezmsg/core/collection.py index dd8f665d..1f1f51e6 100644 --- a/src/ezmsg/core/collection.py +++ b/src/ezmsg/core/collection.py @@ -2,8 +2,9 @@ from collections.abc import Collection as AbstractCollection import typing from copy import deepcopy +import warnings -from .stream import Stream +from .stream import Stream, InputStream, OutputStream from .component import ComponentMeta, Component from .settings import Settings @@ -34,6 +35,16 @@ def __init__( if isinstance(field_value, Component): field_value._set_name(field_name) cls.__components__[field_name] = field_value + elif isinstance(field_value, (InputStream, OutputStream)): + warnings.warn( + f"{name}.{field_name} uses {type(field_value).__name__} as a " + "Collection boundary endpoint. This behavior is deprecated and " + "will change in a future release. Use InputTopic / OutputTopic " + "for zero-overhead topic shortcuts, or InputRelay / OutputRelay " + "for explicit boundary republishers.", + FutureWarning, + stacklevel=2, + ) class Collection(Component, metaclass=CollectionMeta): diff --git a/src/ezmsg/core/relay.py b/src/ezmsg/core/relay.py new file mode 100644 index 00000000..830b7bb6 --- /dev/null +++ b/src/ezmsg/core/relay.py @@ -0,0 +1,42 @@ +from copy import deepcopy +from collections.abc import AsyncGenerator +from typing import Any + +from .settings import Settings +from .netprotocol import DEFAULT_SHM_SIZE +from .stream import InputStream, OutputStream +from .unit import Unit, publisher, subscriber + + +class _RelaySettings(Settings): + leaky: bool = False + max_queue: int | None = None + host: str | None = None + port: int | None = None + num_buffers: int = 32 + buf_size: int = DEFAULT_SHM_SIZE + force_tcp: bool = False + copy_on_forward: bool = True + + +class _CollectionRelayUnit(Unit): + SETTINGS = _RelaySettings + + INPUT = InputStream(Any) + OUTPUT = OutputStream(Any) + + async def initialize(self) -> None: + self.INPUT.leaky = self.SETTINGS.leaky + self.INPUT.max_queue = self.SETTINGS.max_queue + self.OUTPUT.host = self.SETTINGS.host + self.OUTPUT.port = self.SETTINGS.port + self.OUTPUT.num_buffers = self.SETTINGS.num_buffers + self.OUTPUT.buf_size = self.SETTINGS.buf_size + self.OUTPUT.force_tcp = self.SETTINGS.force_tcp + + @subscriber(INPUT) + @publisher(OUTPUT) + async def relay(self, msg: Any) -> AsyncGenerator: + if self.SETTINGS.copy_on_forward: + msg = deepcopy(msg) + yield self.OUTPUT, msg diff --git a/src/ezmsg/core/stream.py b/src/ezmsg/core/stream.py index c719c92a..26be9dfc 100644 --- a/src/ezmsg/core/stream.py +++ b/src/ezmsg/core/stream.py @@ -26,6 +26,110 @@ def __repr__(self) -> str: return f"Stream:{_addr}[{self.msg_type.__name__}]" +class Topic(Stream): + """ + Graph endpoint metadata for Collection boundaries and graph wiring. + + Topics represent named DAG nodes only. Unlike InputStream / OutputStream, + they do not directly configure Subscriber / Publisher transport behavior. + """ + + def __repr__(self) -> str: + return f"Topic{super().__repr__()}()" + + +class InputTopic(Topic): + """ + Directional alias for a Collection input topic. + """ + + def __repr__(self) -> str: + return f"Input{super().__repr__()}" + + +class OutputTopic(Topic): + """ + Directional alias for a Collection output topic. + """ + + def __repr__(self) -> str: + return f"Output{super().__repr__()}" + + +class InputRelay(InputTopic): + """ + Collection input boundary that materializes an internal relay subscriber/publisher. + + This enables subscriber-side behavior (e.g., leaky reception) on the boundary. + """ + + leaky: bool + max_queue: int | None + copy_on_forward: bool + + def __init__( + self, + msg_type: Any, + leaky: bool = False, + max_queue: int | None = None, + copy_on_forward: bool = True, + ) -> None: + super().__init__(msg_type) + if max_queue is not None and max_queue <= 0: + raise ValueError("max_queue must be positive") + self.leaky = leaky + self.max_queue = max_queue + self.copy_on_forward = copy_on_forward + + def __repr__(self) -> str: + base = f"InputRelay{Stream.__repr__(self)}" + return ( + f"{base}(leaky={self.leaky}, max_queue={self.max_queue}, " + f"copy_on_forward={self.copy_on_forward})" + ) + + +class OutputRelay(OutputTopic): + """ + Collection output boundary that materializes an internal relay subscriber/publisher. + + This enables publisher-side behavior (e.g., custom transport buffer settings) + on the boundary. + """ + + host: str | None + port: int | None + num_buffers: int + buf_size: int + force_tcp: bool + copy_on_forward: bool + + def __init__( + self, + msg_type: Any, + host: str | None = None, + port: int | None = None, + num_buffers: int = 32, + buf_size: int = DEFAULT_SHM_SIZE, + force_tcp: bool = False, + copy_on_forward: bool = True, + ) -> None: + super().__init__(msg_type) + self.host = host + self.port = port + self.num_buffers = num_buffers + self.buf_size = buf_size + self.force_tcp = force_tcp + self.copy_on_forward = copy_on_forward + + def __repr__(self) -> str: + base = f"OutputRelay{Stream.__repr__(self)}" + return ( + f"{base}(num_buffers={self.num_buffers}, force_tcp={self.force_tcp}, " + f"copy_on_forward={self.copy_on_forward})" + ) + + class InputStream(Stream): """ Can be added to any Component as a member variable. Methods may subscribe to it. diff --git a/src/ezmsg/core/unit.py b/src/ezmsg/core/unit.py index 8527e71d..0fc8b66b 100644 --- a/src/ezmsg/core/unit.py +++ b/src/ezmsg/core/unit.py @@ -2,7 +2,7 @@ import inspect import functools import warnings -from .stream import InputStream, OutputStream +from .stream import InputStream, OutputStream, Topic from .component import ComponentMeta, Component from .settings import Settings @@ -55,6 +55,12 @@ def __init__( cls.__threads__[thread_name] = thread for field_name, field_value in fields.items(): + if isinstance(field_value, Topic): + raise TypeError( + f"{name}.{field_name} is a {type(field_value).__name__}. " + "Units may only declare InputStream / OutputStream endpoints. " + "Use Topic / Relay endpoints on Collections only." + ) if callable(field_value): if hasattr(field_value, TASK_ATTR): cls.__tasks__[field_name] = field_value diff --git a/tests/test_topics.py b/tests/test_topics.py new file mode 100644 index 00000000..2ca691ca --- /dev/null +++ b/tests/test_topics.py @@ -0,0 +1,156 @@ +import pytest + +import ezmsg.core as ez + +from ezmsg.core.backend import ExecutionContext + + +@pytest.mark.parametrize( + "endpoint_factory", + [ + lambda: ez.Topic(int), + lambda: ez.InputTopic(int), + lambda: ez.OutputTopic(int), + lambda: ez.InputRelay(int), + lambda: ez.OutputRelay(int), + ], +) +def test_unit_rejects_topic_endpoints(endpoint_factory): + with pytest.raises(TypeError, match="Units may only declare InputStream"): + + class BadUnit(ez.Unit): + ENDPOINT = endpoint_factory() + + +def test_collection_stream_endpoint_warns_futurewarning(): + with pytest.warns(FutureWarning, match="deprecated"): + + class LegacyCollection(ez.Collection): + INPUT = ez.InputStream(int) + + +class _Source(ez.Unit): + OUTPUT = ez.OutputStream(int) + + +class _Sink(ez.Unit): + INPUT = ez.InputStream(int) + + +class _TopicPassthrough(ez.Collection): + IN = ez.InputTopic(int) + OUT = ez.OutputTopic(int) + + def network(self) -> ez.NetworkDefinition: + return ((self.IN, self.OUT),) + + +class _RelayInputPassthrough(ez.Collection): + IN = ez.InputRelay(int, leaky=False, max_queue=None, copy_on_forward=True) + OUT = ez.OutputTopic(int) + + def configure(self) -> None: + self.IN.leaky = True + self.IN.max_queue = 7 + + def network(self) -> ez.NetworkDefinition: + return ((self.IN, self.OUT),) + + +class _RelayOutputPassthrough(ez.Collection): + IN = ez.InputTopic(int) + OUT = ez.OutputRelay(int, num_buffers=16, force_tcp=True, copy_on_forward=False) + + def configure(self) -> None: + self.OUT.num_buffers = 8 + + def network(self) -> ez.NetworkDefinition: + return ((self.IN, self.OUT),) + + +class _TopicSystem(ez.Collection): + SOURCE = _Source() + PASSTHROUGH = _TopicPassthrough() + SINK = _Sink() + + def network(self) -> ez.NetworkDefinition: + return ( + (self.SOURCE.OUTPUT, self.PASSTHROUGH.IN), + (self.PASSTHROUGH.OUT, self.SINK.INPUT), + ) + + +class _InputRelaySystem(ez.Collection): + SOURCE = _Source() + PASSTHROUGH = _RelayInputPassthrough() + SINK = _Sink() + + def network(self) -> ez.NetworkDefinition: + return ( + (self.SOURCE.OUTPUT, self.PASSTHROUGH.IN), + (self.PASSTHROUGH.OUT, self.SINK.INPUT), + ) + + +class _OutputRelaySystem(ez.Collection): + SOURCE = _Source() + PASSTHROUGH = _RelayOutputPassthrough() + SINK = _Sink() + + def network(self) -> ez.NetworkDefinition: + return ( + (self.SOURCE.OUTPUT, self.PASSTHROUGH.IN), + (self.PASSTHROUGH.OUT, self.SINK.INPUT), + ) + + +def test_input_output_topics_behave_as_shortcuts(): + system = _TopicSystem() + ctx = ExecutionContext.setup({"SYSTEM": system}) + assert ctx is not None + assert (system.SOURCE.OUTPUT.address, system.PASSTHROUGH.IN.address) in ctx.connections + assert (system.PASSTHROUGH.IN.address, system.PASSTHROUGH.OUT.address) in ctx.connections + assert (system.PASSTHROUGH.OUT.address, system.SINK.INPUT.address) in ctx.connections + + +def test_input_relay_rewrites_edges_and_syncs_settings(): + system = _InputRelaySystem() + ctx = ExecutionContext.setup({"SYSTEM": system}) + assert ctx is not None + + relay = system.PASSTHROUGH.components["__relay_in_IN"] + source = system.SOURCE.OUTPUT.address + endpoint_in = system.PASSTHROUGH.IN.address + endpoint_out = system.PASSTHROUGH.OUT.address + sink = system.SINK.INPUT.address + + assert (source, endpoint_in) in ctx.connections + assert (endpoint_in, relay.INPUT.address) in ctx.connections + assert (relay.OUTPUT.address, endpoint_out) in ctx.connections + assert (endpoint_out, sink) in ctx.connections + assert (endpoint_in, endpoint_out) not in ctx.connections + + assert relay.SETTINGS.leaky is True + assert relay.SETTINGS.max_queue == 7 + assert relay.SETTINGS.copy_on_forward is True + + +def test_output_relay_rewrites_edges_and_syncs_settings(): + system = _OutputRelaySystem() + ctx = ExecutionContext.setup({"SYSTEM": system}) + assert ctx is not None + + relay = system.PASSTHROUGH.components["__relay_out_OUT"] + source = system.SOURCE.OUTPUT.address + endpoint_in = system.PASSTHROUGH.IN.address + endpoint_out = system.PASSTHROUGH.OUT.address + sink = system.SINK.INPUT.address + + assert (source, endpoint_in) in ctx.connections + assert (endpoint_in, relay.INPUT.address) in ctx.connections + assert (relay.OUTPUT.address, endpoint_out) in ctx.connections + assert (endpoint_out, sink) in ctx.connections + + assert relay.SETTINGS.num_buffers == 8 + assert relay.SETTINGS.force_tcp is True + assert relay.SETTINGS.copy_on_forward is False From 70c0101b73ee885ffa832eb83831b50ed263e0a8 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Mon, 2 Mar 2026 16:18:11 -0500 Subject: [PATCH 05/52] adjust examples --- examples/ezmsg_configs.py | 12 ++++++------ examples/ezmsg_toy.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/ezmsg_configs.py b/examples/ezmsg_configs.py index 113295eb..4991acd8 100644 --- a/examples/ezmsg_configs.py +++ b/examples/ezmsg_configs.py @@ -82,8 +82,8 @@ async def listen(self, msg: int) -> None: class PassthroughCollection(ez.Collection): - INPUT = ez.InputStream(int) - OUTPUT = ez.OutputStream(int) + INPUT = ez.InputTopic(int) + OUTPUT = ez.OutputTopic(int) def network(self) -> ez.NetworkDefinition: return ((self.INPUT, self.OUTPUT),) @@ -136,7 +136,7 @@ def configure(self) -> None: class PubNoSubCollection(ez.Collection): - OUTPUT = ez.OutputStream(int) + OUTPUT = ez.OutputTopic(int) GENERATE = Generator() LOG = DebugLog() @@ -148,7 +148,7 @@ def network(self) -> ez.NetworkDefinition: class SubNoPubCollection(ez.Collection): - INPUT = ez.InputStream(int) + INPUT = ez.InputTopic(int) LISTEN = Listener() def network(self) -> ez.NetworkDefinition: @@ -175,7 +175,7 @@ class PubNoSubPassthroughCollection(ez.Collection): COLLECTION = PubNoSubCollection() PASSTHROUGH = PassthroughCollection() - OUTPUT = ez.OutputStream(int) + OUTPUT = ez.OutputTopic(int) def network(self) -> ez.NetworkDefinition: return ( @@ -188,7 +188,7 @@ class SubNoPubPassthroughCollection(ez.Collection): COLLECTION = SubNoPubCollection() PASSTHROUGH = PassthroughCollection() - INPUT = ez.InputStream(int) + INPUT = ez.InputTopic(int) def network(self) -> ez.NetworkDefinition: return ( diff --git a/examples/ezmsg_toy.py b/examples/ezmsg_toy.py index a5c5c772..08f51f5f 100644 --- a/examples/ezmsg_toy.py +++ b/examples/ezmsg_toy.py @@ -123,8 +123,8 @@ class ModifierCollection(ez.Collection): """This collection will subscribe to messages and append the most recent LFO output""" - INPUT = ez.InputStream(str) - OUTPUT = ez.OutputStream(str) + INPUT = ez.InputTopic(str) + OUTPUT = ez.OutputTopic(str) SIN = LFO() # SIN2 = LFO() From 2393c08c831277ef2789249a0ca0f647744e12f0 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Mon, 2 Mar 2026 16:50:29 -0500 Subject: [PATCH 06/52] merged #228 --- src/ezmsg/core/backend.py | 105 +++++++++++++++++++++++++++++++----- src/ezmsg/core/graphmeta.py | 42 ++++++++++++++- tests/test_topics.py | 33 ++++++++++++ 3 files changed, 167 insertions(+), 13 deletions(-) diff --git a/src/ezmsg/core/backend.py b/src/ezmsg/core/backend.py index 1c0b604d..4318f1e2 100644 --- a/src/ezmsg/core/backend.py +++ b/src/ezmsg/core/backend.py @@ -20,7 +20,16 @@ from .collection import Collection, NetworkDefinition from .component import Component -from .stream import Stream, InputStream, OutputStream, InputRelay, OutputRelay +from .stream import ( + Stream, + InputStream, + OutputStream, + Topic, + InputTopic, + OutputTopic, + InputRelay, + OutputRelay, +) from .unit import Unit, PROCESS_ATTR, SUBSCRIBES_ATTR, PUBLISHES_ATTR from .settings import Settings from .graphmeta import ( @@ -28,10 +37,17 @@ ComponentMetadata, ComponentMetadataType, DynamicSettingsMetadata, + InputRelayMetadata, InputStreamMetadata, + InputTopicMetadata, + OutputRelayMetadata, OutputStreamMetadata, + OutputTopicMetadata, + RelayMetadataType, StreamMetadataType, StreamMetadata, + TopicMetadata, + TopicMetadataType, TaskMetadata, GraphMetadata, UnitMetadata, @@ -431,6 +447,7 @@ def _component_metadata(self) -> GraphMetadata: for root in self._components.values(): for comp in crawl_components(root): + is_collection = isinstance(comp, Collection) input_settings = comp.streams.get("INPUT_SETTINGS") dynamic_settings = DynamicSettingsMetadata( enabled=isinstance(input_settings, InputStream), @@ -447,33 +464,95 @@ def _component_metadata(self) -> GraphMetadata: ) stream_entries: dict[str, StreamMetadataType] = {} + topic_entries: dict[str, TopicMetadataType] = {} + relay_entries: dict[str, RelayMetadataType] = {} for stream_name, stream in comp.streams.items(): - if isinstance(stream, InputStream): - entry = InputStreamMetadata( + msg_type = self._stream_type_name(stream.msg_type) + if isinstance(stream, InputRelay): + relay_entries[stream_name] = InputRelayMetadata( name=stream_name, address=stream.address, - msg_type=self._stream_type_name(stream.msg_type), + msg_type=msg_type, leaky=stream.leaky, max_queue=stream.max_queue, + copy_on_forward=stream.copy_on_forward, ) - elif isinstance(stream, OutputStream): - entry = OutputStreamMetadata( + elif isinstance(stream, OutputRelay): + relay_entries[stream_name] = OutputRelayMetadata( name=stream_name, address=stream.address, - msg_type=self._stream_type_name(stream.msg_type), + msg_type=msg_type, host=stream.host, port=stream.port, num_buffers=stream.num_buffers, buf_size=stream.buf_size, force_tcp=stream.force_tcp, + copy_on_forward=stream.copy_on_forward, ) - else: - entry = StreamMetadata( + elif isinstance(stream, InputTopic): + topic_entries[stream_name] = InputTopicMetadata( + name=stream_name, + address=stream.address, + msg_type=msg_type, + ) + elif isinstance(stream, OutputTopic): + topic_entries[stream_name] = OutputTopicMetadata( name=stream_name, address=stream.address, - msg_type=self._stream_type_name(stream.msg_type), + msg_type=msg_type, ) - stream_entries[stream_name] = entry + elif isinstance(stream, Topic): + topic_entries[stream_name] = TopicMetadata( + name=stream_name, + address=stream.address, + msg_type=msg_type, + ) + elif isinstance(stream, InputStream): + if is_collection: + topic_entries[stream_name] = InputTopicMetadata( + name=stream_name, + address=stream.address, + msg_type=msg_type, + ) + else: + stream_entries[stream_name] = InputStreamMetadata( + name=stream_name, + address=stream.address, + msg_type=msg_type, + leaky=stream.leaky, + max_queue=stream.max_queue, + ) + elif isinstance(stream, OutputStream): + if is_collection: + topic_entries[stream_name] = OutputTopicMetadata( + name=stream_name, + address=stream.address, + msg_type=msg_type, + ) + else: + stream_entries[stream_name] = OutputStreamMetadata( + name=stream_name, + address=stream.address, + msg_type=msg_type, + host=stream.host, + port=stream.port, + num_buffers=stream.num_buffers, + buf_size=stream.buf_size, + force_tcp=stream.force_tcp, + ) + else: + if is_collection: + topic_entries[stream_name] = TopicMetadata( + name=stream_name, + address=stream.address, + msg_type=msg_type, + ) + else: + stream_entries[stream_name] = StreamMetadata( + name=stream_name, + address=stream.address, + msg_type=msg_type, + ) task_entries: list[TaskMetadata] = [] for task_name, task in comp.tasks.items(): @@ -507,7 +586,6 @@ def _component_metadata(self) -> GraphMetadata: component_type=self._type_name(comp.__class__), settings_type=settings_type_name, initial_settings=self._settings_snapshot(comp.SETTINGS), - streams=stream_entries, dynamic_settings=dynamic_settings, ) @@ -515,6 +593,8 @@ def _component_metadata(self) -> GraphMetadata: if isinstance(comp, Collection): metadata_entry = CollectionMetadata( **component_common, + topics=topic_entries, + relays=relay_entries, children=sorted( child.address for child in comp.components.values() ), @@ -522,6 +602,7 @@ def _component_metadata(self) -> GraphMetadata: elif isinstance(comp, Unit): metadata_entry = UnitMetadata( **component_common, + streams=stream_entries, tasks=sorted(task_entries, key=lambda task: task.name), main=comp.main.__name__ if comp.main is not None else None, threads=sorted(comp.threads.keys()), diff --git a/src/ezmsg/core/graphmeta.py b/src/ezmsg/core/graphmeta.py index c02f67f6..a9b439e9 100644 --- a/src/ezmsg/core/graphmeta.py +++ b/src/ezmsg/core/graphmeta.py @@ -36,6 +36,44 @@ class OutputStreamMetadata(StreamMetadata): ) +@dataclass +class TopicMetadata: + name: str + address: str + msg_type: str + + +@dataclass +class InputTopicMetadata(TopicMetadata): ... + + +@dataclass +class OutputTopicMetadata(TopicMetadata): ... + + +TopicMetadataType: TypeAlias = TopicMetadata | InputTopicMetadata | OutputTopicMetadata + + +@dataclass +class InputRelayMetadata(InputTopicMetadata): + leaky: bool = False + max_queue: int | None = None + copy_on_forward: bool = True + + +@dataclass +class OutputRelayMetadata(OutputTopicMetadata): + host: str | None = None + port: int | None = None + num_buffers: int | None = None + buf_size: int | None = None + force_tcp: bool | None = None + copy_on_forward: bool = True + + +RelayMetadataType: TypeAlias = InputRelayMetadata | OutputRelayMetadata + + @dataclass class TaskMetadata: name: str @@ -55,17 +93,19 @@ class ComponentMetadata: component_type: str settings_type: str initial_settings: InitialSettingsType - streams: dict[str, StreamMetadataType] dynamic_settings: DynamicSettingsMetadata @dataclass class CollectionMetadata(ComponentMetadata): + topics: dict[str, TopicMetadataType] + relays: dict[str, RelayMetadataType] children: list[str] @dataclass class UnitMetadata(ComponentMetadata): + streams: dict[str, StreamMetadataType] tasks: list[TaskMetadata] main: str | None threads: list[str] diff --git a/tests/test_topics.py b/tests/test_topics.py index 2ca691ca..78531bf1 100644 --- a/tests/test_topics.py +++ b/tests/test_topics.py @@ -3,6 +3,14 @@ import ezmsg.core as ez from ezmsg.core.backend import ExecutionContext +from ezmsg.core.graphmeta import ( + CollectionMetadata, + InputRelayMetadata, + InputStreamMetadata, + OutputStreamMetadata, + OutputTopicMetadata, + UnitMetadata, +) @pytest.mark.parametrize( @@ -154,3 +162,28 @@ def test_output_relay_rewrites_edges_and_syncs_settings(): assert relay.SETTINGS.num_buffers == 8 assert relay.SETTINGS.force_tcp is True assert relay.SETTINGS.copy_on_forward is False + + +def test_metadata_separates_collection_topics_relays_and_unit_streams(): + system = _InputRelaySystem() + ctx = ExecutionContext.setup({"SYSTEM": system}) + assert ctx is not None + + runner = ez.GraphRunner(components={"SYSTEM": system}) + metadata = runner._component_metadata() + + passthrough_meta = metadata.components[system.PASSTHROUGH.address] + assert isinstance(passthrough_meta, CollectionMetadata) + assert "IN" in passthrough_meta.relays + assert isinstance(passthrough_meta.relays["IN"], InputRelayMetadata) + assert passthrough_meta.relays["IN"].leaky is True + assert passthrough_meta.relays["IN"].max_queue == 7 + assert "OUT" in passthrough_meta.topics + assert isinstance(passthrough_meta.topics["OUT"], OutputTopicMetadata) + + source_meta = metadata.components[system.SOURCE.address] + sink_meta = metadata.components[system.SINK.address] + assert isinstance(source_meta, UnitMetadata) + assert isinstance(source_meta.streams["OUTPUT"], OutputStreamMetadata) + assert isinstance(sink_meta, UnitMetadata) + assert isinstance(sink_meta.streams["INPUT"], InputStreamMetadata) From 34ff8f42ab58743582d2a74c622c68f30954532e Mon Sep 17 00:00:00 2001 From: Konrad Pilch Date: Thu, 5 Mar 2026 16:44:59 -0700 Subject: [PATCH 07/52] Fix fast_replace annotation --- src/ezmsg/util/messages/util.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/ezmsg/util/messages/util.py b/src/ezmsg/util/messages/util.py index a632f69e..cd4b492d 100644 --- a/src/ezmsg/util/messages/util.py +++ b/src/ezmsg/util/messages/util.py @@ -1,12 +1,12 @@ from dataclasses import replace as slow_replace import os -import typing +from typing import Any, TypeVar -T = typing.TypeVar("T") +T = TypeVar("T") -def fast_replace(arr: typing.Generic[T], **kwargs) -> T: +def fast_replace(arr: T, **kwargs: Any) -> T: """ Fast replacement of dataclass fields with reduced safety. @@ -14,19 +14,23 @@ def fast_replace(arr: typing.Generic[T], **kwargs) -> T: nor does it check that the passed in fields are valid fields for the dataclass and not flagged as init=False. + BEWARE: This function is not type safe and may lead to runtime errors if + used incorrectly. It implicitly assumes arr has a __dict__ attribute and + that kwargs are valid init parameters for the dataclass of arr. + User code may choose to use this replace or the legacy replace according to their needs. To force ezmsg to use the legacy replace, set the environment variable: EZMSG_DISABLE_FAST_REPLACE Unset the variable to use this replace function. :param arr: The dataclass instance to create a modified copy of. - :type arr: typing.Generic[T] + :type arr: T :param kwargs: Field values to update in the new instance. :return: A new instance of the same type with updated field values. :rtype: T """ out_kwargs = arr.__dict__.copy() # Shallow copy - out_kwargs.update(**kwargs) + out_kwargs.update(kwargs) return arr.__class__(**out_kwargs) From c548977a60eacd59401aaa1ffabaf991374d8cb6 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Sat, 14 Mar 2026 10:58:01 -0400 Subject: [PATCH 08/52] registration/snapshot foundation --- src/ezmsg/core/backendprocess.py | 23 ++++ src/ezmsg/core/graphmeta.py | 24 +++++ src/ezmsg/core/graphserver.py | 176 ++++++++++++++++++++++++++++++- src/ezmsg/core/netprotocol.py | 17 +++ src/ezmsg/core/processclient.py | 116 ++++++++++++++++++++ tests/test_process_control.py | 66 ++++++++++++ 6 files changed, 421 insertions(+), 1 deletion(-) create mode 100644 src/ezmsg/core/processclient.py create mode 100644 tests/test_process_control.py diff --git a/src/ezmsg/core/backendprocess.py b/src/ezmsg/core/backendprocess.py index 8ef22a72..9f001dbd 100644 --- a/src/ezmsg/core/backendprocess.py +++ b/src/ezmsg/core/backendprocess.py @@ -26,6 +26,7 @@ from .unit import Unit, TIMEIT_ATTR, SUBSCRIBES_ATTR from .graphcontext import GraphContext +from .processclient import ProcessControlClient from .pubclient import Publisher from .subclient import Subscriber from .netprotocol import AddressType @@ -223,6 +224,8 @@ class DefaultBackendProcess(BackendProcess): def process(self, loop: asyncio.AbstractEventLoop) -> None: main_func = None context = GraphContext(self.graph_address) + process_client = ProcessControlClient(self.graph_address) + process_register_future: concurrent.futures.Future[None] | None = None coro_callables: dict[str, Callable[[], Coroutine[Any, Any, None]]] = dict() self._shutdown_errors = False @@ -315,6 +318,17 @@ async def setup_state(): logger.debug("Waiting at start barrier!") self.start_barrier.wait() + async def register_process_control() -> None: + try: + await process_client.register([unit.address for unit in self.units]) + except Exception as exc: + logger.warning(f"Process control registration failed: {exc}") + + process_register_future = asyncio.run_coroutine_threadsafe( + register_process_control(), + loop, + ) + for unit in self.units: for thread_fn in unit.threads.values(): loop.run_in_executor(None, thread_fn, unit) @@ -407,6 +421,15 @@ async def shutdown_units() -> None: except TimeoutError: logger.warning("Timed out waiting for retry on context revert") + process_close_future = asyncio.run_coroutine_threadsafe( + process_client.close(), + loop=loop, + ) + with suppress(Exception): + if process_register_future is not None: + process_register_future.result(timeout=0.5) + process_close_future.result() + logger.debug(f"Remaining tasks in event loop = {asyncio.all_tasks(loop)}") if self.task_finished_ev is not None: diff --git a/src/ezmsg/core/graphmeta.py b/src/ezmsg/core/graphmeta.py index a9b439e9..94db0517 100644 --- a/src/ezmsg/core/graphmeta.py +++ b/src/ezmsg/core/graphmeta.py @@ -123,6 +123,21 @@ class GraphMetadata: components: dict[str, ComponentMetadataType] +@dataclass +class ProcessHello: + process_id: str + pid: int + host: str + units: list[str] + + +@dataclass +class ProcessOwnershipUpdate: + process_id: str + added_units: list[str] = field(default_factory=list) + removed_units: list[str] = field(default_factory=list) + + class Edge(NamedTuple): from_topic: str to_topic: str @@ -134,8 +149,17 @@ class SnapshotSession: metadata: GraphMetadata | None +@dataclass +class SnapshotProcess: + process_id: str + pid: int | None + host: str | None + units: list[str] + + @dataclass class GraphSnapshot: graph: dict[str, list[str]] edge_owners: dict[Edge, list[str]] sessions: dict[str, SnapshotSession] + processes: dict[str, SnapshotProcess] = field(default_factory=dict) diff --git a/src/ezmsg/core/graphserver.py b/src/ezmsg/core/graphserver.py index 873f120b..94cbd726 100644 --- a/src/ezmsg/core/graphserver.py +++ b/src/ezmsg/core/graphserver.py @@ -15,12 +15,16 @@ Edge, GraphMetadata, GraphSnapshot, + ProcessHello, + ProcessOwnershipUpdate, + SnapshotProcess, SnapshotSession, ) from .netprotocol import ( Address, Command, ClientInfo, + ProcessInfo, SessionInfo, SubscriberInfo, PublisherInfo, @@ -283,6 +287,20 @@ async def api( # to avoid closing writer return + elif req == Command.PROCESS.value: + process_client_id = uuid1() + self.clients[process_client_id] = ProcessInfo(process_client_id, writer) + writer.write(encode_str(str(process_client_id))) + writer.write(Command.COMPLETE.value) + await writer.drain() + self._client_tasks[process_client_id] = asyncio.create_task( + self._handle_process(process_client_id, reader, writer) + ) + + # NOTE: Created a process control client, must return early + # to avoid closing writer + return + else: # We only want to handle one command at a time async with self._command_lock: @@ -480,6 +498,137 @@ async def _handle_session( self._client_tasks.pop(session_id, None) await close_stream_writer(writer) + def _process_info(self, process_client_id: UUID) -> ProcessInfo | None: + info = self.clients.get(process_client_id) + if isinstance(info, ProcessInfo): + return info + return None + + async def _handle_process( + self, + process_client_id: UUID, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + ) -> None: + logger.debug(f"Graph Server: Process control connected: {process_client_id}") + + try: + while True: + req = await reader.read(1) + + if not req: + break + + if req == Command.PROCESS_REGISTER.value: + response = await self._handle_process_register_request( + process_client_id, reader + ) + writer.write(response) + await writer.drain() + + elif req == Command.PROCESS_UPDATE_OWNERSHIP.value: + response = await self._handle_process_update_ownership_request( + process_client_id, reader + ) + writer.write(response) + await writer.drain() + + else: + logger.warning( + f"Process control {process_client_id} rx unknown command: {req}" + ) + + except (ConnectionResetError, BrokenPipeError) as e: + logger.debug( + f"Process control {process_client_id} disconnected from GraphServer: {e}" + ) + + finally: + self.clients.pop(process_client_id, None) + self._client_tasks.pop(process_client_id, None) + await close_stream_writer(writer) + + async def _handle_process_register_request( + self, process_client_id: UUID, reader: asyncio.StreamReader + ) -> bytes: + num_bytes = await read_int(reader) + payload = await reader.readexactly(num_bytes) + hello: ProcessHello | None = None + try: + hello_obj = pickle.loads(payload) + if isinstance(hello_obj, ProcessHello): + hello = hello_obj + else: + raise RuntimeError("process registration payload was not ProcessHello") + except Exception as exc: + logger.warning( + "Process control %s registration parse failed; ignoring payload: %s", + process_client_id, + exc, + ) + + if hello is None: + return Command.COMPLETE.value + + async with self._command_lock: + process_info = self._process_info(process_client_id) + if process_info is None: + return Command.COMPLETE.value + + process_info.process_id = hello.process_id + process_info.pid = hello.pid + process_info.host = hello.host + process_info.units = set(hello.units) + + return Command.COMPLETE.value + + async def _handle_process_update_ownership_request( + self, process_client_id: UUID, reader: asyncio.StreamReader + ) -> bytes: + num_bytes = await read_int(reader) + payload = await reader.readexactly(num_bytes) + update: ProcessOwnershipUpdate | None = None + try: + update_obj = pickle.loads(payload) + if isinstance(update_obj, ProcessOwnershipUpdate): + update = update_obj + else: + raise RuntimeError( + "process ownership payload was not ProcessOwnershipUpdate" + ) + except Exception as exc: + logger.warning( + "Process control %s ownership update parse failed; ignoring payload: %s", + process_client_id, + exc, + ) + + if update is None: + return Command.COMPLETE.value + + async with self._command_lock: + process_info = self._process_info(process_client_id) + if process_info is None: + return Command.COMPLETE.value + + if ( + process_info.process_id is not None + and process_info.process_id != update.process_id + ): + logger.warning( + "Process control %s process_id mismatch: %s != %s", + process_client_id, + process_info.process_id, + update.process_id, + ) + elif process_info.process_id is None: + process_info.process_id = update.process_id + + process_info.units.update(update.added_units) + process_info.units.difference_update(update.removed_units) + + return Command.COMPLETE.value + def _connect_owner( self, from_topic: str, to_topic: str, owner: UUID | str | None ) -> bool: @@ -632,7 +781,32 @@ def _snapshot(self) -> GraphSnapshot: key=lambda item: str(item[0]), ) } - return GraphSnapshot(graph=graph, edge_owners=edge_owners, sessions=sessions) + processes = { + str(client_id): SnapshotProcess( + process_id=( + process.process_id + if process.process_id is not None + else str(client_id) + ), + pid=process.pid, + host=process.host, + units=sorted(process.units), + ) + for client_id, process in sorted( + [ + (client_id, info) + for client_id, info in self.clients.items() + if isinstance(info, ProcessInfo) + ], + key=lambda item: str(item[0]), + ) + } + return GraphSnapshot( + graph=graph, + edge_owners=edge_owners, + sessions=sessions, + processes=processes, + ) async def _notify_downstream_for_topic(self, topic: str) -> None: for sub in self._downstream_subs(topic): diff --git a/src/ezmsg/core/netprotocol.py b/src/ezmsg/core/netprotocol.py index dca58583..1f717861 100644 --- a/src/ezmsg/core/netprotocol.py +++ b/src/ezmsg/core/netprotocol.py @@ -176,6 +176,18 @@ class SessionInfo(ClientInfo): metadata: GraphMetadata | None = None +@dataclass +class ProcessInfo(ClientInfo): + """ + Process-scoped control-plane client information. + """ + + process_id: str | None = None + pid: int | None = None + host: str | None = None + units: set[str] = field(default_factory=set) + + def uint64_to_bytes(i: int) -> bytes: """ Convert a 64-bit unsigned integer to bytes. @@ -319,6 +331,11 @@ def _generate_next_value_(name, start, count, last_values) -> bytes: SESSION_REGISTER = enum.auto() SESSION_SNAPSHOT = enum.auto() + # Backend Process Control Commands + PROCESS = enum.auto() + PROCESS_REGISTER = enum.auto() + PROCESS_UPDATE_OWNERSHIP = enum.auto() + def create_socket( host: str | None = None, diff --git a/src/ezmsg/core/processclient.py b/src/ezmsg/core/processclient.py new file mode 100644 index 00000000..8d24bb92 --- /dev/null +++ b/src/ezmsg/core/processclient.py @@ -0,0 +1,116 @@ +import asyncio +import logging +import os +import pickle +import socket + +from uuid import UUID, uuid1 + +from .graphmeta import ProcessHello, ProcessOwnershipUpdate +from .graphserver import GraphService +from .netprotocol import ( + AddressType, + Command, + close_stream_writer, + read_str, + uint64_to_bytes, +) + +logger = logging.getLogger("ezmsg") + + +class ProcessControlClient: + _graph_address: AddressType | None + _process_id: str + _client_id: UUID | None + _reader: asyncio.StreamReader | None + _writer: asyncio.StreamWriter | None + _lock: asyncio.Lock + + def __init__( + self, graph_address: AddressType | None = None, process_id: str | None = None + ) -> None: + self._graph_address = graph_address + self._process_id = process_id if process_id is not None else str(uuid1()) + self._client_id = None + self._reader = None + self._writer = None + self._lock = asyncio.Lock() + + @property + def process_id(self) -> str: + return self._process_id + + @property + def client_id(self) -> UUID | None: + return self._client_id + + async def connect(self) -> None: + if self._writer is not None: + return + + reader, writer = await GraphService(self._graph_address).open_connection() + writer.write(Command.PROCESS.value) + await writer.drain() + + client_id = UUID(await read_str(reader)) + response = await reader.read(1) + if response != Command.COMPLETE.value: + await close_stream_writer(writer) + raise RuntimeError("Failed to create process control connection") + + self._client_id = client_id + self._reader = reader + self._writer = writer + + async def register(self, units: list[str]) -> None: + await self.connect() + payload = ProcessHello( + process_id=self._process_id, + pid=os.getpid(), + host=socket.gethostname(), + units=sorted(set(units)), + ) + await self._payload_command(Command.PROCESS_REGISTER, payload) + + async def update_ownership( + self, + added_units: list[str] | None = None, + removed_units: list[str] | None = None, + ) -> None: + await self.connect() + payload = ProcessOwnershipUpdate( + process_id=self._process_id, + added_units=sorted(set(added_units or [])), + removed_units=sorted(set(removed_units or [])), + ) + await self._payload_command(Command.PROCESS_UPDATE_OWNERSHIP, payload) + + async def close(self) -> None: + writer = self._writer + if writer is None: + return + + self._reader = None + self._writer = None + self._client_id = None + await close_stream_writer(writer) + + async def _payload_command(self, command: Command, payload_obj: object) -> None: + reader = self._reader + writer = self._writer + if reader is None or writer is None: + raise RuntimeError("Process control connection is not active") + + payload = pickle.dumps(payload_obj) + async with self._lock: + writer.write(command.value) + writer.write(uint64_to_bytes(len(payload))) + writer.write(payload) + await writer.drain() + + response = await reader.read(1) + if response != Command.COMPLETE.value: + raise RuntimeError( + f"Unexpected response to process control command: {command.name}" + ) diff --git a/tests/test_process_control.py b/tests/test_process_control.py new file mode 100644 index 00000000..dbf70a2d --- /dev/null +++ b/tests/test_process_control.py @@ -0,0 +1,66 @@ +import asyncio + +import pytest + +from ezmsg.core.graphcontext import GraphContext +from ezmsg.core.processclient import ProcessControlClient +from ezmsg.core.graphserver import GraphService + + +@pytest.mark.asyncio +async def test_process_registration_visible_in_snapshot(): + graph_server = GraphService().create_server() + address = graph_server.address + + observer = GraphContext(address, auto_start=False) + await observer.__aenter__() + + process = ProcessControlClient(address, process_id="proc-A") + await process.connect() + + try: + await process.register(["SYS/U1", "SYS/U2"]) + await process.update_ownership(added_units=["SYS/U3"], removed_units=["SYS/U1"]) + + snapshot = await observer.snapshot() + assert len(snapshot.processes) == 1 + + process_entry = next(iter(snapshot.processes.values())) + assert process_entry.process_id == "proc-A" + assert process_entry.pid is not None + assert process_entry.host is not None + assert process_entry.units == ["SYS/U2", "SYS/U3"] + + finally: + await process.close() + await asyncio.sleep(0.05) + await observer.__aexit__(None, None, None) + graph_server.stop() + + +@pytest.mark.asyncio +async def test_process_snapshot_entry_drops_on_disconnect(): + graph_server = GraphService().create_server() + address = graph_server.address + + observer = GraphContext(address, auto_start=False) + await observer.__aenter__() + + process = ProcessControlClient(address, process_id="proc-B") + await process.connect() + + try: + await process.register(["SYS/U1"]) + snapshot = await observer.snapshot() + assert len(snapshot.processes) == 1 + + await process.close() + await asyncio.sleep(0.05) + + snapshot = await observer.snapshot() + assert len(snapshot.processes) == 0 + + finally: + await process.close() + await observer.__aexit__(None, None, None) + graph_server.stop() From eb820561459f9a7ce7822dc7181deaee3ca9eac7 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Sat, 14 Mar 2026 11:26:30 -0400 Subject: [PATCH 09/52] deprecate @ez.thread decorator --- ISSUE_deprecate_ez_thread.md | 27 +++++++++++++++++++ PR_deprecate_ez_thread.md | 25 +++++++++++++++++ .../reference/API/functiondecorators.rst | 4 +++ src/ezmsg/core/unit.py | 11 ++++++++ tests/test_unit_deprecations.py | 18 +++++++++++++ 5 files changed, 85 insertions(+) create mode 100644 ISSUE_deprecate_ez_thread.md create mode 100644 PR_deprecate_ez_thread.md create mode 100644 tests/test_unit_deprecations.py diff --git a/ISSUE_deprecate_ez_thread.md b/ISSUE_deprecate_ez_thread.md new file mode 100644 index 00000000..3dcea33e --- /dev/null +++ b/ISSUE_deprecate_ez_thread.md @@ -0,0 +1,27 @@ +# Deprecate `@ez.thread` Decorator + +## Summary +Deprecate `@ez.thread` in `ezmsg.core` and guide users to explicit background execution patterns (`loop.run_in_executor(...)` / explicit task lifecycle management). + +## Motivation +- `@ez.thread` has no cooperative termination contract and does not integrate cleanly with unit lifecycle shutdown. +- Equivalent behavior is already available with explicit executor usage. +- Keeping `@ez.thread` adds an extra concurrency model surface area without strong adoption. + +## Proposal +1. Mark `@ez.thread` as deprecated by emitting `DeprecationWarning` when the decorator is applied. +2. Update docs to indicate deprecation and migration guidance. +3. Add tests to ensure the warning is emitted and existing decorator behavior remains intact for compatibility. + +## Non-Goals +- Removing `@ez.thread` in this issue. +- Changing runtime behavior of existing `@ez.thread`-decorated functions beyond warning emission. + +## Acceptance Criteria +- Calling `ez.thread(...)` emits `DeprecationWarning`. +- API docs clearly mark `@ez.thread` as deprecated with migration guidance. +- Test coverage exists for warning behavior and attribute tagging. + +## Follow-Up +- Remove `@ez.thread` in a future major release after deprecation window. +- Add release note/migration note before removal. diff --git a/PR_deprecate_ez_thread.md b/PR_deprecate_ez_thread.md new file mode 100644 index 00000000..8fb3295d --- /dev/null +++ b/PR_deprecate_ez_thread.md @@ -0,0 +1,25 @@ +# Deprecate `@ez.thread` + +## Summary +This changeset deprecates the `@ez.thread` decorator and documents the preferred replacement (`loop.run_in_executor(...)` / explicit task lifecycle management). + +## Changes +- Emit `DeprecationWarning` from `ezmsg.core.unit.thread`. +- Add deprecation note to function decorator docs. +- Add a unit test verifying warning emission and backward-compatible decorator tagging. + +## Files +- `src/ezmsg/core/unit.py` +- `docs/source/reference/API/functiondecorators.rst` +- `tests/test_unit_deprecations.py` +- `ISSUE_deprecate_ez_thread.md` + +## Testing +- `PYTHONPYCACHEPREFIX=/tmp/pycache .venv/bin/pytest tests/test_unit_deprecations.py -q` + +## Backward Compatibility +- Existing `@ez.thread` usage continues to function in this release. +- Users now receive a deprecation warning at decorator application time. + +## Future Work +- Remove `@ez.thread` in a future major release after migration window and release-note notice. diff --git a/docs/source/reference/API/functiondecorators.rst b/docs/source/reference/API/functiondecorators.rst index 6b474e2e..66d6d555 100644 --- a/docs/source/reference/API/functiondecorators.rst +++ b/docs/source/reference/API/functiondecorators.rst @@ -11,6 +11,10 @@ These function decorators can be added to member functions of an ezmsg ``Unit`` .. autodecorator:: ezmsg.core.thread +.. note:: + ``@ez.thread`` is deprecated and will be removed in a future release. + Prefer explicit background work via ``loop.run_in_executor(...)``. + .. autodecorator:: ezmsg.core.task .. autodecorator:: ezmsg.core.process diff --git a/src/ezmsg/core/unit.py b/src/ezmsg/core/unit.py index 8527e71d..aa6723a0 100644 --- a/src/ezmsg/core/unit.py +++ b/src/ezmsg/core/unit.py @@ -274,11 +274,22 @@ def thread(func: Callable): Thread functions run concurrently with the main message processing and can be used for background tasks, monitoring, or other concurrent operations. + .. deprecated:: + ``@thread`` is deprecated and will be removed in a future release. + Prefer explicit background work using ``loop.run_in_executor(...)`` or + explicit task management in ``initialize()``/``shutdown()``. + :param func: The function to run as a background thread :type func: collections.abc.Callable :return: The decorated function :rtype: collections.abc.Callable """ + warnings.warn( + "`@ez.thread` is deprecated and will be removed in a future release. " + "Prefer explicit background work via `loop.run_in_executor(...)`.", + DeprecationWarning, + stacklevel=2, + ) setattr(func, THREAD_ATTR, True) return func diff --git a/tests/test_unit_deprecations.py b/tests/test_unit_deprecations.py new file mode 100644 index 00000000..5e4bac54 --- /dev/null +++ b/tests/test_unit_deprecations.py @@ -0,0 +1,18 @@ +import pytest + +import ezmsg.core as ez +from ezmsg.core.unit import THREAD_ATTR + + +def test_thread_decorator_warns_and_sets_attribute(): + def fn(_): + return None + + with pytest.warns( + DeprecationWarning, + match=r"`@ez\.thread` is deprecated and will be removed in a future release", + ): + decorated = ez.thread(fn) + + assert decorated is fn + assert hasattr(decorated, THREAD_ATTR) From 7d0f36d10f2130fc7f09383ad2c11267e589f466 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Sat, 14 Mar 2026 11:34:42 -0400 Subject: [PATCH 10/52] removed markdown noise --- ISSUE_deprecate_ez_thread.md | 27 --------------------------- PR_deprecate_ez_thread.md | 25 ------------------------- 2 files changed, 52 deletions(-) delete mode 100644 ISSUE_deprecate_ez_thread.md delete mode 100644 PR_deprecate_ez_thread.md diff --git a/ISSUE_deprecate_ez_thread.md b/ISSUE_deprecate_ez_thread.md deleted file mode 100644 index 3dcea33e..00000000 --- a/ISSUE_deprecate_ez_thread.md +++ /dev/null @@ -1,27 +0,0 @@ -# Deprecate `@ez.thread` Decorator - -## Summary -Deprecate `@ez.thread` in `ezmsg.core` and guide users to explicit background execution patterns (`loop.run_in_executor(...)` / explicit task lifecycle management). - -## Motivation -- `@ez.thread` has no cooperative termination contract and does not integrate cleanly with unit lifecycle shutdown. -- Equivalent behavior is already available with explicit executor usage. -- Keeping `@ez.thread` adds an extra concurrency model surface area without strong adoption. - -## Proposal -1. Mark `@ez.thread` as deprecated by emitting `DeprecationWarning` when the decorator is applied. -2. Update docs to indicate deprecation and migration guidance. -3. Add tests to ensure the warning is emitted and existing decorator behavior remains intact for compatibility. - -## Non-Goals -- Removing `@ez.thread` in this issue. -- Changing runtime behavior of existing `@ez.thread`-decorated functions beyond warning emission. - -## Acceptance Criteria -- Calling `ez.thread(...)` emits `DeprecationWarning`. -- API docs clearly mark `@ez.thread` as deprecated with migration guidance. -- Test coverage exists for warning behavior and attribute tagging. - -## Follow-Up -- Remove `@ez.thread` in a future major release after deprecation window. -- Add release note/migration note before removal. diff --git a/PR_deprecate_ez_thread.md b/PR_deprecate_ez_thread.md deleted file mode 100644 index 8fb3295d..00000000 --- a/PR_deprecate_ez_thread.md +++ /dev/null @@ -1,25 +0,0 @@ -# Deprecate `@ez.thread` - -## Summary -This changeset deprecates the `@ez.thread` decorator and documents the preferred replacement (`loop.run_in_executor(...)` / explicit task lifecycle management). - -## Changes -- Emit `DeprecationWarning` from `ezmsg.core.unit.thread`. -- Add deprecation note to function decorator docs. -- Add a unit test verifying warning emission and backward-compatible decorator tagging. - -## Files -- `src/ezmsg/core/unit.py` -- `docs/source/reference/API/functiondecorators.rst` -- `tests/test_unit_deprecations.py` -- `ISSUE_deprecate_ez_thread.md` - -## Testing -- `PYTHONPYCACHEPREFIX=/tmp/pycache .venv/bin/pytest tests/test_unit_deprecations.py -q` - -## Backward Compatibility -- Existing `@ez.thread` usage continues to function in this release. -- Users now receive a deprecation warning at decorator application time. - -## Future Work -- Remove `@ez.thread` in a future major release after migration window and release-note notice. From cd30af92315f6951fbef86d5f9f9be34a1a9201c Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Sat, 14 Mar 2026 11:37:18 -0400 Subject: [PATCH 11/52] removed test --- tests/test_unit_deprecations.py | 18 ------------------ 1 file changed, 18 deletions(-) delete mode 100644 tests/test_unit_deprecations.py diff --git a/tests/test_unit_deprecations.py b/tests/test_unit_deprecations.py deleted file mode 100644 index 5e4bac54..00000000 --- a/tests/test_unit_deprecations.py +++ /dev/null @@ -1,18 +0,0 @@ -import pytest - -import ezmsg.core as ez -from ezmsg.core.unit import THREAD_ATTR - - -def test_thread_decorator_warns_and_sets_attribute(): - def fn(_): - return None - - with pytest.warns( - DeprecationWarning, - match=r"`@ez\.thread` is deprecated and will be removed in a future release", - ): - decorated = ez.thread(fn) - - assert decorated is fn - assert hasattr(decorated, THREAD_ATTR) From 0ffd196cf36553ef72af3918fe0331a04fd432e0 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Sun, 15 Mar 2026 10:20:35 -0400 Subject: [PATCH 12/52] graphserver is ASOT for settings --- examples/settings_tui.py | 389 +++++++++++++++++++++++++++++++ src/ezmsg/core/backendprocess.py | 64 ++++- src/ezmsg/core/graphcontext.py | 73 +++++- src/ezmsg/core/graphmeta.py | 34 ++- src/ezmsg/core/graphserver.py | 271 ++++++++++++++++++++- src/ezmsg/core/netprotocol.py | 4 + src/ezmsg/core/processclient.py | 25 +- tests/test_settings_api.py | 201 ++++++++++++++++ 8 files changed, 1036 insertions(+), 25 deletions(-) create mode 100644 examples/settings_tui.py create mode 100644 tests/test_settings_api.py diff --git a/examples/settings_tui.py b/examples/settings_tui.py new file mode 100644 index 00000000..83acf623 --- /dev/null +++ b/examples/settings_tui.py @@ -0,0 +1,389 @@ +#!/usr/bin/env python3 +""" +Simple settings TUI for ezmsg GraphServer. + +Features: +- Live settings view (push updates via GraphContext.subscribe_settings_events) +- Inspect component metadata and current settings snapshot +- Publish patched settings to components with dynamic INPUT_SETTINGS + +Usage: + .venv/bin/python examples/settings_tui.py --host 127.0.0.1 --port 25978 + +Commands: + help + refresh + inspect + set {"field": 123, "nested": {"gain": 0.5}} + quit + +Notes: +- Updates are sent over normal pub/sub to the component's INPUT_SETTINGS topic. +- For safe updates, the script expects pickled current settings to be available + and unpickleable in this environment. +""" + +from __future__ import annotations + +import argparse +import asyncio +import contextlib +import json +import pickle +from dataclasses import dataclass, is_dataclass, replace +from typing import Any + +from ezmsg.core.graphcontext import GraphContext +from ezmsg.core.graphmeta import ( + ComponentMetadataType, + GraphMetadata, + SettingsChangedEvent, + SettingsSnapshotValue, +) +from ezmsg.core.netprotocol import DEFAULT_HOST, GRAPHSERVER_PORT_DEFAULT +from ezmsg.core.pubclient import Publisher + + +def _truncate(text: str, width: int) -> str: + if width <= 3: + return text[:width] + if len(text) <= width: + return text + return text[: width - 3] + "..." + + +def _format_settings(value: SettingsSnapshotValue | None, width: int = 72) -> str: + if value is None: + return "-" + return _truncate(repr(value.repr_value), width) + + +def _deep_merge_dict(base: dict[str, Any], patch: dict[str, Any]) -> dict[str, Any]: + merged = dict(base) + for key, patch_value in patch.items(): + base_value = merged.get(key) + if isinstance(base_value, dict) and isinstance(patch_value, dict): + merged[key] = _deep_merge_dict(base_value, patch_value) + else: + merged[key] = patch_value + return merged + + +def _patch_dataclass(obj: Any, patch: dict[str, Any]) -> Any: + updates: dict[str, Any] = {} + for key, patch_value in patch.items(): + if not hasattr(obj, key): + raise KeyError(f"Settings object has no field '{key}'") + current = getattr(obj, key) + if is_dataclass(current) and isinstance(patch_value, dict): + updates[key] = _patch_dataclass(current, patch_value) + elif isinstance(current, dict) and isinstance(patch_value, dict): + updates[key] = _deep_merge_dict(current, patch_value) + else: + updates[key] = patch_value + return replace(obj, **updates) + + +def _patch_value(value: Any, patch: dict[str, Any]) -> Any: + if is_dataclass(value): + return _patch_dataclass(value, patch) + if isinstance(value, dict): + return _deep_merge_dict(value, patch) + raise TypeError( + f"Cannot patch settings value of type {type(value).__name__}. " + "Only dataclass/dict settings are supported by this script." + ) + + +def _components_from_metadata( + metadata: GraphMetadata | None, +) -> dict[str, ComponentMetadataType]: + if metadata is None: + return {} + return dict(metadata.components) + + +@dataclass +class ComponentRow: + address: str + name: str + component_type: str + settings_type: str + dynamic_enabled: bool + input_topic: str | None + + +class SettingsTUI: + def __init__(self, ctx: GraphContext): + self.ctx = ctx + self.settings: dict[str, SettingsSnapshotValue] = {} + self.components: dict[str, ComponentRow] = {} + self.row_addresses: list[str] = [] + self.last_seq = 0 + self.publishers: dict[str, Publisher] = {} + self._event_queue: asyncio.Queue[SettingsChangedEvent] = asyncio.Queue() + self._watch_task: asyncio.Task[None] | None = None + + async def initialize(self) -> None: + await self.refresh() + events = await self.ctx.settings_events(after_seq=0) + for event in events: + self.settings[event.component_address] = event.value + self.last_seq = max(self.last_seq, event.seq) + self._watch_task = asyncio.create_task(self._watch_settings_events()) + + async def close(self) -> None: + if self._watch_task is not None: + self._watch_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._watch_task + + async def _watch_settings_events(self) -> None: + async for event in self.ctx.subscribe_settings_events(after_seq=self.last_seq): + await self._event_queue.put(event) + + async def refresh(self) -> None: + snapshot = await self.ctx.snapshot() + settings = await self.ctx.settings_snapshot() + + components: dict[str, ComponentRow] = {} + for session in snapshot.sessions.values(): + for address, comp in _components_from_metadata(session.metadata).items(): + components[address] = ComponentRow( + address=address, + name=comp.name, + component_type=comp.component_type, + settings_type=comp.settings_type, + dynamic_enabled=comp.dynamic_settings.enabled, + input_topic=comp.dynamic_settings.input_topic, + ) + + self.components = components + self.settings = settings + + async def drain_events(self) -> int: + count = 0 + while True: + try: + event = self._event_queue.get_nowait() + except asyncio.QueueEmpty: + break + self.settings[event.component_address] = event.value + self.last_seq = max(self.last_seq, event.seq) + count += 1 + return count + + def render(self, pending_updates: int = 0) -> None: + print("\x1bc", end="") + print("ezmsg settings tui") + print( + "Commands: help, refresh, inspect , " + "set , quit" + ) + if pending_updates > 0: + print(f"Applied {pending_updates} new settings event(s).") + + all_addresses = sorted(set(self.settings) | set(self.components)) + self.row_addresses = all_addresses + + header = ( + f"{'Row':<4} {'Component':<36} {'Dyn':<4} " + f"{'INPUT_SETTINGS Topic':<42} {'Current Settings':<72}" + ) + print() + print(header) + print("-" * len(header)) + + for idx, address in enumerate(all_addresses, start=1): + comp = self.components.get(address) + settings = self.settings.get(address) + + dynamic = "yes" if comp is not None and comp.dynamic_enabled else "no" + input_topic = ( + comp.input_topic if comp is not None and comp.input_topic is not None else "-" + ) + print( + f"{idx:<4} " + f"{_truncate(address, 36):<36} " + f"{dynamic:<4} " + f"{_truncate(input_topic, 42):<42} " + f"{_format_settings(settings):<72}" + ) + + def resolve_target(self, token: str) -> str: + if token.isdigit(): + idx = int(token) - 1 + if idx < 0 or idx >= len(self.row_addresses): + raise ValueError(f"Row index out of range: {token}") + return self.row_addresses[idx] + return token + + async def inspect(self, token: str) -> None: + address = self.resolve_target(token) + comp = self.components.get(address) + settings = self.settings.get(address) + print("\n--- inspect ---") + print(f"address: {address}") + if comp is None: + print("metadata: ") + else: + print(f"name: {comp.name}") + print(f"component_type: {comp.component_type}") + print(f"settings_type: {comp.settings_type}") + print(f"dynamic_settings.enabled: {comp.dynamic_enabled}") + print(f"dynamic_settings.input_topic: {comp.input_topic}") + if settings is None: + print("current_settings: ") + else: + print(f"repr: {settings.repr_value!r}") + print(f"has_pickled_payload: {settings.serialized is not None}") + if settings.serialized is not None: + try: + obj = pickle.loads(settings.serialized) + print(f"unpickled_type: {type(obj).__module__}.{type(obj).__name__}") + except Exception as exc: + print(f"unpickled_type: ") + + async def set_settings(self, token: str, patch: dict[str, Any]) -> str: + address = self.resolve_target(token) + comp = self.components.get(address) + if comp is None: + raise ValueError(f"No component metadata available for '{address}'") + if not comp.dynamic_enabled or comp.input_topic is None: + raise ValueError( + f"Component '{address}' is not dynamic-settings enabled or has no INPUT_SETTINGS topic" + ) + + current = self.settings.get(address) + if current is None: + raise ValueError(f"No current settings snapshot for '{address}'") + if current.serialized is None: + raise ValueError( + f"No serialized settings for '{address}'. Cannot safely build updated object." + ) + + try: + current_obj = pickle.loads(current.serialized) + except Exception as exc: + raise ValueError( + f"Could not unpickle current settings for '{address}': {exc}" + ) from exc + + updated_obj = _patch_value(current_obj, patch) + publisher = self.publishers.get(comp.input_topic) + if publisher is None: + publisher = await self.ctx.publisher(comp.input_topic) + self.publishers[comp.input_topic] = publisher + + await publisher.broadcast(updated_obj) + return f"Published settings update to {comp.input_topic}" + + +def _parse_patch(json_text: str) -> dict[str, Any]: + try: + patch = json.loads(json_text) + except json.JSONDecodeError as exc: + raise ValueError(f"Invalid JSON patch: {exc}") from exc + if not isinstance(patch, dict): + raise ValueError("Patch must be a JSON object") + return patch + + +def _parse_address(host: str, port: int) -> tuple[str, int]: + return (host, port) + + +async def _run_tui(host: str, port: int, auto_start: bool) -> None: + address = _parse_address(host, port) + + async with GraphContext(address, auto_start=auto_start) as ctx: + tui = SettingsTUI(ctx) + await tui.initialize() + try: + while True: + pending = await tui.drain_events() + tui.render(pending_updates=pending) + cmdline = (await asyncio.to_thread(input, "\nsettings-tui> ")).strip() + if not cmdline: + continue + + cmd, *rest = cmdline.split(" ", 1) + if cmd in {"q", "quit", "exit"}: + break + + if cmd in {"h", "help"}: + print( + "\nhelp:\n" + " refresh\n" + " inspect \n" + " set \n" + " quit\n" + ) + await asyncio.to_thread(input, "Press Enter to continue...") + continue + + if cmd == "refresh": + await tui.refresh() + continue + + if cmd == "inspect": + if not rest: + print("Usage: inspect ") + else: + await tui.inspect(rest[0].strip()) + await asyncio.to_thread(input, "Press Enter to continue...") + continue + + if cmd == "set": + if not rest: + print("Usage: set ") + await asyncio.to_thread(input, "Press Enter to continue...") + continue + + target_and_patch = rest[0].strip() + if " " not in target_and_patch: + print("Usage: set ") + await asyncio.to_thread(input, "Press Enter to continue...") + continue + + target, patch_text = target_and_patch.split(" ", 1) + try: + patch = _parse_patch(patch_text.strip()) + result = await tui.set_settings(target.strip(), patch) + print(result) + except Exception as exc: + print(f"set failed: {exc}") + await asyncio.to_thread(input, "Press Enter to continue...") + continue + + print(f"Unknown command: {cmd}") + await asyncio.to_thread(input, "Press Enter to continue...") + finally: + await tui.close() + + +def _build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="ezmsg settings TUI") + parser.add_argument("--host", default=DEFAULT_HOST, help="GraphServer host") + parser.add_argument( + "--port", + type=int, + default=GRAPHSERVER_PORT_DEFAULT, + help="GraphServer port", + ) + parser.add_argument( + "--auto-start", + action="store_true", + help="Allow GraphContext to auto-start GraphServer if unavailable", + ) + return parser + + +def main() -> None: + parser = _build_parser() + args = parser.parse_args() + asyncio.run(_run_tui(args.host, args.port, args.auto_start)) + + +if __name__ == "__main__": + main() diff --git a/src/ezmsg/core/backendprocess.py b/src/ezmsg/core/backendprocess.py index 9f001dbd..1f193ac4 100644 --- a/src/ezmsg/core/backendprocess.py +++ b/src/ezmsg/core/backendprocess.py @@ -3,15 +3,16 @@ import logging import inspect import os +import pickle import time import traceback import threading import weakref from abc import abstractmethod -from dataclasses import dataclass +from dataclasses import asdict, dataclass, is_dataclass from collections import defaultdict -from collections.abc import Callable, Coroutine, Generator, Sequence +from collections.abc import Awaitable, Callable, Coroutine, Generator, Sequence from functools import wraps, partial from concurrent.futures import ThreadPoolExecutor from concurrent.futures.thread import _worker @@ -26,6 +27,7 @@ from .unit import Unit, TIMEIT_ATTR, SUBSCRIBES_ATTR from .graphcontext import GraphContext +from .graphmeta import SettingsSnapshotValue from .processclient import ProcessControlClient from .pubclient import Publisher from .subclient import Subscriber @@ -221,6 +223,22 @@ class DefaultBackendProcess(BackendProcess): pubs: dict[str, Publisher] _shutdown_errors: bool + def _settings_snapshot_value(self, value: object) -> SettingsSnapshotValue: + try: + serialized = pickle.dumps(value) + except Exception: + serialized = None + + if is_dataclass(value): + try: + repr_value = asdict(value) + if isinstance(repr_value, dict): + return SettingsSnapshotValue(serialized=serialized, repr_value=repr_value) + except Exception: + pass + + return SettingsSnapshotValue(serialized=serialized, repr_value=repr(value)) + def process(self, loop: asyncio.AbstractEventLoop) -> None: main_func = None context = GraphContext(self.graph_address) @@ -287,8 +305,30 @@ async def setup_state(): loop, ).result() task_name = f"SUBSCRIBER|{stream.address}" + report_settings_update: ( + Callable[[object], Awaitable[None]] | None + ) = None + if stream.name == "INPUT_SETTINGS": + component_address = unit.address + + async def report_settings_update_cb( + msg: object, + *, + _component_address: str = component_address, + ) -> None: + value = self._settings_snapshot_value(msg) + await process_client.report_settings_update( + component_address=_component_address, + value=value, + ) + + report_settings_update = report_settings_update_cb + coro_callables[task_name] = partial( - handle_subscriber, sub, sub_callables[stream.address] + handle_subscriber, + sub, + sub_callables[stream.address], + on_message=report_settings_update, ) elif isinstance(stream, OutputStream): @@ -519,7 +559,9 @@ async def wrapped_task(msg: Any = None) -> None: async def handle_subscriber( - sub: Subscriber, callables: set[Callable[..., Coroutine[Any, Any, None]]] + sub: Subscriber, + callables: set[Callable[..., Coroutine[Any, Any, None]]], + on_message: Callable[[Any], Awaitable[None]] | None = None, ): """ Handle incoming messages from a subscriber and distribute to callables. @@ -547,6 +589,13 @@ async def handle_subscriber( if sub.leaky: msg = await sub.recv() try: + if on_message is not None: + try: + await on_message(msg) + except Exception as exc: + logger.warning( + f"Failed to report subscriber message metadata: {exc}" + ) for callable in list(callables): try: await callable(msg) @@ -557,6 +606,13 @@ async def handle_subscriber( else: async with sub.recv_zero_copy() as msg: try: + if on_message is not None: + try: + await on_message(msg) + except Exception as exc: + logger.warning( + f"Failed to report subscriber message metadata: {exc}" + ) for callable in list(callables): try: await callable(msg) diff --git a/src/ezmsg/core/graphcontext.py b/src/ezmsg/core/graphcontext.py index 3beddfdf..45877af4 100644 --- a/src/ezmsg/core/graphcontext.py +++ b/src/ezmsg/core/graphcontext.py @@ -22,14 +22,19 @@ from .graphserver import GraphServer, GraphService from .pubclient import Publisher from .subclient import Subscriber -from .graphmeta import GraphMetadata, GraphSnapshot +from .graphmeta import ( + GraphMetadata, + GraphSnapshot, + SettingsChangedEvent, + SettingsSnapshotValue, +) logger = logging.getLogger("ezmsg") class _SessionResponseKind(enum.Enum): BYTE = enum.auto() - SNAPSHOT = enum.auto() + PICKLED = enum.auto() @dataclass @@ -248,13 +253,13 @@ async def _session_io_loop(self) -> None: if cmd.response_kind == _SessionResponseKind.BYTE: response = await reader.read(1) - elif cmd.response_kind == _SessionResponseKind.SNAPSHOT: + elif cmd.response_kind == _SessionResponseKind.PICKLED: num_bytes = await read_int(reader) - snapshot_bytes = await reader.readexactly(num_bytes) + payload_bytes = await reader.readexactly(num_bytes) complete = await reader.read(1) if complete != Command.COMPLETE.value: - raise RuntimeError("Unexpected response to session snapshot") - response = pickle.loads(snapshot_bytes) + raise RuntimeError("Unexpected pickled response from session") + response = pickle.loads(payload_bytes) else: raise RuntimeError(f"Unsupported response kind: {cmd.response_kind}") @@ -340,12 +345,66 @@ async def register_metadata(self, metadata: GraphMetadata) -> None: async def snapshot(self) -> GraphSnapshot: snapshot = await self._session_command( Command.SESSION_SNAPSHOT, - response_kind=_SessionResponseKind.SNAPSHOT, + response_kind=_SessionResponseKind.PICKLED, ) if not isinstance(snapshot, GraphSnapshot): raise RuntimeError("Session snapshot payload was not a GraphSnapshot") return snapshot + async def settings_snapshot(self) -> dict[str, SettingsSnapshotValue]: + snapshot = await self._session_command( + Command.SESSION_SETTINGS_SNAPSHOT, + response_kind=_SessionResponseKind.PICKLED, + ) + if not isinstance(snapshot, dict): + raise RuntimeError("Settings snapshot payload was not a dictionary") + if not all(isinstance(value, SettingsSnapshotValue) for value in snapshot.values()): + raise RuntimeError("Settings snapshot payload contained invalid values") + return snapshot + + async def settings_events(self, after_seq: int = 0) -> list[SettingsChangedEvent]: + events = await self._session_command( + Command.SESSION_SETTINGS_EVENTS, + str(after_seq), + response_kind=_SessionResponseKind.PICKLED, + ) + if not isinstance(events, list): + raise RuntimeError("Settings event payload was not a list") + if not all(isinstance(event, SettingsChangedEvent) for event in events): + raise RuntimeError("Settings event payload contained invalid entries") + return events + + async def subscribe_settings_events( + self, + *, + after_seq: int = 0, + ) -> typing.AsyncIterator[SettingsChangedEvent]: + reader, writer = await GraphService(self.graph_address).open_connection() + writer.write(Command.SESSION_SETTINGS_SUBSCRIBE.value) + writer.write(encode_str(str(after_seq))) + await writer.drain() + + _subscriber_id = UUID(await read_str(reader)) + response = await reader.read(1) + if response != Command.COMPLETE.value: + await close_stream_writer(writer) + raise RuntimeError("Failed to subscribe to settings events") + + try: + while True: + payload_size = await read_int(reader) + payload = await reader.readexactly(payload_size) + event = pickle.loads(payload) + if not isinstance(event, SettingsChangedEvent): + raise RuntimeError( + "Settings subscription received invalid event payload" + ) + yield event + except asyncio.IncompleteReadError: + return + finally: + await close_stream_writer(writer) + async def _shutdown_servers(self) -> None: if self._graph_server is not None: self._graph_server.stop() diff --git a/src/ezmsg/core/graphmeta.py b/src/ezmsg/core/graphmeta.py index 94db0517..c54c1853 100644 --- a/src/ezmsg/core/graphmeta.py +++ b/src/ezmsg/core/graphmeta.py @@ -1,3 +1,5 @@ +import enum + from dataclasses import dataclass, field from typing import Any, TypeAlias, NamedTuple @@ -124,7 +126,7 @@ class GraphMetadata: @dataclass -class ProcessHello: +class ProcessRegistration: process_id: str pid: int host: str @@ -138,6 +140,36 @@ class ProcessOwnershipUpdate: removed_units: list[str] = field(default_factory=list) +@dataclass +class SettingsSnapshotValue: + serialized: bytes | None + repr_value: dict[str, Any] | str + + +class SettingsEventType(enum.Enum): + INITIAL_SETTINGS = "INITIAL_SETTINGS" + SETTINGS_UPDATED = "SETTINGS_UPDATED" + + +@dataclass +class SettingsChangedEvent: + seq: int + event_type: SettingsEventType + component_address: str + timestamp: float + source_session_id: str | None + source_process_id: str | None + value: SettingsSnapshotValue + + +@dataclass +class ProcessSettingsUpdate: + process_id: str + component_address: str + value: SettingsSnapshotValue + timestamp: float + + class Edge(NamedTuple): from_topic: str to_topic: str diff --git a/src/ezmsg/core/graphserver.py b/src/ezmsg/core/graphserver.py index 94cbd726..d9979805 100644 --- a/src/ezmsg/core/graphserver.py +++ b/src/ezmsg/core/graphserver.py @@ -4,6 +4,7 @@ import os import socket import threading +import time from contextlib import suppress from uuid import UUID, uuid1 @@ -15,8 +16,12 @@ Edge, GraphMetadata, GraphSnapshot, - ProcessHello, + ProcessRegistration, ProcessOwnershipUpdate, + ProcessSettingsUpdate, + SettingsChangedEvent, + SettingsEventType, + SettingsSnapshotValue, SnapshotProcess, SnapshotSession, ) @@ -77,6 +82,12 @@ class GraphServer(threading.Thread): _client_tasks: dict[UUID, "asyncio.Task[None]"] _command_lock: asyncio.Lock + _settings_current: dict[str, SettingsSnapshotValue] + _settings_source_session: dict[str, UUID | None] + _settings_events: list[SettingsChangedEvent] + _settings_event_seq: int + _settings_owned_by_session: dict[UUID, set[str]] + _settings_subscribers: dict[UUID, asyncio.Queue[SettingsChangedEvent]] def __init__(self, **kwargs) -> None: super().__init__( @@ -92,6 +103,12 @@ def __init__(self, **kwargs) -> None: self._client_tasks = {} self.shms = {} self._address = None + self._settings_current = {} + self._settings_source_session = {} + self._settings_events = [] + self._settings_event_seq = 0 + self._settings_owned_by_session = {} + self._settings_subscribers = {} @property def address(self) -> Address: @@ -287,6 +304,22 @@ async def api( # to avoid closing writer return + elif req == Command.SESSION_SETTINGS_SUBSCRIBE.value: + subscriber_id = uuid1() + after_seq = int(await read_str(reader)) + writer.write(encode_str(str(subscriber_id))) + writer.write(Command.COMPLETE.value) + await writer.drain() + self._client_tasks[subscriber_id] = asyncio.create_task( + self._handle_settings_subscriber( + subscriber_id, after_seq, reader, writer + ) + ) + + # NOTE: Created a stream client, must return early + # to avoid closing writer + return + elif req == Command.PROCESS.value: process_client_id = uuid1() self.clients[process_client_id] = ProcessInfo(process_client_id, writer) @@ -480,6 +513,17 @@ async def _handle_session( await self._handle_session_snapshot_request(writer) await writer.drain() + elif req == Command.SESSION_SETTINGS_SNAPSHOT.value: + await self._handle_session_settings_snapshot_request(writer) + await writer.drain() + + elif req == Command.SESSION_SETTINGS_EVENTS.value: + after_seq = int(await read_str(reader)) + await self._handle_session_settings_events_request( + writer, after_seq + ) + await writer.drain() + else: logger.warning( f"Session {session_id} rx unknown command from GraphServer: {req}" @@ -533,6 +577,13 @@ async def _handle_process( writer.write(response) await writer.drain() + elif req == Command.PROCESS_SETTINGS_UPDATE.value: + response = await self._handle_process_settings_update_request( + process_client_id, reader + ) + writer.write(response) + await writer.drain() + else: logger.warning( f"Process control {process_client_id} rx unknown command: {req}" @@ -548,18 +599,86 @@ async def _handle_process( self._client_tasks.pop(process_client_id, None) await close_stream_writer(writer) + def _queue_settings_event( + self, queue: asyncio.Queue[SettingsChangedEvent], event: SettingsChangedEvent + ) -> None: + try: + queue.put_nowait(event) + except asyncio.QueueFull: + # Keep most recent samples under backpressure. + with suppress(asyncio.QueueEmpty): + queue.get_nowait() + with suppress(asyncio.QueueFull): + queue.put_nowait(event) + + async def _settings_sender( + self, + subscriber_id: UUID, + queue: asyncio.Queue[SettingsChangedEvent], + writer: asyncio.StreamWriter, + ) -> None: + try: + while True: + event = await queue.get() + payload = pickle.dumps(event) + writer.write(uint64_to_bytes(len(payload))) + writer.write(payload) + await writer.drain() + except (ConnectionResetError, BrokenPipeError): + logger.debug(f"Settings subscriber {subscriber_id} disconnected on send") + except asyncio.CancelledError: + raise + + async def _handle_settings_subscriber( + self, + subscriber_id: UUID, + after_seq: int, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + ) -> None: + queue: asyncio.Queue[SettingsChangedEvent] = asyncio.Queue(maxsize=1024) + + async with self._command_lock: + self._settings_subscribers[subscriber_id] = queue + for event in self._settings_events: + if event.seq > after_seq: + self._queue_settings_event(queue, event) + + sender_task = asyncio.create_task( + self._settings_sender(subscriber_id, queue, writer), + name=f"settings-sender-{subscriber_id}", + ) + + try: + while True: + req = await reader.read(1) + if not req: + break + except (ConnectionResetError, BrokenPipeError) as e: + logger.debug(f"Settings subscriber {subscriber_id} disconnected: {e}") + finally: + async with self._command_lock: + self._settings_subscribers.pop(subscriber_id, None) + self._client_tasks.pop(subscriber_id, None) + sender_task.cancel() + with suppress(asyncio.CancelledError): + await sender_task + await close_stream_writer(writer) + async def _handle_process_register_request( self, process_client_id: UUID, reader: asyncio.StreamReader ) -> bytes: num_bytes = await read_int(reader) payload = await reader.readexactly(num_bytes) - hello: ProcessHello | None = None + registration: ProcessRegistration | None = None try: - hello_obj = pickle.loads(payload) - if isinstance(hello_obj, ProcessHello): - hello = hello_obj + payload_obj = pickle.loads(payload) + if isinstance(payload_obj, ProcessRegistration): + registration = payload_obj else: - raise RuntimeError("process registration payload was not ProcessHello") + raise RuntimeError( + "process registration payload was not ProcessRegistration" + ) except Exception as exc: logger.warning( "Process control %s registration parse failed; ignoring payload: %s", @@ -567,7 +686,7 @@ async def _handle_process_register_request( exc, ) - if hello is None: + if registration is None: return Command.COMPLETE.value async with self._command_lock: @@ -575,10 +694,10 @@ async def _handle_process_register_request( if process_info is None: return Command.COMPLETE.value - process_info.process_id = hello.process_id - process_info.pid = hello.pid - process_info.host = hello.host - process_info.units = set(hello.units) + process_info.process_id = registration.process_id + process_info.pid = registration.pid + process_info.host = registration.host + process_info.units = set(registration.units) return Command.COMPLETE.value @@ -629,6 +748,51 @@ async def _handle_process_update_ownership_request( return Command.COMPLETE.value + async def _handle_process_settings_update_request( + self, process_client_id: UUID, reader: asyncio.StreamReader + ) -> bytes: + num_bytes = await read_int(reader) + payload = await reader.readexactly(num_bytes) + update: ProcessSettingsUpdate | None = None + try: + update_obj = pickle.loads(payload) + if isinstance(update_obj, ProcessSettingsUpdate): + update = update_obj + else: + raise RuntimeError( + "process settings payload was not ProcessSettingsUpdate" + ) + except Exception as exc: + logger.warning( + "Process control %s settings update parse failed; ignoring payload: %s", + process_client_id, + exc, + ) + + if update is None: + return Command.COMPLETE.value + + async with self._command_lock: + process_info = self._process_info(process_client_id) + if process_info is None: + return Command.COMPLETE.value + + if process_info.process_id is None: + process_info.process_id = update.process_id + + self._settings_current[update.component_address] = update.value + self._settings_source_session[update.component_address] = None + self._append_settings_event_locked( + event_type=SettingsEventType.SETTINGS_UPDATED, + component_address=update.component_address, + value=update.value, + source_session_id=None, + source_process_id=update.process_id, + timestamp=update.timestamp, + ) + + return Command.COMPLETE.value + def _connect_owner( self, from_topic: str, to_topic: str, owner: UUID | str | None ) -> bool: @@ -655,6 +819,64 @@ def _session_info(self, session_id: UUID) -> SessionInfo | None: return info return None + def _append_settings_event_locked( + self, + event_type: SettingsEventType, + component_address: str, + value: SettingsSnapshotValue, + source_session_id: str | None, + source_process_id: str | None, + timestamp: float | None = None, + ) -> None: + self._settings_event_seq += 1 + event = SettingsChangedEvent( + seq=self._settings_event_seq, + event_type=event_type, + component_address=component_address, + timestamp=timestamp if timestamp is not None else time.time(), + source_session_id=source_session_id, + source_process_id=source_process_id, + value=value, + ) + self._settings_events.append(event) + + for queue in self._settings_subscribers.values(): + self._queue_settings_event(queue, event) + + # Bound memory growth for long-lived servers. + max_events = 10_000 + if len(self._settings_events) > max_events: + del self._settings_events[0 : len(self._settings_events) - max_events] + + def _remove_settings_for_session_locked(self, session_id: UUID) -> None: + component_addresses = self._settings_owned_by_session.pop(session_id, set()) + for component_address in component_addresses: + if self._settings_source_session.get(component_address) == session_id: + self._settings_current.pop(component_address, None) + self._settings_source_session.pop(component_address, None) + + def _apply_session_metadata_settings_locked( + self, session_id: UUID, metadata: GraphMetadata + ) -> None: + session_components: set[str] = set() + for component in metadata.components.values(): + value = SettingsSnapshotValue( + serialized=component.initial_settings[0], + repr_value=component.initial_settings[1], + ) + self._settings_current[component.address] = value + self._settings_source_session[component.address] = session_id + session_components.add(component.address) + self._append_settings_event_locked( + event_type=SettingsEventType.INITIAL_SETTINGS, + component_address=component.address, + value=value, + source_session_id=str(session_id), + source_process_id=None, + ) + + self._settings_owned_by_session[session_id] = session_components + async def _handle_session_edge_request( self, session_id: UUID, @@ -709,7 +931,9 @@ async def _handle_session_register_request( async with self._command_lock: session = self._session_info(session_id) if session is not None and metadata is not None: + self._remove_settings_for_session_locked(session_id) session.metadata = metadata + self._apply_session_metadata_settings_locked(session_id, metadata) return Command.COMPLETE.value @@ -723,6 +947,29 @@ async def _handle_session_snapshot_request( writer.write(snapshot_bytes) writer.write(Command.COMPLETE.value) + async def _handle_session_settings_snapshot_request( + self, writer: asyncio.StreamWriter + ) -> None: + async with self._command_lock: + snapshot = { + component_address: self._settings_current[component_address] + for component_address in sorted(self._settings_current) + } + snapshot_bytes = pickle.dumps(snapshot) + writer.write(uint64_to_bytes(len(snapshot_bytes))) + writer.write(snapshot_bytes) + writer.write(Command.COMPLETE.value) + + async def _handle_session_settings_events_request( + self, writer: asyncio.StreamWriter, after_seq: int + ) -> None: + async with self._command_lock: + events = [event for event in self._settings_events if event.seq > after_seq] + event_bytes = pickle.dumps(events) + writer.write(uint64_to_bytes(len(event_bytes))) + writer.write(event_bytes) + writer.write(Command.COMPLETE.value) + def _clear_session_state(self, session_id: UUID) -> set[str]: notify_topics: set[str] = set() session = self._session_info(session_id) @@ -733,6 +980,7 @@ def _clear_session_state(self, session_id: UUID) -> set[str]: if self._disconnect_owner(from_topic, to_topic, session_id): notify_topics.add(to_topic) + self._remove_settings_for_session_locked(session_id) session.metadata = None return notify_topics @@ -746,6 +994,7 @@ def _drop_session(self, session_id: UUID) -> set[str]: if self._disconnect_owner(from_topic, to_topic, session_id): notify_topics.add(to_topic) + self._remove_settings_for_session_locked(session_id) session.metadata = None self.clients.pop(session_id, None) return notify_topics diff --git a/src/ezmsg/core/netprotocol.py b/src/ezmsg/core/netprotocol.py index 1f717861..cba388de 100644 --- a/src/ezmsg/core/netprotocol.py +++ b/src/ezmsg/core/netprotocol.py @@ -330,11 +330,15 @@ def _generate_next_value_(name, start, count, last_values) -> bytes: SESSION_CLEAR = enum.auto() SESSION_REGISTER = enum.auto() SESSION_SNAPSHOT = enum.auto() + SESSION_SETTINGS_SNAPSHOT = enum.auto() + SESSION_SETTINGS_EVENTS = enum.auto() + SESSION_SETTINGS_SUBSCRIBE = enum.auto() # Backend Process Control Commands PROCESS = enum.auto() PROCESS_REGISTER = enum.auto() PROCESS_UPDATE_OWNERSHIP = enum.auto() + PROCESS_SETTINGS_UPDATE = enum.auto() def create_socket( diff --git a/src/ezmsg/core/processclient.py b/src/ezmsg/core/processclient.py index 8d24bb92..b7466078 100644 --- a/src/ezmsg/core/processclient.py +++ b/src/ezmsg/core/processclient.py @@ -3,10 +3,16 @@ import os import pickle import socket +import time from uuid import UUID, uuid1 -from .graphmeta import ProcessHello, ProcessOwnershipUpdate +from .graphmeta import ( + ProcessRegistration, + ProcessOwnershipUpdate, + ProcessSettingsUpdate, + SettingsSnapshotValue, +) from .graphserver import GraphService from .netprotocol import ( AddressType, @@ -65,7 +71,7 @@ async def connect(self) -> None: async def register(self, units: list[str]) -> None: await self.connect() - payload = ProcessHello( + payload = ProcessRegistration( process_id=self._process_id, pid=os.getpid(), host=socket.gethostname(), @@ -86,6 +92,21 @@ async def update_ownership( ) await self._payload_command(Command.PROCESS_UPDATE_OWNERSHIP, payload) + async def report_settings_update( + self, + component_address: str, + value: SettingsSnapshotValue, + timestamp: float | None = None, + ) -> None: + await self.connect() + payload = ProcessSettingsUpdate( + process_id=self._process_id, + component_address=component_address, + value=value, + timestamp=timestamp if timestamp is not None else time.time(), + ) + await self._payload_command(Command.PROCESS_SETTINGS_UPDATE, payload) + async def close(self) -> None: writer = self._writer if writer is None: diff --git a/tests/test_settings_api.py b/tests/test_settings_api.py new file mode 100644 index 00000000..b949a526 --- /dev/null +++ b/tests/test_settings_api.py @@ -0,0 +1,201 @@ +import asyncio +from dataclasses import dataclass + +import pytest + +import ezmsg.core as ez +from ezmsg.core.graphcontext import GraphContext +from ezmsg.core.graphmeta import ( + ComponentMetadata, + DynamicSettingsMetadata, + GraphMetadata, + SettingsEventType, + SettingsSnapshotValue, +) +from ezmsg.core.graphserver import GraphService +from ezmsg.core.processclient import ProcessControlClient + + +def _metadata_with_component(component_address: str) -> GraphMetadata: + return GraphMetadata( + schema_version=1, + root_name="SYS", + components={ + component_address: ComponentMetadata( + address=component_address, + name="UNIT", + component_type="example.Unit", + settings_type="example.Settings", + initial_settings=(None, {"alpha": 1}), + dynamic_settings=DynamicSettingsMetadata( + enabled=True, + input_topic=f"{component_address}/INPUT_SETTINGS", + settings_type="example.Settings", + ), + ) + }, + ) + + +@pytest.mark.asyncio +async def test_settings_snapshot_and_events_from_metadata_registration(): + graph_server = GraphService().create_server() + address = graph_server.address + + owner = GraphContext(address, auto_start=False) + observer = GraphContext(address, auto_start=False) + + await owner.__aenter__() + await observer.__aenter__() + + try: + component_address = "SYS/UNIT_A" + await owner.register_metadata(_metadata_with_component(component_address)) + + settings = await observer.settings_snapshot() + assert component_address in settings + assert settings[component_address].repr_value == {"alpha": 1} + + events = await observer.settings_events(after_seq=0) + matching = [ + event + for event in events + if event.component_address == component_address + and event.event_type == SettingsEventType.INITIAL_SETTINGS + ] + assert matching + assert matching[-1].source_session_id == str(owner._session_id) + finally: + await owner.__aexit__(None, None, None) + await observer.__aexit__(None, None, None) + graph_server.stop() + + +@dataclass +class _SettingsMsg: + gain: int + + +class _SettingsSource(ez.Unit): + OUTPUT = ez.OutputStream(_SettingsMsg) + + @ez.publisher(OUTPUT) + async def emit(self): + yield self.OUTPUT, _SettingsMsg(gain=7) + raise ez.Complete + + +class _SettingsSink(ez.Unit): + INPUT_SETTINGS = ez.InputStream(_SettingsMsg) + + @ez.subscriber(INPUT_SETTINGS) + async def on_settings(self, msg: _SettingsMsg) -> None: + raise ez.NormalTermination + + +class _SettingsSystem(ez.Collection): + SRC = _SettingsSource() + SINK = _SettingsSink() + + def network(self) -> ez.NetworkDefinition: + return ((self.SRC.OUTPUT, self.SINK.INPUT_SETTINGS),) + + +def test_input_settings_hook_reports_to_graphserver(): + graph_server = GraphService().create_server() + address = graph_server.address + try: + ez.run(components={"SYS": _SettingsSystem()}, graph_address=address, force_single_process=True) + + async def observe() -> None: + observer = GraphContext(address, auto_start=False) + await observer.__aenter__() + try: + settings = await observer.settings_snapshot() + sink_address = "SYS/SINK" + assert sink_address in settings + assert settings[sink_address].repr_value == {"gain": 7} + + events = await observer.settings_events(after_seq=0) + matching = [ + event + for event in events + if event.component_address == sink_address + and event.event_type == SettingsEventType.SETTINGS_UPDATED + ] + assert matching + finally: + await observer.__aexit__(None, None, None) + + asyncio.run(observe()) + finally: + graph_server.stop() + + +@pytest.mark.asyncio +async def test_process_reported_settings_update_visible_in_snapshot_and_events(): + graph_server = GraphService().create_server() + address = graph_server.address + + observer = GraphContext(address, auto_start=False) + await observer.__aenter__() + + process = ProcessControlClient(address, process_id="proc-settings") + await process.connect() + + try: + await process.register(["SYS/UNIT_B"]) + await process.report_settings_update( + component_address="SYS/UNIT_B", + value=SettingsSnapshotValue(serialized=None, repr_value={"gain": 2}), + ) + + settings = await observer.settings_snapshot() + assert settings["SYS/UNIT_B"].repr_value == {"gain": 2} + + events = await observer.settings_events(after_seq=0) + matching = [ + event + for event in events + if event.component_address == "SYS/UNIT_B" + and event.event_type == SettingsEventType.SETTINGS_UPDATED + ] + assert matching + assert matching[-1].source_process_id == "proc-settings" + + stream = observer.subscribe_settings_events(after_seq=0) + streamed = await asyncio.wait_for(anext(stream), timeout=1.0) + assert streamed.component_address == "SYS/UNIT_B" + await stream.aclose() + finally: + await process.close() + await observer.__aexit__(None, None, None) + graph_server.stop() + + +@pytest.mark.asyncio +async def test_session_owned_settings_removed_when_session_drops(): + graph_server = GraphService().create_server() + address = graph_server.address + + owner = GraphContext(address, auto_start=False) + observer = GraphContext(address, auto_start=False) + + await owner.__aenter__() + await observer.__aenter__() + + try: + component_address = "SYS/UNIT_C" + await owner.register_metadata(_metadata_with_component(component_address)) + settings = await observer.settings_snapshot() + assert component_address in settings + + await owner._close_session() + await asyncio.sleep(0.05) + + settings = await observer.settings_snapshot() + assert component_address not in settings + finally: + await owner.__aexit__(None, None, None) + await observer.__aexit__(None, None, None) + graph_server.stop() From c7c2160d74d0af367341dd3d3708673503723e59 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Sun, 15 Mar 2026 10:41:32 -0400 Subject: [PATCH 13/52] process control routing --- src/ezmsg/core/graphcontext.py | 21 ++++ src/ezmsg/core/graphmeta.py | 17 +++ src/ezmsg/core/graphserver.py | 217 +++++++++++++++++++++++++++++++- src/ezmsg/core/netprotocol.py | 4 + src/ezmsg/core/processclient.py | 172 +++++++++++++++++++++++-- tests/test_process_routing.py | 106 ++++++++++++++++ 6 files changed, 524 insertions(+), 13 deletions(-) create mode 100644 tests/test_process_routing.py diff --git a/src/ezmsg/core/graphcontext.py b/src/ezmsg/core/graphcontext.py index 45877af4..72f26235 100644 --- a/src/ezmsg/core/graphcontext.py +++ b/src/ezmsg/core/graphcontext.py @@ -25,6 +25,7 @@ from .graphmeta import ( GraphMetadata, GraphSnapshot, + ProcessControlResponse, SettingsChangedEvent, SettingsSnapshotValue, ) @@ -405,6 +406,26 @@ async def subscribe_settings_events( finally: await close_stream_writer(writer) + async def process_request( + self, + unit_address: str, + operation: str, + *, + payload: bytes | None = None, + timeout: float = 2.0, + ) -> ProcessControlResponse: + response = await self._session_command( + Command.SESSION_PROCESS_REQUEST, + unit_address, + operation, + str(timeout), + payload=payload if payload is not None else b"", + response_kind=_SessionResponseKind.PICKLED, + ) + if not isinstance(response, ProcessControlResponse): + raise RuntimeError("Session process request payload was not ProcessControlResponse") + return response + async def _shutdown_servers(self) -> None: if self._graph_server is not None: self._graph_server.stop() diff --git a/src/ezmsg/core/graphmeta.py b/src/ezmsg/core/graphmeta.py index c54c1853..eff42612 100644 --- a/src/ezmsg/core/graphmeta.py +++ b/src/ezmsg/core/graphmeta.py @@ -170,6 +170,23 @@ class ProcessSettingsUpdate: timestamp: float +@dataclass +class ProcessControlRequest: + request_id: str + unit_address: str + operation: str + payload: bytes | None = None + + +@dataclass +class ProcessControlResponse: + request_id: str + ok: bool + payload: bytes | None = None + error: str | None = None + process_id: str | None = None + + class Edge(NamedTuple): from_topic: str to_topic: str diff --git a/src/ezmsg/core/graphserver.py b/src/ezmsg/core/graphserver.py index d9979805..331c98a4 100644 --- a/src/ezmsg/core/graphserver.py +++ b/src/ezmsg/core/graphserver.py @@ -16,6 +16,8 @@ Edge, GraphMetadata, GraphSnapshot, + ProcessControlRequest, + ProcessControlResponse, ProcessRegistration, ProcessOwnershipUpdate, ProcessSettingsUpdate, @@ -88,6 +90,9 @@ class GraphServer(threading.Thread): _settings_event_seq: int _settings_owned_by_session: dict[UUID, set[str]] _settings_subscribers: dict[UUID, asyncio.Queue[SettingsChangedEvent]] + _pending_process_requests: dict[ + str, tuple[UUID, "asyncio.Future[ProcessControlResponse]"] + ] def __init__(self, **kwargs) -> None: super().__init__( @@ -109,6 +114,7 @@ def __init__(self, **kwargs) -> None: self._settings_event_seq = 0 self._settings_owned_by_session = {} self._settings_subscribers = {} + self._pending_process_requests = {} @property def address(self) -> Address: @@ -524,6 +530,10 @@ async def _handle_session( ) await writer.drain() + elif req == Command.SESSION_PROCESS_REQUEST.value: + await self._handle_session_process_request(writer, reader) + await writer.drain() + else: logger.warning( f"Session {session_id} rx unknown command from GraphServer: {req}" @@ -567,22 +577,30 @@ async def _handle_process( response = await self._handle_process_register_request( process_client_id, reader ) - writer.write(response) - await writer.drain() + await self._write_process_response( + process_client_id, writer, response + ) elif req == Command.PROCESS_UPDATE_OWNERSHIP.value: response = await self._handle_process_update_ownership_request( process_client_id, reader ) - writer.write(response) - await writer.drain() + await self._write_process_response( + process_client_id, writer, response + ) elif req == Command.PROCESS_SETTINGS_UPDATE.value: response = await self._handle_process_settings_update_request( process_client_id, reader ) - writer.write(response) - await writer.drain() + await self._write_process_response( + process_client_id, writer, response + ) + + elif req == Command.PROCESS_ROUTE_RESPONSE.value: + await self._handle_process_route_response_request( + process_client_id, reader + ) else: logger.warning( @@ -595,10 +613,52 @@ async def _handle_process( ) finally: + process_info = self._process_info(process_client_id) + + async with self._command_lock: + request_ids = [ + request_id + for request_id, (owner_process_id, _) in self._pending_process_requests.items() + if owner_process_id == process_client_id + ] + for request_id in request_ids: + pending = self._pending_process_requests.pop(request_id, None) + if pending is None: + continue + _, response_fut = pending + if not response_fut.done(): + response_fut.set_result( + ProcessControlResponse( + request_id=request_id, + ok=False, + error="Owning process disconnected before response", + process_id=( + process_info.process_id if process_info is not None else None + ), + ) + ) + self.clients.pop(process_client_id, None) self._client_tasks.pop(process_client_id, None) await close_stream_writer(writer) + async def _write_process_response( + self, + process_client_id: UUID, + fallback_writer: asyncio.StreamWriter, + response: bytes, + ) -> None: + process_info = self._process_info(process_client_id) + if process_info is None: + fallback_writer.write(response) + await fallback_writer.drain() + return + + async with process_info.write_lock: + writer = process_info.writer + writer.write(response) + await writer.drain() + def _queue_settings_event( self, queue: asyncio.Queue[SettingsChangedEvent], event: SettingsChangedEvent ) -> None: @@ -793,6 +853,151 @@ async def _handle_process_settings_update_request( return Command.COMPLETE.value + async def _handle_process_route_response_request( + self, process_client_id: UUID, reader: asyncio.StreamReader + ) -> None: + num_bytes = await read_int(reader) + payload = await reader.readexactly(num_bytes) + response: ProcessControlResponse | None = None + try: + response_obj = pickle.loads(payload) + if isinstance(response_obj, ProcessControlResponse): + response = response_obj + else: + raise RuntimeError( + "process route response payload was not ProcessControlResponse" + ) + except Exception as exc: + logger.warning( + "Process control %s route response parse failed; ignoring payload: %s", + process_client_id, + exc, + ) + + if response is None: + return + + async with self._command_lock: + pending = self._pending_process_requests.pop(response.request_id, None) + + if pending is None: + logger.warning( + "Process control %s returned unknown request_id: %s", + process_client_id, + response.request_id, + ) + return + + owner_process_id, response_fut = pending + if owner_process_id != process_client_id: + if not response_fut.done(): + response_fut.set_result( + ProcessControlResponse( + request_id=response.request_id, + ok=False, + error=( + "Received response from unexpected process " + f"{process_client_id}; expected {owner_process_id}" + ), + process_id=response.process_id, + ) + ) + return + + if not response_fut.done(): + response_fut.set_result(response) + + def _process_for_unit(self, unit_address: str) -> ProcessInfo | None: + for info in self.clients.values(): + if isinstance(info, ProcessInfo) and unit_address in info.units: + return info + return None + + async def _route_process_request( + self, + unit_address: str, + operation: str, + payload: bytes | None, + timeout: float, + ) -> ProcessControlResponse: + request_id = str(uuid1()) + response_fut: asyncio.Future[ProcessControlResponse] = ( + asyncio.get_running_loop().create_future() + ) + request = ProcessControlRequest( + request_id=request_id, + unit_address=unit_address, + operation=operation, + payload=payload, + ) + + async with self._command_lock: + process_info = self._process_for_unit(unit_address) + if process_info is None: + return ProcessControlResponse( + request_id=request_id, + ok=False, + error=f"No process owns unit '{unit_address}'", + ) + + self._pending_process_requests[request_id] = (process_info.id, response_fut) + + try: + async with process_info.write_lock: + process_writer = process_info.writer + request_bytes = pickle.dumps(request) + process_writer.write(Command.PROCESS_ROUTE_REQUEST.value) + process_writer.write(uint64_to_bytes(len(request_bytes))) + process_writer.write(request_bytes) + await process_writer.drain() + except Exception as exc: + self._pending_process_requests.pop(request_id, None) + return ProcessControlResponse( + request_id=request_id, + ok=False, + error=f"Failed to route request to owning process: {exc}", + process_id=process_info.process_id, + ) + + try: + return await asyncio.wait_for(response_fut, timeout=timeout) + except asyncio.TimeoutError: + async with self._command_lock: + self._pending_process_requests.pop(request_id, None) + return ProcessControlResponse( + request_id=request_id, + ok=False, + error=( + f"Timed out waiting for process response " + f"(unit={unit_address}, operation={operation}, timeout={timeout}s)" + ), + process_id=process_info.process_id, + ) + + async def _handle_session_process_request( + self, + writer: asyncio.StreamWriter, + reader: asyncio.StreamReader, + ) -> None: + unit_address = await read_str(reader) + operation = await read_str(reader) + timeout = float(await read_str(reader)) + payload_size = await read_int(reader) + payload: bytes | None = None + if payload_size > 0: + payload = await reader.readexactly(payload_size) + + response = await self._route_process_request( + unit_address=unit_address, + operation=operation, + payload=payload, + timeout=timeout, + ) + response_bytes = pickle.dumps(response) + writer.write(uint64_to_bytes(len(response_bytes))) + writer.write(response_bytes) + writer.write(Command.COMPLETE.value) + def _connect_owner( self, from_topic: str, to_topic: str, owner: UUID | str | None ) -> bool: diff --git a/src/ezmsg/core/netprotocol.py b/src/ezmsg/core/netprotocol.py index cba388de..b1332140 100644 --- a/src/ezmsg/core/netprotocol.py +++ b/src/ezmsg/core/netprotocol.py @@ -186,6 +186,7 @@ class ProcessInfo(ClientInfo): pid: int | None = None host: str | None = None units: set[str] = field(default_factory=set) + write_lock: asyncio.Lock = field(default_factory=asyncio.Lock, init=False) def uint64_to_bytes(i: int) -> bytes: @@ -333,12 +334,15 @@ def _generate_next_value_(name, start, count, last_values) -> bytes: SESSION_SETTINGS_SNAPSHOT = enum.auto() SESSION_SETTINGS_EVENTS = enum.auto() SESSION_SETTINGS_SUBSCRIBE = enum.auto() + SESSION_PROCESS_REQUEST = enum.auto() # Backend Process Control Commands PROCESS = enum.auto() PROCESS_REGISTER = enum.auto() PROCESS_UPDATE_OWNERSHIP = enum.auto() PROCESS_SETTINGS_UPDATE = enum.auto() + PROCESS_ROUTE_REQUEST = enum.auto() + PROCESS_ROUTE_RESPONSE = enum.auto() def create_socket( diff --git a/src/ezmsg/core/processclient.py b/src/ezmsg/core/processclient.py index b7466078..70cd534f 100644 --- a/src/ezmsg/core/processclient.py +++ b/src/ezmsg/core/processclient.py @@ -6,8 +6,12 @@ import time from uuid import UUID, uuid1 +from contextlib import suppress +from collections.abc import Awaitable, Callable from .graphmeta import ( + ProcessControlRequest, + ProcessControlResponse, ProcessRegistration, ProcessOwnershipUpdate, ProcessSettingsUpdate, @@ -18,6 +22,7 @@ AddressType, Command, close_stream_writer, + read_int, read_str, uint64_to_bytes, ) @@ -31,7 +36,12 @@ class ProcessControlClient: _client_id: UUID | None _reader: asyncio.StreamReader | None _writer: asyncio.StreamWriter | None - _lock: asyncio.Lock + _write_lock: asyncio.Lock + _ack_queue: asyncio.Queue[bytes] + _io_task: asyncio.Task[None] | None + _request_handler: Callable[ + [ProcessControlRequest], ProcessControlResponse | Awaitable[ProcessControlResponse] + ] | None def __init__( self, graph_address: AddressType | None = None, process_id: str | None = None @@ -41,7 +51,10 @@ def __init__( self._client_id = None self._reader = None self._writer = None - self._lock = asyncio.Lock() + self._write_lock = asyncio.Lock() + self._ack_queue = asyncio.Queue() + self._io_task = None + self._request_handler = None @property def process_id(self) -> str: @@ -68,6 +81,19 @@ async def connect(self) -> None: self._client_id = client_id self._reader = reader self._writer = writer + self._io_task = asyncio.create_task( + self._io_loop(), + name=f"process-control-{client_id}", + ) + + def set_request_handler( + self, + handler: Callable[ + [ProcessControlRequest], ProcessControlResponse | Awaitable[ProcessControlResponse] + ] + | None, + ) -> None: + self._request_handler = handler async def register(self, units: list[str]) -> None: await self.connect() @@ -112,26 +138,158 @@ async def close(self) -> None: if writer is None: return + io_task = self._io_task + self._io_task = None + if io_task is not None: + io_task.cancel() + with suppress(asyncio.CancelledError): + await io_task + self._reader = None self._writer = None self._client_id = None await close_stream_writer(writer) async def _payload_command(self, command: Command, payload_obj: object) -> None: + await self._write_payload(command, payload_obj, expect_complete=True) + + async def _write_payload( + self, + command: Command, + payload_obj: object, + *, + expect_complete: bool, + ) -> None: reader = self._reader writer = self._writer if reader is None or writer is None: raise RuntimeError("Process control connection is not active") payload = pickle.dumps(payload_obj) - async with self._lock: + async with self._write_lock: writer.write(command.value) writer.write(uint64_to_bytes(len(payload))) writer.write(payload) await writer.drain() - response = await reader.read(1) - if response != Command.COMPLETE.value: - raise RuntimeError( - f"Unexpected response to process control command: {command.name}" + if not expect_complete: + return + + try: + response = await asyncio.wait_for(self._ack_queue.get(), timeout=5.0) + except asyncio.TimeoutError as exc: + raise RuntimeError( + f"Timed out waiting for response to process control command: {command.name}" + ) from exc + + if response != Command.COMPLETE.value: + raise RuntimeError( + f"Unexpected response to process control command: {command.name}" + ) + + async def _io_loop(self) -> None: + reader = self._reader + writer = self._writer + if reader is None or writer is None: + return + + try: + while True: + req = await reader.read(1) + if not req: + break + + if req == Command.COMPLETE.value: + self._ack_queue.put_nowait(req) + continue + + if req != Command.PROCESS_ROUTE_REQUEST.value: + logger.warning( + "Process control %s received unknown command: %s", + self._client_id, + req, + ) + continue + + payload_size = await read_int(reader) + payload = await reader.readexactly(payload_size) + request: ProcessControlRequest | None = None + try: + request_obj = pickle.loads(payload) + if isinstance(request_obj, ProcessControlRequest): + request = request_obj + else: + raise RuntimeError( + "process route request payload was not ProcessControlRequest" + ) + except Exception as exc: + logger.warning( + "Process control %s failed to parse route request: %s", + self._client_id, + exc, + ) + + if request is None: + continue + + response = await self._handle_route_request(request) + await self._write_payload( + Command.PROCESS_ROUTE_RESPONSE, + response, + expect_complete=False, ) + + except asyncio.CancelledError: + raise + except (ConnectionResetError, BrokenPipeError) as exc: + logger.debug(f"Process control {self._client_id} disconnected: {exc}") + + async def _handle_route_request( + self, request: ProcessControlRequest + ) -> ProcessControlResponse: + if self._request_handler is None: + return ProcessControlResponse( + request_id=request.request_id, + ok=False, + error="process request handler is not configured", + process_id=self._process_id, + ) + + try: + result = self._request_handler(request) + if asyncio.iscoroutine(result): + result = await result + except Exception as exc: + return ProcessControlResponse( + request_id=request.request_id, + ok=False, + error=f"process request handler failed: {exc}", + process_id=self._process_id, + ) + + if not isinstance(result, ProcessControlResponse): + return ProcessControlResponse( + request_id=request.request_id, + ok=False, + error=( + "process request handler returned invalid response type: " + f"{type(result).__name__}" + ), + process_id=self._process_id, + ) + + if result.request_id != request.request_id: + result = ProcessControlResponse( + request_id=request.request_id, + ok=False, + error=( + "process request handler returned mismatched request_id: " + f"{result.request_id}" + ), + process_id=self._process_id, + ) + + if result.process_id is None: + result.process_id = self._process_id + + return result diff --git a/tests/test_process_routing.py b/tests/test_process_routing.py new file mode 100644 index 00000000..93fcb734 --- /dev/null +++ b/tests/test_process_routing.py @@ -0,0 +1,106 @@ +import asyncio + +import pytest + +from ezmsg.core.graphcontext import GraphContext +from ezmsg.core.graphmeta import ProcessControlRequest, ProcessControlResponse +from ezmsg.core.graphserver import GraphService +from ezmsg.core.processclient import ProcessControlClient + + +@pytest.mark.asyncio +async def test_process_routing_roundtrip(): + graph_server = GraphService().create_server() + address = graph_server.address + + observer = GraphContext(address, auto_start=False) + await observer.__aenter__() + + process = ProcessControlClient(address, process_id="proc-route") + await process.connect() + await process.register(["SYS/U1"]) + + async def handler(request: ProcessControlRequest) -> ProcessControlResponse: + assert request.unit_address == "SYS/U1" + assert request.operation == "ECHO" + return ProcessControlResponse( + request_id=request.request_id, + ok=True, + payload=request.payload, + ) + + process.set_request_handler(handler) + + try: + response = await observer.process_request( + "SYS/U1", + "ECHO", + payload=b"hello", + timeout=1.0, + ) + assert response.ok + assert response.payload == b"hello" + assert response.process_id == "proc-route" + finally: + await process.close() + await observer.__aexit__(None, None, None) + graph_server.stop() + + +@pytest.mark.asyncio +async def test_process_routing_missing_owner_returns_error(): + graph_server = GraphService().create_server() + address = graph_server.address + + observer = GraphContext(address, auto_start=False) + await observer.__aenter__() + + try: + response = await observer.process_request( + "SYS/UNKNOWN", + "PING", + payload=b"", + timeout=0.25, + ) + assert not response.ok + assert response.error is not None + assert "No process owns unit" in response.error + finally: + await observer.__aexit__(None, None, None) + graph_server.stop() + + +@pytest.mark.asyncio +async def test_process_routing_timeout_returns_error(): + graph_server = GraphService().create_server() + address = graph_server.address + + observer = GraphContext(address, auto_start=False) + await observer.__aenter__() + + process = ProcessControlClient(address, process_id="proc-timeout") + await process.connect() + await process.register(["SYS/U2"]) + + block = asyncio.Event() + + async def blocking_handler(_request: ProcessControlRequest) -> ProcessControlResponse: + await block.wait() + return ProcessControlResponse(request_id="", ok=False) + + process.set_request_handler(blocking_handler) + + try: + response = await observer.process_request( + "SYS/U2", + "SLOW", + timeout=0.05, + ) + assert not response.ok + assert response.error is not None + assert "Timed out waiting for process response" in response.error + assert response.process_id == "proc-timeout" + finally: + await process.close() + await observer.__aexit__(None, None, None) + graph_server.stop() From ab746c59618aaadd5fbd07488bbb2a031f2a7e21 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Mon, 16 Mar 2026 11:17:25 -0400 Subject: [PATCH 14/52] first process-control commands implemented --- src/ezmsg/core/graphcontext.py | 66 ++++++++++++++++++++++++++++++++- src/ezmsg/core/graphmeta.py | 36 +++++++++++++++++- src/ezmsg/core/graphserver.py | 6 +++ src/ezmsg/core/processclient.py | 62 +++++++++++++++++++++++++++++-- tests/test_process_routing.py | 40 +++++++++++++++++++- 5 files changed, 202 insertions(+), 8 deletions(-) diff --git a/src/ezmsg/core/graphcontext.py b/src/ezmsg/core/graphcontext.py index 72f26235..d0e0f326 100644 --- a/src/ezmsg/core/graphcontext.py +++ b/src/ezmsg/core/graphcontext.py @@ -23,8 +23,11 @@ from .pubclient import Publisher from .subclient import Subscriber from .graphmeta import ( + ProcessControlOperation, GraphMetadata, GraphSnapshot, + ProcessPing, + ProcessStats, ProcessControlResponse, SettingsChangedEvent, SettingsSnapshotValue, @@ -409,15 +412,25 @@ async def subscribe_settings_events( async def process_request( self, unit_address: str, - operation: str, + operation: ProcessControlOperation | str, *, payload: bytes | None = None, + payload_obj: object | None = None, timeout: float = 2.0, ) -> ProcessControlResponse: + if payload is not None and payload_obj is not None: + raise ValueError("Specify only one of payload or payload_obj") + + if payload_obj is not None: + payload = pickle.dumps(payload_obj) + + operation_name = ( + operation.value if isinstance(operation, ProcessControlOperation) else operation + ) response = await self._session_command( Command.SESSION_PROCESS_REQUEST, unit_address, - operation, + operation_name, str(timeout), payload=payload if payload is not None else b"", response_kind=_SessionResponseKind.PICKLED, @@ -426,6 +439,55 @@ async def process_request( raise RuntimeError("Session process request payload was not ProcessControlResponse") return response + async def process_ping( + self, + unit_address: str, + *, + timeout: float = 2.0, + ) -> ProcessPing: + response = await self.process_request( + unit_address, + ProcessControlOperation.PING, + timeout=timeout, + ) + return typing.cast(ProcessPing, self.decode_process_payload(response, ProcessPing)) + + async def process_stats( + self, + unit_address: str, + *, + timeout: float = 2.0, + ) -> ProcessStats: + response = await self.process_request( + unit_address, + ProcessControlOperation.GET_PROCESS_STATS, + timeout=timeout, + ) + return typing.cast( + ProcessStats, self.decode_process_payload(response, ProcessStats) + ) + + def decode_process_payload( + self, + response: ProcessControlResponse, + expected_type: type[object] = object, + ) -> object: + if not response.ok: + raise RuntimeError( + f"Process request failed ({response.error_code}): {response.error}" + ) + if response.payload is None: + raise RuntimeError("Process response did not include a payload") + decoded = pickle.loads(response.payload) + if expected_type is object: + return decoded + if not isinstance(decoded, expected_type): + raise RuntimeError( + "Unexpected process payload type: " + f"{type(decoded).__name__} (expected {expected_type.__name__})" + ) + return decoded + async def _shutdown_servers(self) -> None: if self._graph_server is not None: self._graph_server.stop() diff --git a/src/ezmsg/core/graphmeta.py b/src/ezmsg/core/graphmeta.py index eff42612..254a55ac 100644 --- a/src/ezmsg/core/graphmeta.py +++ b/src/ezmsg/core/graphmeta.py @@ -174,19 +174,53 @@ class ProcessSettingsUpdate: class ProcessControlRequest: request_id: str unit_address: str - operation: str + operation: "ProcessControlOperation | str" payload: bytes | None = None +class ProcessControlOperation(enum.Enum): + PING = "PING" + GET_PROCESS_STATS = "GET_PROCESS_STATS" + + +class ProcessControlErrorCode(enum.Enum): + UNROUTABLE_UNIT = "UNROUTABLE_UNIT" + ROUTE_WRITE_FAILED = "ROUTE_WRITE_FAILED" + TIMEOUT = "TIMEOUT" + PROCESS_DISCONNECTED = "PROCESS_DISCONNECTED" + UNSUPPORTED_OPERATION = "UNSUPPORTED_OPERATION" + HANDLER_NOT_CONFIGURED = "HANDLER_NOT_CONFIGURED" + HANDLER_ERROR = "HANDLER_ERROR" + INVALID_RESPONSE = "INVALID_RESPONSE" + + @dataclass class ProcessControlResponse: request_id: str ok: bool payload: bytes | None = None error: str | None = None + error_code: ProcessControlErrorCode | None = None process_id: str | None = None +@dataclass +class ProcessPing: + process_id: str + pid: int + host: str + timestamp: float + + +@dataclass +class ProcessStats: + process_id: str + pid: int + host: str + owned_units: list[str] + timestamp: float + + class Edge(NamedTuple): from_topic: str to_topic: str diff --git a/src/ezmsg/core/graphserver.py b/src/ezmsg/core/graphserver.py index 331c98a4..686ccb69 100644 --- a/src/ezmsg/core/graphserver.py +++ b/src/ezmsg/core/graphserver.py @@ -14,6 +14,7 @@ from .graph_util import get_compactified_graph, graph_string, prune_graph_connections from .graphmeta import ( Edge, + ProcessControlErrorCode, GraphMetadata, GraphSnapshot, ProcessControlRequest, @@ -632,6 +633,7 @@ async def _handle_process( request_id=request_id, ok=False, error="Owning process disconnected before response", + error_code=ProcessControlErrorCode.PROCESS_DISCONNECTED, process_id=( process_info.process_id if process_info is not None else None ), @@ -899,6 +901,7 @@ async def _handle_process_route_response_request( "Received response from unexpected process " f"{process_client_id}; expected {owner_process_id}" ), + error_code=ProcessControlErrorCode.INVALID_RESPONSE, process_id=response.process_id, ) ) @@ -938,6 +941,7 @@ async def _route_process_request( request_id=request_id, ok=False, error=f"No process owns unit '{unit_address}'", + error_code=ProcessControlErrorCode.UNROUTABLE_UNIT, ) self._pending_process_requests[request_id] = (process_info.id, response_fut) @@ -956,6 +960,7 @@ async def _route_process_request( request_id=request_id, ok=False, error=f"Failed to route request to owning process: {exc}", + error_code=ProcessControlErrorCode.ROUTE_WRITE_FAILED, process_id=process_info.process_id, ) @@ -971,6 +976,7 @@ async def _route_process_request( f"Timed out waiting for process response " f"(unit={unit_address}, operation={operation}, timeout={timeout}s)" ), + error_code=ProcessControlErrorCode.TIMEOUT, process_id=process_info.process_id, ) diff --git a/src/ezmsg/core/processclient.py b/src/ezmsg/core/processclient.py index 70cd534f..1950f023 100644 --- a/src/ezmsg/core/processclient.py +++ b/src/ezmsg/core/processclient.py @@ -10,9 +10,13 @@ from collections.abc import Awaitable, Callable from .graphmeta import ( + ProcessControlErrorCode, + ProcessControlOperation, ProcessControlRequest, ProcessControlResponse, + ProcessPing, ProcessRegistration, + ProcessStats, ProcessOwnershipUpdate, ProcessSettingsUpdate, SettingsSnapshotValue, @@ -42,6 +46,7 @@ class ProcessControlClient: _request_handler: Callable[ [ProcessControlRequest], ProcessControlResponse | Awaitable[ProcessControlResponse] ] | None + _owned_units: set[str] def __init__( self, graph_address: AddressType | None = None, process_id: str | None = None @@ -55,6 +60,7 @@ def __init__( self._ack_queue = asyncio.Queue() self._io_task = None self._request_handler = None + self._owned_units = set() @property def process_id(self) -> str: @@ -97,13 +103,15 @@ def set_request_handler( async def register(self, units: list[str]) -> None: await self.connect() + normalized_units = sorted(set(units)) payload = ProcessRegistration( process_id=self._process_id, pid=os.getpid(), host=socket.gethostname(), - units=sorted(set(units)), + units=normalized_units, ) await self._payload_command(Command.PROCESS_REGISTER, payload) + self._owned_units = set(normalized_units) async def update_ownership( self, @@ -111,12 +119,16 @@ async def update_ownership( removed_units: list[str] | None = None, ) -> None: await self.connect() + added = sorted(set(added_units or [])) + removed = sorted(set(removed_units or [])) payload = ProcessOwnershipUpdate( process_id=self._process_id, - added_units=sorted(set(added_units or [])), - removed_units=sorted(set(removed_units or [])), + added_units=added, + removed_units=removed, ) await self._payload_command(Command.PROCESS_UPDATE_OWNERSHIP, payload) + self._owned_units.update(added) + self._owned_units.difference_update(removed) async def report_settings_update( self, @@ -247,11 +259,50 @@ async def _io_loop(self) -> None: async def _handle_route_request( self, request: ProcessControlRequest ) -> ProcessControlResponse: + operation: ProcessControlOperation | None = None + if isinstance(request.operation, ProcessControlOperation): + operation = request.operation + elif isinstance(request.operation, str): + with suppress(ValueError): + operation = ProcessControlOperation(request.operation) + + if operation == ProcessControlOperation.PING: + return ProcessControlResponse( + request_id=request.request_id, + ok=True, + payload=pickle.dumps( + ProcessPing( + process_id=self._process_id, + pid=os.getpid(), + host=socket.gethostname(), + timestamp=time.time(), + ) + ), + process_id=self._process_id, + ) + + if operation == ProcessControlOperation.GET_PROCESS_STATS: + return ProcessControlResponse( + request_id=request.request_id, + ok=True, + payload=pickle.dumps( + ProcessStats( + process_id=self._process_id, + pid=os.getpid(), + host=socket.gethostname(), + owned_units=sorted(self._owned_units), + timestamp=time.time(), + ) + ), + process_id=self._process_id, + ) + if self._request_handler is None: return ProcessControlResponse( request_id=request.request_id, ok=False, - error="process request handler is not configured", + error=f"Unsupported process control operation: {request.operation}", + error_code=ProcessControlErrorCode.HANDLER_NOT_CONFIGURED, process_id=self._process_id, ) @@ -264,6 +315,7 @@ async def _handle_route_request( request_id=request.request_id, ok=False, error=f"process request handler failed: {exc}", + error_code=ProcessControlErrorCode.HANDLER_ERROR, process_id=self._process_id, ) @@ -275,6 +327,7 @@ async def _handle_route_request( "process request handler returned invalid response type: " f"{type(result).__name__}" ), + error_code=ProcessControlErrorCode.INVALID_RESPONSE, process_id=self._process_id, ) @@ -286,6 +339,7 @@ async def _handle_route_request( "process request handler returned mismatched request_id: " f"{result.request_id}" ), + error_code=ProcessControlErrorCode.INVALID_RESPONSE, process_id=self._process_id, ) diff --git a/tests/test_process_routing.py b/tests/test_process_routing.py index 93fcb734..f093bcb5 100644 --- a/tests/test_process_routing.py +++ b/tests/test_process_routing.py @@ -3,7 +3,11 @@ import pytest from ezmsg.core.graphcontext import GraphContext -from ezmsg.core.graphmeta import ProcessControlRequest, ProcessControlResponse +from ezmsg.core.graphmeta import ( + ProcessControlErrorCode, + ProcessControlRequest, + ProcessControlResponse, +) from ezmsg.core.graphserver import GraphService from ezmsg.core.processclient import ProcessControlClient @@ -47,6 +51,38 @@ async def handler(request: ProcessControlRequest) -> ProcessControlResponse: graph_server.stop() +@pytest.mark.asyncio +async def test_process_routing_builtin_ping_and_stats(): + graph_server = GraphService().create_server() + address = graph_server.address + + observer = GraphContext(address, auto_start=False) + await observer.__aenter__() + + process = ProcessControlClient(address, process_id="proc-builtins") + await process.connect() + await process.register(["SYS/U1", "SYS/U2"]) + await process.update_ownership(removed_units=["SYS/U2"], added_units=["SYS/U3"]) + + try: + ping = await observer.process_ping("SYS/U1", timeout=1.0) + assert ping.process_id == "proc-builtins" + assert ping.pid > 0 + assert ping.host + assert ping.timestamp > 0.0 + + stats = await observer.process_stats("SYS/U1", timeout=1.0) + assert stats.process_id == "proc-builtins" + assert stats.pid > 0 + assert stats.host + assert stats.owned_units == ["SYS/U1", "SYS/U3"] + assert stats.timestamp > 0.0 + finally: + await process.close() + await observer.__aexit__(None, None, None) + graph_server.stop() + + @pytest.mark.asyncio async def test_process_routing_missing_owner_returns_error(): graph_server = GraphService().create_server() @@ -65,6 +101,7 @@ async def test_process_routing_missing_owner_returns_error(): assert not response.ok assert response.error is not None assert "No process owns unit" in response.error + assert response.error_code == ProcessControlErrorCode.UNROUTABLE_UNIT finally: await observer.__aexit__(None, None, None) graph_server.stop() @@ -99,6 +136,7 @@ async def blocking_handler(_request: ProcessControlRequest) -> ProcessControlRes assert not response.ok assert response.error is not None assert "Timed out waiting for process response" in response.error + assert response.error_code == ProcessControlErrorCode.TIMEOUT assert response.process_id == "proc-timeout" finally: await process.close() From ce8e192b7417c1f70e2c7072f5f756be6a8c5dee Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Mon, 16 Mar 2026 14:13:35 -0400 Subject: [PATCH 15/52] profiling backend --- src/ezmsg/core/backendprocess.py | 26 +- src/ezmsg/core/graphcontext.py | 106 ++++++++- src/ezmsg/core/graphmeta.py | 80 +++++++ src/ezmsg/core/messagechannel.py | 43 ++++ src/ezmsg/core/processclient.py | 67 ++++++ src/ezmsg/core/profiling.py | 395 +++++++++++++++++++++++++++++++ src/ezmsg/core/pubclient.py | 19 ++ src/ezmsg/core/subclient.py | 30 ++- tests/test_profiling_api.py | 137 +++++++++++ 9 files changed, 882 insertions(+), 21 deletions(-) create mode 100644 src/ezmsg/core/profiling.py create mode 100644 tests/test_profiling_api.py diff --git a/src/ezmsg/core/backendprocess.py b/src/ezmsg/core/backendprocess.py index 1f193ac4..1c3ed487 100644 --- a/src/ezmsg/core/backendprocess.py +++ b/src/ezmsg/core/backendprocess.py @@ -4,7 +4,6 @@ import inspect import os import pickle -import time import traceback import threading import weakref @@ -28,6 +27,7 @@ from .graphcontext import GraphContext from .graphmeta import SettingsSnapshotValue +from .profiling import PROFILES, PROFILE_TIME from .processclient import ProcessControlClient from .pubclient import Publisher from .subclient import Subscriber @@ -243,6 +243,7 @@ def process(self, loop: asyncio.AbstractEventLoop) -> None: main_func = None context = GraphContext(self.graph_address) process_client = ProcessControlClient(self.graph_address) + PROFILES.set_process_id(process_client.process_id) process_register_future: concurrent.futures.Future[None] | None = None coro_callables: dict[str, Callable[[], Coroutine[Any, Any, None]]] = dict() self._shutdown_errors = False @@ -502,11 +503,12 @@ async def publish(stream: Stream, obj: Any) -> None: await asyncio.sleep(0) async def perf_publish(stream: Stream, obj: Any) -> None: - start = time.perf_counter() + start = PROFILE_TIME() await publish(stream, obj) - stop = time.perf_counter() + stop = PROFILE_TIME() logger.info( - f"{task_address} send duration = " + f"{(stop - start) * 1e3:0.4f}ms" + f"{task_address} send duration = " + f"{((stop - start) / 1_000_000.0):0.4f}ms" ) pub_fn = perf_publish if hasattr(task, TIMEIT_ATTR) else publish @@ -598,7 +600,13 @@ async def handle_subscriber( ) for callable in list(callables): try: - await callable(msg) + span_start_ns = sub.begin_profile() + try: + await callable(msg) + finally: + sub.end_profile( + span_start_ns, getattr(callable, "__name__", None) + ) except (Complete, NormalTermination): callables.remove(callable) finally: @@ -615,7 +623,13 @@ async def handle_subscriber( ) for callable in list(callables): try: - await callable(msg) + span_start_ns = sub.begin_profile() + try: + await callable(msg) + finally: + sub.end_profile( + span_start_ns, getattr(callable, "__name__", None) + ) except (Complete, NormalTermination): callables.remove(callable) finally: diff --git a/src/ezmsg/core/graphcontext.py b/src/ezmsg/core/graphcontext.py index d0e0f326..cdf040c0 100644 --- a/src/ezmsg/core/graphcontext.py +++ b/src/ezmsg/core/graphcontext.py @@ -27,8 +27,11 @@ GraphMetadata, GraphSnapshot, ProcessPing, + ProcessProfilingSnapshot, + ProcessProfilingTraceBatch, ProcessStats, ProcessControlResponse, + ProfilingTraceControl, SettingsChangedEvent, SettingsSnapshotValue, ) @@ -52,25 +55,35 @@ class _SessionCommand: class GraphContext: """ - GraphContext maintains a list of created publishers, subscribers, and connections in the graph. - - The GraphContext provides a managed environment for creating and tracking publishers, - subscribers, and graph connections. When the context is no longer needed, it can - revert changes in the graph which disconnects publishers and removes modifications - that this context made. - - It also maintains a context manager that ensures the GraphServer is running. - - :param graph_service: Optional graph service instance to use - :type graph_service: GraphService | None + Session-scoped client for graph mutation, metadata, settings, and process control. + + `GraphContext` opens a session connection to `GraphServer` and acts as a control + plane for both low-level graph operations and high-level API introspection. + + Core capabilities: + - Create/track `Publisher` and `Subscriber` clients. + - Connect/disconnect topic edges owned by this session. + - Register high-level `GraphMetadata`. + - Read graph snapshots (topology, edge ownership, sessions, process ownership). + - Query settings snapshots/events and subscribe to push-based settings updates. + - Route process-control requests (ping/stats/profiling and custom operations). + - Revert all session-owned mutations on context exit (`SESSION_CLEAR`). + + Session semantics: + - Mutations and metadata are tied to the session lifecycle. + - If the session disconnects, session-owned graph state is dropped by server cleanup. + - Low-level pub/sub API usage remains supported independently of metadata. + + :param graph_address: Graph server address. If `None`, defaults are used. + :type graph_address: AddressType | None :param auto_start: Whether to auto-start a GraphServer if connection fails. If None, defaults to auto-start only when graph_address is not provided and no environment override is set. :type auto_start: bool | None .. note:: - The GraphContext is typically managed automatically by the ezmsg runtime - and doesn't need to be instantiated directly by user code. + `GraphContext` is used by the runtime, and can also be used directly by tools + (inspectors, profilers, dashboards, and operational scripts). """ _clients: set[Publisher | Subscriber] @@ -467,6 +480,73 @@ async def process_stats( ProcessStats, self.decode_process_payload(response, ProcessStats) ) + async def process_profiling_snapshot( + self, + unit_address: str, + *, + timeout: float = 2.0, + ) -> ProcessProfilingSnapshot: + response = await self.process_request( + unit_address, + ProcessControlOperation.GET_PROFILING_SNAPSHOT, + timeout=timeout, + ) + return typing.cast( + ProcessProfilingSnapshot, + self.decode_process_payload(response, ProcessProfilingSnapshot), + ) + + async def process_set_profiling_trace( + self, + unit_address: str, + control: ProfilingTraceControl, + *, + timeout: float = 2.0, + ) -> ProcessControlResponse: + return await self.process_request( + unit_address, + ProcessControlOperation.SET_PROFILING_TRACE, + payload_obj=control, + timeout=timeout, + ) + + async def process_profiling_trace_batch( + self, + unit_address: str, + *, + max_samples: int = 1000, + timeout: float = 2.0, + ) -> ProcessProfilingTraceBatch: + response = await self.process_request( + unit_address, + ProcessControlOperation.GET_PROFILING_TRACE_BATCH, + payload_obj=max_samples, + timeout=timeout, + ) + return typing.cast( + ProcessProfilingTraceBatch, + self.decode_process_payload(response, ProcessProfilingTraceBatch), + ) + + async def profiling_snapshot_all( + self, + *, + timeout_per_process: float = 0.5, + ) -> dict[str, ProcessProfilingSnapshot]: + graph_snapshot = await self.snapshot() + out: dict[str, ProcessProfilingSnapshot] = {} + for process in graph_snapshot.processes.values(): + if len(process.units) == 0: + continue + route_unit = process.units[0] + try: + out[process.process_id] = await self.process_profiling_snapshot( + route_unit, timeout=timeout_per_process + ) + except Exception: + continue + return out + def decode_process_payload( self, response: ProcessControlResponse, diff --git a/src/ezmsg/core/graphmeta.py b/src/ezmsg/core/graphmeta.py index 254a55ac..42e83722 100644 --- a/src/ezmsg/core/graphmeta.py +++ b/src/ezmsg/core/graphmeta.py @@ -181,6 +181,9 @@ class ProcessControlRequest: class ProcessControlOperation(enum.Enum): PING = "PING" GET_PROCESS_STATS = "GET_PROCESS_STATS" + GET_PROFILING_SNAPSHOT = "GET_PROFILING_SNAPSHOT" + SET_PROFILING_TRACE = "SET_PROFILING_TRACE" + GET_PROFILING_TRACE_BATCH = "GET_PROFILING_TRACE_BATCH" class ProcessControlErrorCode(enum.Enum): @@ -221,6 +224,83 @@ class ProcessStats: timestamp: float +class ProfileChannelType(enum.Enum): + LOCAL = "LOCAL" + SHM = "SHM" + TCP = "TCP" + UNKNOWN = "UNKNOWN" + + +@dataclass +class PublisherProfileSnapshot: + endpoint_id: str + topic: str + messages_published_total: int + messages_published_window: int + publish_delta_ns_avg_window: float + publish_rate_hz_window: float + inflight_messages_current: int + inflight_messages_peak_window: int + backpressure_wait_ns_total: int + backpressure_wait_ns_window: int + timestamp: float + + +@dataclass +class SubscriberProfileSnapshot: + endpoint_id: str + topic: str + messages_received_total: int + messages_received_window: int + lease_time_ns_total: int + lease_time_ns_avg_window: float + user_span_ns_total: int + user_span_ns_avg_window: float + attributable_backpressure_ns_total: int + attributable_backpressure_ns_window: int + attributable_backpressure_events_total: int + channel_kind_last: ProfileChannelType + timestamp: float + + +@dataclass +class ProcessProfilingSnapshot: + process_id: str + pid: int + host: str + window_seconds: float + timestamp: float + publishers: dict[str, PublisherProfileSnapshot] + subscribers: dict[str, SubscriberProfileSnapshot] + + +@dataclass +class ProfilingTraceControl: + enabled: bool + sample_mod: int = 1 + publisher_topics: list[str] | None = None + subscriber_topics: list[str] | None = None + + +@dataclass +class ProfilingTraceSample: + timestamp: float + endpoint_id: str + topic: str + metric: str + value: float + channel_kind: ProfileChannelType | None = None + + +@dataclass +class ProcessProfilingTraceBatch: + process_id: str + pid: int + host: str + timestamp: float + samples: list[ProfilingTraceSample] + + class Edge(NamedTuple): from_topic: str to_topic: str diff --git a/src/ezmsg/core/messagechannel.py b/src/ezmsg/core/messagechannel.py index 8b50e298..35cc8b26 100644 --- a/src/ezmsg/core/messagechannel.py +++ b/src/ezmsg/core/messagechannel.py @@ -21,6 +21,8 @@ encode_str, close_stream_writer, ) +from .profiling import PROFILES, PROFILE_TIME +from .graphmeta import ProfileChannelType logger = logging.getLogger("ezmsg") @@ -99,6 +101,8 @@ class Channel: _pub_writer: asyncio.StreamWriter _graph_address: AddressType | None _local_backpressure: Backpressure | None + _channel_kind: ProfileChannelType + _lease_start: dict[tuple[UUID, int], int] def __init__( self, @@ -125,6 +129,8 @@ def __init__( self.clients = dict() self._graph_address = graph_address self._local_backpressure = None + self._channel_kind = ProfileChannelType.UNKNOWN + self._lease_start = {} @classmethod async def create( @@ -257,8 +263,10 @@ async def _publisher_connection(self, reader: asyncio.StreamReader) -> None: msg_id = await read_int(reader) buf_idx = msg_id % self.num_buffers + channel_kind = ProfileChannelType.UNKNOWN if msg == Command.TX_SHM.value: + channel_kind = ProfileChannelType.SHM shm_name = await read_str(reader) if self.shm is not None and self.shm.name != shm_name: @@ -285,6 +293,7 @@ async def _publisher_connection(self, reader: asyncio.StreamReader) -> None: self.cache.put_from_mem(self.shm[buf_idx]) elif msg == Command.TX_TCP.value: + channel_kind = ProfileChannelType.TCP buf_size = await read_int(reader) obj_bytes = await reader.readexactly(buf_size) assert MessageMarshal.msg_id(obj_bytes) == msg_id @@ -293,6 +302,8 @@ async def _publisher_connection(self, reader: asyncio.StreamReader) -> None: else: raise ValueError(f"unimplemented data telemetry: {msg}") + self._set_channel_kind(channel_kind) + if not self._notify_clients(msg_id): # Nobody is listening; need to ack! self.cache.release(msg_id) @@ -310,13 +321,31 @@ async def _publisher_connection(self, reader: asyncio.StreamReader) -> None: logger.debug(f"disconnected: channel:{self.id} -> pub:{self.pub_id}") + def _set_channel_kind(self, kind: ProfileChannelType) -> None: + if self._channel_kind == ProfileChannelType.UNKNOWN: + self._channel_kind = kind + elif self._channel_kind != kind: + logger.warning( + "Channel %s observed channel kind change: %s -> %s", + self.id, + self._channel_kind.value, + kind.value, + ) + self._channel_kind = kind + + @property + def channel_kind(self) -> ProfileChannelType: + return self._channel_kind + def _notify_clients(self, msg_id: int) -> bool: """notify interested clients and return true if any were notified""" buf_idx = msg_id % self.num_buffers + now_ns = PROFILE_TIME() for client_id, queue in self.clients.items(): if queue is None: continue # queue is none if this is the pub self.backpressure.lease(client_id, buf_idx) + self._lease_start[(client_id, msg_id)] = now_ns queue.put_nowait((self.pub_id, msg_id)) return not self.backpressure.available(buf_idx) @@ -331,6 +360,7 @@ def put_local(self, msg_id: int, msg: typing.Any) -> None: ) buf_idx = msg_id % self.num_buffers + self._set_channel_kind(ProfileChannelType.LOCAL) if self._notify_clients(msg_id): self.cache.put_local(msg, msg_id) self._local_backpressure.lease(self.id, buf_idx) @@ -379,6 +409,14 @@ def _release_backpressure(self, msg_id: int, client_id: UUID) -> None: :param client_id: UUID of client releasing this message :type client_id: UUID """ + now_ns = PROFILE_TIME() + lease = self._lease_start.pop((client_id, msg_id), None) + if lease is not None: + start_ns = lease + PROFILES.subscriber_attributed_backpressure( + client_id, now_ns, now_ns - start_ns, self._channel_kind + ) + buf_idx = msg_id % self.num_buffers self.backpressure.free(client_id, buf_idx) if self.backpressure.buffers[buf_idx].is_empty: @@ -434,6 +472,11 @@ def unregister_client(self, client_id: UUID) -> None: queue.put_nowait((pub_id, msg_id)) self.backpressure.free(client_id) + stale = [ + key for key in self._lease_start.keys() if key[0] == client_id + ] + for key in stale: + self._lease_start.pop(key, None) elif client_id == self.pub_id and self._local_backpressure is not None: self._local_backpressure.free(self.id) diff --git a/src/ezmsg/core/processclient.py b/src/ezmsg/core/processclient.py index 1950f023..f6574b5c 100644 --- a/src/ezmsg/core/processclient.py +++ b/src/ezmsg/core/processclient.py @@ -10,10 +10,13 @@ from collections.abc import Awaitable, Callable from .graphmeta import ( + ProcessProfilingSnapshot, + ProcessProfilingTraceBatch, ProcessControlErrorCode, ProcessControlOperation, ProcessControlRequest, ProcessControlResponse, + ProfilingTraceControl, ProcessPing, ProcessRegistration, ProcessStats, @@ -21,6 +24,7 @@ ProcessSettingsUpdate, SettingsSnapshotValue, ) +from .profiling import PROFILES from .graphserver import GraphService from .netprotocol import ( AddressType, @@ -61,6 +65,7 @@ def __init__( self._io_task = None self._request_handler = None self._owned_units = set() + PROFILES.set_process_id(self._process_id, reset=True) @property def process_id(self) -> str: @@ -103,6 +108,7 @@ def set_request_handler( async def register(self, units: list[str]) -> None: await self.connect() + PROFILES.set_process_id(self._process_id) normalized_units = sorted(set(units)) payload = ProcessRegistration( process_id=self._process_id, @@ -297,6 +303,67 @@ async def _handle_route_request( process_id=self._process_id, ) + if operation == ProcessControlOperation.GET_PROFILING_SNAPSHOT: + snapshot: ProcessProfilingSnapshot = PROFILES.snapshot() + return ProcessControlResponse( + request_id=request.request_id, + ok=True, + payload=pickle.dumps(snapshot), + process_id=self._process_id, + ) + + if operation == ProcessControlOperation.SET_PROFILING_TRACE: + control: ProfilingTraceControl | None = None + try: + if request.payload is not None: + control_obj = pickle.loads(request.payload) + if isinstance(control_obj, ProfilingTraceControl): + control = control_obj + except Exception as exc: + return ProcessControlResponse( + request_id=request.request_id, + ok=False, + error=f"Invalid profiling trace control payload: {exc}", + error_code=ProcessControlErrorCode.INVALID_RESPONSE, + process_id=self._process_id, + ) + + if control is None: + return ProcessControlResponse( + request_id=request.request_id, + ok=False, + error="Missing profiling trace control payload", + error_code=ProcessControlErrorCode.INVALID_RESPONSE, + process_id=self._process_id, + ) + + PROFILES.set_trace_control(control) + return ProcessControlResponse( + request_id=request.request_id, + ok=True, + process_id=self._process_id, + ) + + if operation == ProcessControlOperation.GET_PROFILING_TRACE_BATCH: + max_samples = 1000 + if request.payload is not None: + try: + max_samples_obj = pickle.loads(request.payload) + if isinstance(max_samples_obj, int): + max_samples = max(1, max_samples_obj) + except Exception: + pass + + batch: ProcessProfilingTraceBatch = PROFILES.trace_batch( + max_samples=max_samples + ) + return ProcessControlResponse( + request_id=request.request_id, + ok=True, + payload=pickle.dumps(batch), + process_id=self._process_id, + ) + if self._request_handler is None: return ProcessControlResponse( request_id=request.request_id, diff --git a/src/ezmsg/core/profiling.py b/src/ezmsg/core/profiling.py new file mode 100644 index 00000000..3537f3d0 --- /dev/null +++ b/src/ezmsg/core/profiling.py @@ -0,0 +1,395 @@ +import os +import socket +import time +from collections import deque +from dataclasses import dataclass, field +from typing import Callable, TypeAlias +from uuid import UUID + +from .graphmeta import ( + ProcessProfilingSnapshot, + ProcessProfilingTraceBatch, + ProfileChannelType, + ProfilingTraceControl, + ProfilingTraceSample, + PublisherProfileSnapshot, + SubscriberProfileSnapshot, +) + + +WINDOW_SECONDS = float(os.environ.get("EZMSG_PROFILE_WINDOW_SECONDS", "10.0")) +BUCKET_SECONDS = float(os.environ.get("EZMSG_PROFILE_BUCKET_SECONDS", "0.1")) +TRACE_MAX_SAMPLES = int(os.environ.get("EZMSG_PROFILE_TRACE_MAX_SAMPLES", "10000")) +# Must return monotonic nanoseconds so *_ns metrics remain unit-consistent. +PROFILE_TIME_TYPE: TypeAlias = Callable[[], int] +PROFILE_TIME: PROFILE_TIME_TYPE = time.perf_counter_ns + + +def _endpoint_id(topic: str, id: UUID) -> str: + return f"{topic}:{id}" + + +@dataclass +class _Rolling: + window_seconds: float = WINDOW_SECONDS + bucket_seconds: float = BUCKET_SECONDS + count: list[int] = field(default_factory=list) + value_sum: list[int] = field(default_factory=list) + max_value: list[int] = field(default_factory=list) + _num_buckets: int = 0 + _bucket_ns: int = 0 + _last_bucket: int | None = None + + def __post_init__(self) -> None: + self._num_buckets = max(1, int(self.window_seconds / self.bucket_seconds)) + self._bucket_ns = max(1, int(self.bucket_seconds * 1e9)) + self.count = [0 for _ in range(self._num_buckets)] + self.value_sum = [0 for _ in range(self._num_buckets)] + self.max_value = [0 for _ in range(self._num_buckets)] + + def _bucket(self, ts_ns: int) -> int: + return (ts_ns // self._bucket_ns) % self._num_buckets + + def _advance(self, ts_ns: int) -> int: + bucket = self._bucket(ts_ns) + if self._last_bucket is None: + self._last_bucket = bucket + return bucket + if bucket == self._last_bucket: + return bucket + idx = (self._last_bucket + 1) % self._num_buckets + while idx != bucket: + self.count[idx] = 0 + self.value_sum[idx] = 0 + self.max_value[idx] = 0 + idx = (idx + 1) % self._num_buckets + self.count[bucket] = 0 + self.value_sum[bucket] = 0 + self.max_value[bucket] = 0 + self._last_bucket = bucket + return bucket + + def add(self, ts_ns: int, value: int) -> None: + idx = self._advance(ts_ns) + self.count[idx] += 1 + self.value_sum[idx] += value + if value > self.max_value[idx]: + self.max_value[idx] = value + + def count_total(self) -> int: + return sum(self.count) + + def sum_total(self) -> int: + return sum(self.value_sum) + + def max_total(self) -> int: + return max(self.max_value) if self.max_value else 0 + + def avg(self) -> float: + c = self.count_total() + if c == 0: + return 0.0 + return float(self.sum_total()) / float(c) + + +@dataclass +class _PublisherMetrics: + topic: str + endpoint_id: str + messages_published_total: int = 0 + backpressure_wait_ns_total: int = 0 + inflight_messages_current: int = 0 + _last_publish_ts_ns: int | None = None + _publish_delta: _Rolling = field(default_factory=_Rolling) + _publish_count: _Rolling = field(default_factory=lambda: _Rolling()) + _backpressure_wait: _Rolling = field(default_factory=_Rolling) + _inflight: _Rolling = field(default_factory=_Rolling) + trace_enabled: bool = False + trace_sample_mod: int = 1 + _trace_counter: int = 0 + trace_samples: deque[ProfilingTraceSample] = field( + default_factory=lambda: deque(maxlen=TRACE_MAX_SAMPLES) + ) + + def record_publish(self, ts_ns: int, inflight: int) -> None: + self.messages_published_total += 1 + self._publish_count.add(ts_ns, 1) + publish_delta_ns = 0 + if self._last_publish_ts_ns is not None: + publish_delta_ns = ts_ns - self._last_publish_ts_ns + self._publish_delta.add(ts_ns, publish_delta_ns) + self._last_publish_ts_ns = ts_ns + self.sample_inflight(ts_ns, inflight) + self._trace_counter += 1 + if self.trace_enabled and (self._trace_counter % max(1, self.trace_sample_mod) == 0): + self.trace_samples.append( + ProfilingTraceSample( + timestamp=float(PROFILE_TIME()), + endpoint_id=self.endpoint_id, + topic=self.topic, + metric="publish_delta_ns", + value=float(publish_delta_ns), + ) + ) + + def record_backpressure_wait(self, ts_ns: int, wait_ns: int) -> None: + self.backpressure_wait_ns_total += wait_ns + self._backpressure_wait.add(ts_ns, wait_ns) + if self.trace_enabled: + self.trace_samples.append( + ProfilingTraceSample( + timestamp=float(PROFILE_TIME()), + endpoint_id=self.endpoint_id, + topic=self.topic, + metric="backpressure_wait_ns", + value=float(wait_ns), + ) + ) + + def sample_inflight(self, ts_ns: int, inflight: int) -> None: + self.inflight_messages_current = inflight + self._inflight.add(ts_ns, inflight) + + def snapshot(self) -> PublisherProfileSnapshot: + window_msgs = self._publish_count.count_total() + return PublisherProfileSnapshot( + endpoint_id=self.endpoint_id, + topic=self.topic, + messages_published_total=self.messages_published_total, + messages_published_window=window_msgs, + publish_delta_ns_avg_window=self._publish_delta.avg(), + publish_rate_hz_window=float(window_msgs) / max(WINDOW_SECONDS, 1e-9), + inflight_messages_current=self.inflight_messages_current, + inflight_messages_peak_window=self._inflight.max_total(), + backpressure_wait_ns_total=self.backpressure_wait_ns_total, + backpressure_wait_ns_window=self._backpressure_wait.sum_total(), + timestamp=float(PROFILE_TIME()), + ) + + +@dataclass +class _SubscriberMetrics: + topic: str + endpoint_id: str + messages_received_total: int = 0 + lease_time_ns_total: int = 0 + user_span_ns_total: int = 0 + attributable_backpressure_ns_total: int = 0 + attributable_backpressure_events_total: int = 0 + channel_kind_last: ProfileChannelType = ProfileChannelType.UNKNOWN + _recv_count: _Rolling = field(default_factory=lambda: _Rolling()) + _lease_time: _Rolling = field(default_factory=_Rolling) + _user_span: _Rolling = field(default_factory=_Rolling) + _attrib_bp: _Rolling = field(default_factory=_Rolling) + trace_enabled: bool = False + trace_sample_mod: int = 1 + _trace_counter: int = 0 + trace_samples: deque[ProfilingTraceSample] = field( + default_factory=lambda: deque(maxlen=TRACE_MAX_SAMPLES) + ) + + def record_receive(self, ts_ns: int, lease_ns: int, channel_kind: ProfileChannelType) -> None: + self.messages_received_total += 1 + self.lease_time_ns_total += lease_ns + self.channel_kind_last = channel_kind + self._recv_count.add(ts_ns, 1) + self._lease_time.add(ts_ns, lease_ns) + self._trace_counter += 1 + if self.trace_enabled and (self._trace_counter % max(1, self.trace_sample_mod) == 0): + self.trace_samples.append( + ProfilingTraceSample( + timestamp=float(PROFILE_TIME()), + endpoint_id=self.endpoint_id, + topic=self.topic, + metric="lease_time_ns", + value=float(lease_ns), + channel_kind=channel_kind, + ) + ) + + def record_user_span(self, ts_ns: int, span_ns: int, label: str | None) -> None: + self.user_span_ns_total += span_ns + self._user_span.add(ts_ns, span_ns) + if self.trace_enabled: + self.trace_samples.append( + ProfilingTraceSample( + timestamp=float(PROFILE_TIME()), + endpoint_id=self.endpoint_id, + topic=self.topic if label is None else f"{self.topic}:{label}", + metric="user_span_ns", + value=float(span_ns), + channel_kind=self.channel_kind_last, + ) + ) + + def record_attributed_backpressure( + self, ts_ns: int, duration_ns: int, channel_kind: ProfileChannelType + ) -> None: + self.attributable_backpressure_ns_total += duration_ns + self.attributable_backpressure_events_total += 1 + self.channel_kind_last = channel_kind + self._attrib_bp.add(ts_ns, duration_ns) + + def snapshot(self) -> SubscriberProfileSnapshot: + recv_count = self._recv_count.count_total() + user_count = self._user_span.count_total() + return SubscriberProfileSnapshot( + endpoint_id=self.endpoint_id, + topic=self.topic, + messages_received_total=self.messages_received_total, + messages_received_window=recv_count, + lease_time_ns_total=self.lease_time_ns_total, + lease_time_ns_avg_window=self._lease_time.avg(), + user_span_ns_total=self.user_span_ns_total, + user_span_ns_avg_window=( + float(self._user_span.sum_total()) / float(user_count) + if user_count > 0 + else 0.0 + ), + attributable_backpressure_ns_total=self.attributable_backpressure_ns_total, + attributable_backpressure_ns_window=self._attrib_bp.sum_total(), + attributable_backpressure_events_total=self.attributable_backpressure_events_total, + channel_kind_last=self.channel_kind_last, + timestamp=float(PROFILE_TIME()), + ) + + +class ProfileRegistry: + def __init__(self) -> None: + self._process_id = "" + self._pid = os.getpid() + self._host = socket.gethostname() + self._publishers: dict[UUID, _PublisherMetrics] = {} + self._subscribers: dict[UUID, _SubscriberMetrics] = {} + self._default_trace_control = ProfilingTraceControl(enabled=False) + + def set_process_id(self, process_id: str, *, reset: bool = False) -> None: + if reset or (self._process_id and self._process_id != process_id): + self._publishers.clear() + self._subscribers.clear() + self._default_trace_control = ProfilingTraceControl(enabled=False) + self._process_id = process_id + + def register_publisher(self, pub_id: UUID, topic: str) -> None: + self._publishers[pub_id] = _PublisherMetrics( + topic=topic, + endpoint_id=_endpoint_id(topic, pub_id), + ) + + def unregister_publisher(self, pub_id: UUID) -> None: + self._publishers.pop(pub_id, None) + + def register_subscriber(self, sub_id: UUID, topic: str) -> None: + self._subscribers[sub_id] = _SubscriberMetrics( + topic=topic, + endpoint_id=_endpoint_id(topic, sub_id), + ) + + def unregister_subscriber(self, sub_id: UUID) -> None: + self._subscribers.pop(sub_id, None) + + def publisher_publish(self, pub_id: UUID, ts_ns: int, inflight: int) -> None: + metric = self._publishers.get(pub_id) + if metric is not None: + metric.record_publish(ts_ns, inflight) + + def publisher_backpressure_wait(self, pub_id: UUID, ts_ns: int, wait_ns: int) -> None: + metric = self._publishers.get(pub_id) + if metric is not None: + metric.record_backpressure_wait(ts_ns, wait_ns) + + def publisher_sample_inflight(self, pub_id: UUID, ts_ns: int, inflight: int) -> None: + metric = self._publishers.get(pub_id) + if metric is not None: + metric.sample_inflight(ts_ns, inflight) + + def subscriber_receive( + self, + sub_id: UUID, + ts_ns: int, + lease_ns: int, + channel_kind: ProfileChannelType, + ) -> None: + metric = self._subscribers.get(sub_id) + if metric is not None: + metric.record_receive(ts_ns, lease_ns, channel_kind) + + def subscriber_user_span( + self, sub_id: UUID, ts_ns: int, span_ns: int, label: str | None + ) -> None: + metric = self._subscribers.get(sub_id) + if metric is not None: + metric.record_user_span(ts_ns, span_ns, label) + + def subscriber_attributed_backpressure( + self, + sub_id: UUID, + ts_ns: int, + duration_ns: int, + channel_kind: ProfileChannelType, + ) -> None: + metric = self._subscribers.get(sub_id) + if metric is not None: + metric.record_attributed_backpressure(ts_ns, duration_ns, channel_kind) + + def snapshot(self) -> ProcessProfilingSnapshot: + return ProcessProfilingSnapshot( + process_id=self._process_id, + pid=self._pid, + host=self._host, + window_seconds=WINDOW_SECONDS, + timestamp=float(PROFILE_TIME()), + publishers={ + metric.endpoint_id: metric.snapshot() + for metric in self._publishers.values() + }, + subscribers={ + metric.endpoint_id: metric.snapshot() + for metric in self._subscribers.values() + }, + ) + + def set_trace_control(self, control: ProfilingTraceControl) -> None: + self._default_trace_control = control + sample_mod = max(1, control.sample_mod) + pub_topics = set(control.publisher_topics or []) + sub_topics = set(control.subscriber_topics or []) + + for metric in self._publishers.values(): + enabled = control.enabled and ( + not pub_topics or metric.topic in pub_topics + ) + metric.trace_enabled = enabled + metric.trace_sample_mod = sample_mod + + for metric in self._subscribers.values(): + enabled = control.enabled and ( + not sub_topics or metric.topic in sub_topics + ) + metric.trace_enabled = enabled + metric.trace_sample_mod = sample_mod + + def trace_batch(self, max_samples: int = 1000) -> ProcessProfilingTraceBatch: + samples: list[ProfilingTraceSample] = [] + for metric in self._publishers.values(): + while metric.trace_samples and len(samples) < max_samples: + samples.append(metric.trace_samples.popleft()) + if len(samples) >= max_samples: + break + if len(samples) < max_samples: + for metric in self._subscribers.values(): + while metric.trace_samples and len(samples) < max_samples: + samples.append(metric.trace_samples.popleft()) + if len(samples) >= max_samples: + break + + return ProcessProfilingTraceBatch( + process_id=self._process_id, + pid=self._pid, + host=self._host, + timestamp=float(PROFILE_TIME()), + samples=samples, + ) + + +PROFILES = ProfileRegistry() diff --git a/src/ezmsg/core/pubclient.py b/src/ezmsg/core/pubclient.py index f9c42952..270e50b7 100644 --- a/src/ezmsg/core/pubclient.py +++ b/src/ezmsg/core/pubclient.py @@ -13,6 +13,7 @@ from .channelmanager import CHANNELS from .messagechannel import Channel from .messagemarshal import MessageMarshal, UninitializedMemory +from .profiling import PROFILES, PROFILE_TIME from .netprotocol import ( Address, @@ -230,6 +231,7 @@ def __init__( self._force_tcp = force_tcp self._last_backpressure_event = -1 self._graph_address = graph_address + PROFILES.register_publisher(self.id, self.topic) @property def log_name(self) -> str: @@ -243,6 +245,7 @@ def close(self) -> None: and all subscriber handling tasks. """ self._graph_task.cancel() + PROFILES.unregister_publisher(self.id) self._shm.close() self._connection_task.cancel() for task in self._channel_tasks.values(): @@ -369,12 +372,18 @@ async def _handle_channel( elif msg == Command.RX_ACK.value: msg_id = await read_int(reader) self._backpressure.free(info.id, msg_id % self._num_buffers) + PROFILES.publisher_sample_inflight( + self.id, PROFILE_TIME(), self._backpressure.pressure + ) except (ConnectionResetError, BrokenPipeError): logger.debug(f"Publisher {self.id}: Channel {info.id} connection fail") finally: self._backpressure.free(info.id) + PROFILES.publisher_sample_inflight( + self.id, PROFILE_TIME(), self._backpressure.pressure + ) await close_stream_writer(self._channels[info.id].writer) del self._channels[info.id] @@ -434,7 +443,12 @@ async def broadcast(self, obj: Any) -> None: if BACKPRESSURE_WARNING and (delta > BACKPRESSURE_REFRACTORY): logger.warning(f"{self.topic} under subscriber backpressure!") self._last_backpressure_event = time.time() + wait_start_ns = PROFILE_TIME() await self._backpressure.wait(buf_idx) + wait_end_ns = PROFILE_TIME() + PROFILES.publisher_backpressure_wait( + self.id, wait_end_ns, wait_end_ns - wait_start_ns + ) # Get local channel and put variable there for local tx self._local_channel.put_local(self._msg_id, obj) @@ -502,10 +516,15 @@ async def broadcast(self, obj: Any) -> None: channel.writer.write(msg) await channel.writer.drain() self._backpressure.lease(channel.id, buf_idx) + PROFILES.publisher_sample_inflight( + self.id, PROFILE_TIME(), self._backpressure.pressure + ) except (ConnectionResetError, BrokenPipeError): logger.debug( f"Publisher {self.id}: Channel {channel.id} connection fail" ) + now_ns = PROFILE_TIME() + PROFILES.publisher_publish(self.id, now_ns, self._backpressure.pressure) self._msg_id += 1 diff --git a/src/ezmsg/core/subclient.py b/src/ezmsg/core/subclient.py index 3ca2dc22..50a7c3c1 100644 --- a/src/ezmsg/core/subclient.py +++ b/src/ezmsg/core/subclient.py @@ -3,12 +3,14 @@ import typing from uuid import UUID -from contextlib import asynccontextmanager, suppress +from contextlib import asynccontextmanager, contextmanager, suppress from copy import deepcopy from .graphserver import GraphService from .channelmanager import CHANNELS from .messagechannel import NotificationQueue, LeakyQueue, Channel +from .profiling import PROFILES, PROFILE_TIME +from .graphmeta import ProfileChannelType from .netprotocol import ( AddressType, @@ -135,6 +137,7 @@ def __init__( else: self._incoming = asyncio.Queue() self._initialized = asyncio.Event() + PROFILES.register_subscriber(self.id, self.topic) def _handle_dropped_notification( self, notification: typing.Tuple[UUID, int] @@ -160,6 +163,7 @@ def close(self) -> None: and closes all shared memory contexts. """ self._graph_task.cancel() + PROFILES.unregister_subscriber(self.id) async def wait_closed(self) -> None: """ @@ -295,5 +299,27 @@ async def recv_zero_copy(self) -> typing.AsyncGenerator[typing.Any, None]: break # Stale notification from an unregistered publisher — skip. - with self._channels[pub_id].get(msg_id, self.id) as msg: + channel = self._channels[pub_id] + channel_kind = getattr(channel, "channel_kind", ProfileChannelType.UNKNOWN) + start_ns = PROFILE_TIME() + with channel.get(msg_id, self.id) as msg: yield msg + end_ns = PROFILE_TIME() + PROFILES.subscriber_receive( + self.id, end_ns, end_ns - start_ns, channel_kind + ) + + def begin_profile(self) -> int: + return PROFILE_TIME() + + def end_profile(self, start_ns: int, label: str | None = None) -> None: + end_ns = PROFILE_TIME() + PROFILES.subscriber_user_span(self.id, end_ns, end_ns - start_ns, label) + + @contextmanager + def profile_span(self, label: str | None = None) -> typing.Generator[None, None, None]: + start_ns = self.begin_profile() + try: + yield + finally: + self.end_profile(start_ns, label=label) diff --git a/tests/test_profiling_api.py b/tests/test_profiling_api.py new file mode 100644 index 00000000..fd3ecb69 --- /dev/null +++ b/tests/test_profiling_api.py @@ -0,0 +1,137 @@ +import asyncio + +import pytest + +from ezmsg.core.graphcontext import GraphContext +from ezmsg.core.graphmeta import ( + ProcessControlErrorCode, + ProfilingTraceControl, +) +from ezmsg.core.graphserver import GraphService +from ezmsg.core.processclient import ProcessControlClient + + +@pytest.mark.asyncio +async def test_process_profiling_snapshot_collects_pub_sub_metrics(): + graph_server = GraphService().create_server() + address = graph_server.address + + ctx = GraphContext(address, auto_start=False) + await ctx.__aenter__() + + process = ProcessControlClient(address, process_id="proc-prof") + await process.connect() + await process.register(["SYS/U1"]) + + pub = await ctx.publisher("TOPIC_PROF") + sub = await ctx.subscriber("TOPIC_PROF") + + try: + for idx in range(8): + await pub.broadcast(idx) + async with sub.recv_zero_copy() as _msg: + await asyncio.sleep(0) + + snap = await ctx.process_profiling_snapshot("SYS/U1", timeout=1.0) + assert snap.process_id == "proc-prof" + assert snap.window_seconds > 0 + assert len(snap.publishers) >= 1 + assert len(snap.subscribers) >= 1 + + pub_metrics = next(iter(snap.publishers.values())) + assert pub_metrics.messages_published_total >= 8 + assert pub_metrics.publish_rate_hz_window >= 0.0 + + sub_metrics = next(iter(snap.subscribers.values())) + assert sub_metrics.messages_received_total >= 8 + assert sub_metrics.lease_time_ns_total > 0 + assert sub_metrics.lease_time_ns_avg_window >= 0.0 + finally: + await process.close() + await ctx.__aexit__(None, None, None) + graph_server.stop() + + +@pytest.mark.asyncio +async def test_process_profiling_trace_control_and_batch(): + graph_server = GraphService().create_server() + address = graph_server.address + + ctx = GraphContext(address, auto_start=False) + await ctx.__aenter__() + + process = ProcessControlClient(address, process_id="proc-trace") + await process.connect() + await process.register(["SYS/U2"]) + + pub = await ctx.publisher("TOPIC_TRACE") + sub = await ctx.subscriber("TOPIC_TRACE") + + try: + response = await ctx.process_set_profiling_trace( + "SYS/U2", + ProfilingTraceControl( + enabled=True, + sample_mod=1, + publisher_topics=["TOPIC_TRACE"], + subscriber_topics=["TOPIC_TRACE"], + ), + timeout=1.0, + ) + assert response.ok + + for idx in range(5): + await pub.broadcast(idx) + async with sub.recv_zero_copy() as _msg: + span_start_ns = sub.begin_profile() + try: + await asyncio.sleep(0) + finally: + sub.end_profile(span_start_ns, "taskA") + + batch = await ctx.process_profiling_trace_batch( + "SYS/U2", max_samples=200, timeout=1.0 + ) + assert batch.process_id == "proc-trace" + assert len(batch.samples) > 0 + + disable_response = await ctx.process_set_profiling_trace( + "SYS/U2", + ProfilingTraceControl(enabled=False), + timeout=1.0, + ) + assert disable_response.ok + finally: + await process.close() + await ctx.__aexit__(None, None, None) + graph_server.stop() + + +@pytest.mark.asyncio +async def test_profiling_snapshot_all_and_unroutable_error_code(): + graph_server = GraphService().create_server() + address = graph_server.address + + ctx = GraphContext(address, auto_start=False) + await ctx.__aenter__() + + process = ProcessControlClient(address, process_id="proc-all") + await process.connect() + await process.register(["SYS/U3"]) + + try: + snapshots = await ctx.profiling_snapshot_all(timeout_per_process=0.5) + assert "proc-all" in snapshots + assert snapshots["proc-all"].process_id == "proc-all" + + response = await ctx.process_request( + "SYS/MISSING", + "GET_PROFILING_SNAPSHOT", + timeout=0.2, + ) + assert not response.ok + assert response.error_code == ProcessControlErrorCode.UNROUTABLE_UNIT + finally: + await process.close() + await ctx.__aexit__(None, None, None) + graph_server.stop() From bf9e009139bb21d1e50367db690eed3b624dd39b Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Mon, 16 Mar 2026 14:46:05 -0400 Subject: [PATCH 16/52] profiling hooks; trace/streams --- examples/profiling_tui.py | 442 +++++++++++++++++++++++++++++++++ src/ezmsg/core/graphcontext.py | 34 +++ src/ezmsg/core/graphmeta.py | 6 + src/ezmsg/core/graphserver.py | 109 ++++++++ src/ezmsg/core/netprotocol.py | 1 + tests/test_profiling_api.py | 54 ++++ 6 files changed, 646 insertions(+) create mode 100644 examples/profiling_tui.py diff --git a/examples/profiling_tui.py b/examples/profiling_tui.py new file mode 100644 index 00000000..13f230c6 --- /dev/null +++ b/examples/profiling_tui.py @@ -0,0 +1,442 @@ +#!/usr/bin/env python3 +""" +Simple live profiling TUI for ezmsg GraphServer. + +Features: +- Periodic profiling snapshot view broken out by publisher/subscriber endpoints +- Live trace sample counts via GraphContext.subscribe_profiling_trace() +- Optional automatic trace enablement for discovered processes + +Usage: + .venv/bin/python examples/profiling_tui.py --host 127.0.0.1 --port 25978 +""" + +from __future__ import annotations + +import argparse +import asyncio +import contextlib +import time +from dataclasses import dataclass + +from ezmsg.core.graphcontext import GraphContext +from ezmsg.core.graphmeta import ProcessProfilingSnapshot, ProfilingTraceControl +from ezmsg.core.netprotocol import DEFAULT_HOST, GRAPHSERVER_PORT_DEFAULT + + +def _truncate(text: str, width: int) -> str: + if width <= 3: + return text[:width] + if len(text) <= width: + return text + return text[: width - 3] + "..." + + +def _fmt_float(value: float, digits: int = 2) -> str: + return f"{value:.{digits}f}" + + +@dataclass +class PublisherView: + process_id: str + topic: str + endpoint_id: str + published_total: int + published_window: int + publish_rate_hz: float + publish_delta_ms_avg: float + inflight_current: int + inflight_peak: int + trace_samples_seen: int + trace_last_age_s: float | None + backpressure_wait_ms_window: float + + +@dataclass +class SubscriberView: + process_id: str + topic: str + endpoint_id: str + channel_kind: str + received_total: int + received_window: int + lease_time_ms_avg: float + user_span_ms_avg: float + attributable_backpressure_ms_window: float + attributable_backpressure_events_total: int + trace_samples_seen: int + trace_last_age_s: float | None + + +class ProfilingTUI: + def __init__( + self, + ctx: GraphContext, + *, + snapshot_interval: float, + trace_interval: float, + trace_max_samples: int, + auto_trace: bool, + trace_sample_mod: int, + max_rows: int, + ) -> None: + self.ctx = ctx + self.snapshot_interval = max(0.2, snapshot_interval) + self.trace_interval = max(0.01, trace_interval) + self.trace_max_samples = max(1, trace_max_samples) + self.auto_trace = auto_trace + self.trace_sample_mod = max(1, trace_sample_mod) + self.max_rows = max(5, max_rows) + + self.snapshots: dict[str, ProcessProfilingSnapshot] = {} + self.route_units: dict[str, str] = {} + self.trace_enabled_processes: set[str] = set() + self.trace_errors: dict[str, str] = {} + self.trace_samples_seen_by_endpoint: dict[str, int] = {} + self.trace_last_timestamp_by_endpoint: dict[str, float] = {} + self.last_snapshot_time: float | None = None + + self._snapshot_task: asyncio.Task[None] | None = None + self._trace_task: asyncio.Task[None] | None = None + + async def start(self) -> None: + await self._refresh_snapshot() + self._snapshot_task = asyncio.create_task(self._snapshot_loop()) + self._trace_task = asyncio.create_task(self._trace_loop()) + + async def close(self) -> None: + for task in (self._snapshot_task, self._trace_task): + if task is not None: + task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await task + + if not self.auto_trace: + return + + for process_id, route_unit in self.route_units.items(): + if process_id not in self.trace_enabled_processes: + continue + with contextlib.suppress(Exception): + await self.ctx.process_set_profiling_trace( + route_unit, + ProfilingTraceControl(enabled=False), + timeout=0.5, + ) + + async def _snapshot_loop(self) -> None: + while True: + await self._refresh_snapshot() + await asyncio.sleep(self.snapshot_interval) + + async def _refresh_snapshot(self) -> None: + graph_snapshot = await self.ctx.snapshot() + route_units: dict[str, str] = {} + for process in graph_snapshot.processes.values(): + if process.units: + route_units[process.process_id] = process.units[0] + self.route_units = route_units + + if self.auto_trace: + for process_id, route_unit in route_units.items(): + if process_id in self.trace_enabled_processes: + continue + try: + response = await self.ctx.process_set_profiling_trace( + route_unit, + ProfilingTraceControl( + enabled=True, + sample_mod=self.trace_sample_mod, + ), + timeout=0.5, + ) + if response.ok: + self.trace_enabled_processes.add(process_id) + self.trace_errors.pop(process_id, None) + else: + self.trace_errors[process_id] = str( + response.error or "unknown error" + ) + except Exception as exc: + self.trace_errors[process_id] = str(exc) + + self.snapshots = await self.ctx.profiling_snapshot_all( + timeout_per_process=max(0.1, self.snapshot_interval * 0.8) + ) + self.last_snapshot_time = time.time() + + async def _trace_loop(self) -> None: + async for batch in self.ctx.subscribe_profiling_trace( + interval=self.trace_interval, + max_samples=self.trace_max_samples, + ): + for process_batch in batch.batches.values(): + for sample in process_batch.samples: + endpoint_id = sample.endpoint_id + self.trace_samples_seen_by_endpoint[endpoint_id] = ( + self.trace_samples_seen_by_endpoint.get(endpoint_id, 0) + 1 + ) + self.trace_last_timestamp_by_endpoint[endpoint_id] = batch.timestamp + + def _trace_for_endpoint(self, endpoint_id: str) -> tuple[int, float | None]: + now = time.time() + count = self.trace_samples_seen_by_endpoint.get(endpoint_id, 0) + ts = self.trace_last_timestamp_by_endpoint.get(endpoint_id) + age = None if ts is None else max(0.0, now - ts) + return count, age + + def _publisher_rows(self) -> list[PublisherView]: + rows: list[PublisherView] = [] + for process_id, snapshot in self.snapshots.items(): + for pub in snapshot.publishers.values(): + trace_count, trace_age = self._trace_for_endpoint(pub.endpoint_id) + rows.append( + PublisherView( + process_id=process_id, + topic=pub.topic, + endpoint_id=pub.endpoint_id, + published_total=pub.messages_published_total, + published_window=pub.messages_published_window, + publish_rate_hz=pub.publish_rate_hz_window, + publish_delta_ms_avg=pub.publish_delta_ns_avg_window / 1_000_000.0, + inflight_current=pub.inflight_messages_current, + inflight_peak=pub.inflight_messages_peak_window, + trace_samples_seen=trace_count, + trace_last_age_s=trace_age, + backpressure_wait_ms_window=( + pub.backpressure_wait_ns_window / 1_000_000.0 + ), + ) + ) + rows.sort( + key=lambda row: ( + -row.publish_rate_hz, + -row.published_total, + row.process_id, + row.topic, + ) + ) + return rows + + def _subscriber_rows(self) -> list[SubscriberView]: + rows: list[SubscriberView] = [] + for process_id, snapshot in self.snapshots.items(): + for sub in snapshot.subscribers.values(): + trace_count, trace_age = self._trace_for_endpoint(sub.endpoint_id) + channel_kind = ( + sub.channel_kind_last.value + if hasattr(sub.channel_kind_last, "value") + else str(sub.channel_kind_last) + ) + rows.append( + SubscriberView( + process_id=process_id, + topic=sub.topic, + endpoint_id=sub.endpoint_id, + channel_kind=channel_kind, + received_total=sub.messages_received_total, + received_window=sub.messages_received_window, + lease_time_ms_avg=sub.lease_time_ns_avg_window / 1_000_000.0, + user_span_ms_avg=sub.user_span_ns_avg_window / 1_000_000.0, + attributable_backpressure_ms_window=( + sub.attributable_backpressure_ns_window / 1_000_000.0 + ), + attributable_backpressure_events_total=( + sub.attributable_backpressure_events_total + ), + trace_samples_seen=trace_count, + trace_last_age_s=trace_age, + ) + ) + rows.sort( + key=lambda row: ( + -row.lease_time_ms_avg, + -row.received_total, + row.process_id, + row.topic, + ) + ) + return rows + + def render(self) -> None: + print("\x1bc", end="") + print("ezmsg profiling tui") + print("Ctrl-C to quit") + print( + "snapshot interval=" + f"{self.snapshot_interval:.2f}s, trace interval={self.trace_interval:.2f}s, " + f"trace max_samples={self.trace_max_samples}, auto_trace={self.auto_trace}" + ) + if self.last_snapshot_time is not None: + print( + "last snapshot age: " + f"{_fmt_float(max(0.0, time.time() - self.last_snapshot_time), 2)}s" + ) + print( + f"processes discovered={len(self.route_units)} " + f"publishers={sum(len(s.publishers) for s in self.snapshots.values())} " + f"subscribers={sum(len(s.subscribers) for s in self.snapshots.values())}" + ) + + publisher_rows = self._publisher_rows() + subscriber_rows = self._subscriber_rows() + + print("\nPublishers") + pub_header = ( + f"{'Process':<20} {'Topic':<26} {'Endpoint':<24} " + f"{'Total':>8} {'Win':>6} {'RateHz':>8} {'DeltaMs':>8} " + f"{'InFl':>5} {'InPk':>5} {'BPmsW':>8} {'Trace':>7} {'TAge':>6}" + ) + print(pub_header) + print("-" * len(pub_header)) + if not publisher_rows: + print("") + else: + for row in publisher_rows[: self.max_rows]: + trace_age = ( + "-" if row.trace_last_age_s is None else _fmt_float(row.trace_last_age_s, 2) + ) + print( + f"{_truncate(row.process_id, 20):<20} " + f"{_truncate(row.topic, 26):<26} " + f"{_truncate(row.endpoint_id, 24):<24} " + f"{row.published_total:>8} " + f"{row.published_window:>6} " + f"{_fmt_float(row.publish_rate_hz, 2):>8} " + f"{_fmt_float(row.publish_delta_ms_avg, 2):>8} " + f"{row.inflight_current:>5} " + f"{row.inflight_peak:>5} " + f"{_fmt_float(row.backpressure_wait_ms_window, 2):>8} " + f"{row.trace_samples_seen:>7} " + f"{trace_age:>6}" + ) + + print("\nSubscribers") + sub_header = ( + f"{'Process':<20} {'Topic':<26} {'Endpoint':<24} {'Kind':<6} " + f"{'Total':>8} {'Win':>6} {'LeaseMs':>8} {'UserMs':>8} " + f"{'BPmsW':>8} {'BPev':>6} {'Trace':>7} {'TAge':>6}" + ) + print(sub_header) + print("-" * len(sub_header)) + if not subscriber_rows: + print("") + else: + for row in subscriber_rows[: self.max_rows]: + trace_age = ( + "-" if row.trace_last_age_s is None else _fmt_float(row.trace_last_age_s, 2) + ) + print( + f"{_truncate(row.process_id, 20):<20} " + f"{_truncate(row.topic, 26):<26} " + f"{_truncate(row.endpoint_id, 24):<24} " + f"{_truncate(row.channel_kind, 6):<6} " + f"{row.received_total:>8} " + f"{row.received_window:>6} " + f"{_fmt_float(row.lease_time_ms_avg, 2):>8} " + f"{_fmt_float(row.user_span_ms_avg, 2):>8} " + f"{_fmt_float(row.attributable_backpressure_ms_window, 2):>8} " + f"{row.attributable_backpressure_events_total:>6} " + f"{row.trace_samples_seen:>7} " + f"{trace_age:>6}" + ) + + if self.trace_errors: + print("\ntrace errors:") + for process_id, err in sorted(self.trace_errors.items()): + print(f" {_truncate(process_id, 30)}: {_truncate(err, 120)}") + + +def _parse_address(host: str, port: int) -> tuple[str, int]: + return (host, port) + + +async def _run_tui(args: argparse.Namespace) -> None: + async with GraphContext( + _parse_address(args.host, args.port), auto_start=args.auto_start + ) as ctx: + tui = ProfilingTUI( + ctx, + snapshot_interval=args.snapshot_interval, + trace_interval=args.trace_interval, + trace_max_samples=args.max_samples, + auto_trace=args.auto_trace, + trace_sample_mod=args.sample_mod, + max_rows=args.max_rows, + ) + await tui.start() + try: + while True: + tui.render() + await asyncio.sleep(max(0.1, args.render_interval)) + finally: + await tui.close() + + +def _build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="ezmsg profiling TUI") + parser.add_argument("--host", default=DEFAULT_HOST, help="GraphServer host") + parser.add_argument( + "--port", + type=int, + default=GRAPHSERVER_PORT_DEFAULT, + help="GraphServer port", + ) + parser.add_argument( + "--auto-start", + action="store_true", + help="Allow GraphContext to auto-start GraphServer if unavailable", + ) + parser.add_argument( + "--snapshot-interval", + type=float, + default=1.0, + help="Seconds between snapshot refreshes", + ) + parser.add_argument( + "--trace-interval", + type=float, + default=0.05, + help="Seconds between GraphServer trace stream batches", + ) + parser.add_argument( + "--max-samples", + type=int, + default=512, + help="Max samples per process per streamed batch", + ) + parser.add_argument( + "--sample-mod", + type=int, + default=1, + help="Trace sampling divisor when auto-enabling trace", + ) + parser.add_argument( + "--render-interval", + type=float, + default=0.5, + help="Seconds between TUI redraws", + ) + parser.add_argument( + "--max-rows", + type=int, + default=30, + help="Max publisher/subscriber rows to render per table", + ) + parser.add_argument( + "--no-auto-trace", + action="store_true", + help="Do not auto-enable trace mode on discovered processes", + ) + return parser + + +def main() -> None: + parser = _build_parser() + args = parser.parse_args() + args.auto_trace = not args.no_auto_trace + asyncio.run(_run_tui(args)) + + +if __name__ == "__main__": + main() diff --git a/src/ezmsg/core/graphcontext.py b/src/ezmsg/core/graphcontext.py index cdf040c0..03ef8124 100644 --- a/src/ezmsg/core/graphcontext.py +++ b/src/ezmsg/core/graphcontext.py @@ -29,6 +29,7 @@ ProcessPing, ProcessProfilingSnapshot, ProcessProfilingTraceBatch, + ProfilingTraceStreamBatch, ProcessStats, ProcessControlResponse, ProfilingTraceControl, @@ -422,6 +423,39 @@ async def subscribe_settings_events( finally: await close_stream_writer(writer) + async def subscribe_profiling_trace( + self, + *, + interval: float = 0.05, + max_samples: int = 1000, + ) -> typing.AsyncIterator[ProfilingTraceStreamBatch]: + reader, writer = await GraphService(self.graph_address).open_connection() + writer.write(Command.SESSION_PROFILING_SUBSCRIBE.value) + writer.write(encode_str(str(interval))) + writer.write(encode_str(str(max_samples))) + await writer.drain() + + _subscriber_id = UUID(await read_str(reader)) + response = await reader.read(1) + if response != Command.COMPLETE.value: + await close_stream_writer(writer) + raise RuntimeError("Failed to subscribe to profiling trace stream") + + try: + while True: + payload_size = await read_int(reader) + payload = await reader.readexactly(payload_size) + batch = pickle.loads(payload) + if not isinstance(batch, ProfilingTraceStreamBatch): + raise RuntimeError( + "Profiling subscription received invalid batch payload" + ) + yield batch + except asyncio.IncompleteReadError: + return + finally: + await close_stream_writer(writer) + async def process_request( self, unit_address: str, diff --git a/src/ezmsg/core/graphmeta.py b/src/ezmsg/core/graphmeta.py index 42e83722..00d5be2c 100644 --- a/src/ezmsg/core/graphmeta.py +++ b/src/ezmsg/core/graphmeta.py @@ -301,6 +301,12 @@ class ProcessProfilingTraceBatch: samples: list[ProfilingTraceSample] +@dataclass +class ProfilingTraceStreamBatch: + timestamp: float + batches: dict[str, ProcessProfilingTraceBatch] + + class Edge(NamedTuple): from_topic: str to_topic: str diff --git a/src/ezmsg/core/graphserver.py b/src/ezmsg/core/graphserver.py index 686ccb69..6b18ca87 100644 --- a/src/ezmsg/core/graphserver.py +++ b/src/ezmsg/core/graphserver.py @@ -14,9 +14,12 @@ from .graph_util import get_compactified_graph, graph_string, prune_graph_connections from .graphmeta import ( Edge, + ProcessControlOperation, ProcessControlErrorCode, GraphMetadata, GraphSnapshot, + ProcessProfilingTraceBatch, + ProfilingTraceStreamBatch, ProcessControlRequest, ProcessControlResponse, ProcessRegistration, @@ -327,6 +330,27 @@ async def api( # to avoid closing writer return + elif req == Command.SESSION_PROFILING_SUBSCRIBE.value: + subscriber_id = uuid1() + interval = float(await read_str(reader)) + max_samples = int(await read_str(reader)) + writer.write(encode_str(str(subscriber_id))) + writer.write(Command.COMPLETE.value) + await writer.drain() + self._client_tasks[subscriber_id] = asyncio.create_task( + self._handle_profiling_subscriber( + subscriber_id, + max(0.01, interval), + max(1, max_samples), + reader, + writer, + ) + ) + + # NOTE: Created a stream client, must return early + # to avoid closing writer + return + elif req == Command.PROCESS.value: process_client_id = uuid1() self.clients[process_client_id] = ProcessInfo(process_client_id, writer) @@ -727,6 +751,91 @@ async def _handle_settings_subscriber( await sender_task await close_stream_writer(writer) + async def _profiling_route_targets(self) -> list[tuple[str, str]]: + targets: list[tuple[str, str]] = [] + async with self._command_lock: + for client_id, info in self.clients.items(): + if not isinstance(info, ProcessInfo): + continue + if len(info.units) == 0: + continue + process_id = info.process_id if info.process_id is not None else str(client_id) + route_unit = sorted(info.units)[0] + targets.append((process_id, route_unit)) + return targets + + async def _collect_profiling_trace_stream_batch( + self, + *, + max_samples: int, + timeout_per_process: float, + ) -> ProfilingTraceStreamBatch: + targets = await self._profiling_route_targets() + batches: dict[str, ProcessProfilingTraceBatch] = {} + request_payload = pickle.dumps(max_samples) + + for process_id, route_unit in targets: + response = await self._route_process_request( + unit_address=route_unit, + operation=ProcessControlOperation.GET_PROFILING_TRACE_BATCH.value, + payload=request_payload, + timeout=timeout_per_process, + ) + if not response.ok or response.payload is None: + continue + try: + payload_obj = pickle.loads(response.payload) + except Exception: + continue + if not isinstance(payload_obj, ProcessProfilingTraceBatch): + continue + if len(payload_obj.samples) == 0: + continue + batches[process_id] = payload_obj + + return ProfilingTraceStreamBatch( + timestamp=time.time(), + batches=batches, + ) + + async def _handle_profiling_subscriber( + self, + subscriber_id: UUID, + interval: float, + max_samples: int, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + ) -> None: + try: + while True: + try: + req = await asyncio.wait_for(reader.read(1), timeout=interval) + if not req: + break + # No control commands currently supported on this stream. + continue + except asyncio.TimeoutError: + pass + + batch = await self._collect_profiling_trace_stream_batch( + max_samples=max_samples, + timeout_per_process=max(0.05, interval), + ) + if len(batch.batches) == 0: + continue + + payload = pickle.dumps(batch) + writer.write(uint64_to_bytes(len(payload))) + writer.write(payload) + await writer.drain() + except (ConnectionResetError, BrokenPipeError) as e: + logger.debug(f"Profiling subscriber {subscriber_id} disconnected: {e}") + except asyncio.CancelledError: + raise + finally: + self._client_tasks.pop(subscriber_id, None) + await close_stream_writer(writer) + async def _handle_process_register_request( self, process_client_id: UUID, reader: asyncio.StreamReader ) -> bytes: diff --git a/src/ezmsg/core/netprotocol.py b/src/ezmsg/core/netprotocol.py index b1332140..33f0833b 100644 --- a/src/ezmsg/core/netprotocol.py +++ b/src/ezmsg/core/netprotocol.py @@ -334,6 +334,7 @@ def _generate_next_value_(name, start, count, last_values) -> bytes: SESSION_SETTINGS_SNAPSHOT = enum.auto() SESSION_SETTINGS_EVENTS = enum.auto() SESSION_SETTINGS_SUBSCRIBE = enum.auto() + SESSION_PROFILING_SUBSCRIBE = enum.auto() SESSION_PROCESS_REQUEST = enum.auto() # Backend Process Control Commands diff --git a/tests/test_profiling_api.py b/tests/test_profiling_api.py index fd3ecb69..49724588 100644 --- a/tests/test_profiling_api.py +++ b/tests/test_profiling_api.py @@ -135,3 +135,57 @@ async def test_profiling_snapshot_all_and_unroutable_error_code(): await process.close() await ctx.__aexit__(None, None, None) graph_server.stop() + + +@pytest.mark.asyncio +async def test_process_profiling_trace_subscription_push(): + graph_server = GraphService().create_server() + address = graph_server.address + + ctx = GraphContext(address, auto_start=False) + await ctx.__aenter__() + + process = ProcessControlClient(address, process_id="proc-stream") + await process.connect() + await process.register(["SYS/U4"]) + + pub = await ctx.publisher("TOPIC_STREAM") + sub = await ctx.subscriber("TOPIC_STREAM") + stream = None + + try: + response = await ctx.process_set_profiling_trace( + "SYS/U4", + ProfilingTraceControl( + enabled=True, + sample_mod=1, + publisher_topics=["TOPIC_STREAM"], + subscriber_topics=["TOPIC_STREAM"], + ), + timeout=1.0, + ) + assert response.ok + + stream = ctx.subscribe_profiling_trace(interval=0.02, max_samples=256) + + for idx in range(8): + await pub.broadcast(idx) + async with sub.recv_zero_copy() as _msg: + span_start_ns = sub.begin_profile() + try: + await asyncio.sleep(0) + finally: + sub.end_profile(span_start_ns, "taskA") + + batch = await asyncio.wait_for(anext(stream), timeout=1.0) + assert batch.timestamp > 0.0 + assert "proc-stream" in batch.batches + process_batch = batch.batches["proc-stream"] + assert process_batch.process_id == "proc-stream" + assert len(process_batch.samples) > 0 + finally: + if stream is not None: + await stream.aclose() + await process.close() + await ctx.__aexit__(None, None, None) + graph_server.stop() From 9c2dc79ad9968033d2d9a4944c582bf5bb36503a Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Mon, 16 Mar 2026 15:10:46 -0400 Subject: [PATCH 17/52] topology change subscription api --- examples/topology_tui.py | 284 +++++++++++++++++++++++++++++++++ src/ezmsg/core/graphcontext.py | 32 ++++ src/ezmsg/core/graphmeta.py | 15 ++ src/ezmsg/core/graphserver.py | 173 +++++++++++++++++++- src/ezmsg/core/netprotocol.py | 1 + tests/test_topology_api.py | 103 ++++++++++++ 6 files changed, 606 insertions(+), 2 deletions(-) create mode 100644 examples/topology_tui.py create mode 100644 tests/test_topology_api.py diff --git a/examples/topology_tui.py b/examples/topology_tui.py new file mode 100644 index 00000000..07bc82f4 --- /dev/null +++ b/examples/topology_tui.py @@ -0,0 +1,284 @@ +#!/usr/bin/env python3 +""" +Simple live topology TUI for ezmsg GraphServer. + +Features: +- Push-based topology event subscription +- Live graph summary (nodes/edges/sessions/processes) +- Process ownership view +- Current edge list +- Recent topology event log + +Usage: + PYTHONPATH=src .venv/bin/python examples/topology_tui.py --host 127.0.0.1 --port 25978 +""" + +from __future__ import annotations + +import argparse +import asyncio +import contextlib +import time +from collections import deque + +from ezmsg.core.graphcontext import GraphContext +from ezmsg.core.graphmeta import GraphSnapshot, TopologyChangedEvent +from ezmsg.core.netprotocol import DEFAULT_HOST, GRAPHSERVER_PORT_DEFAULT + + +def _truncate(text: str, width: int) -> str: + if width <= 3: + return text[:width] + if len(text) <= width: + return text + return text[: width - 3] + "..." + + +def _fmt_age(age_s: float) -> str: + return f"{age_s:0.2f}s" + + +def _flatten_edges(snapshot: GraphSnapshot) -> list[tuple[str, str]]: + edges: list[tuple[str, str]] = [] + for src, destinations in snapshot.graph.items(): + for dst in destinations: + edges.append((src, dst)) + edges.sort(key=lambda edge: (edge[0], edge[1])) + return edges + + +class TopologyTUI: + def __init__( + self, + ctx: GraphContext, + *, + snapshot_interval: float, + render_interval: float, + max_edges: int, + max_events: int, + max_processes: int, + ) -> None: + self.ctx = ctx + self.snapshot_interval = max(0.2, snapshot_interval) + self.render_interval = max(0.1, render_interval) + self.max_edges = max(10, max_edges) + self.max_events = max(10, max_events) + self.max_processes = max(5, max_processes) + + self.snapshot: GraphSnapshot | None = None + self.last_snapshot_time: float | None = None + self._events: deque[TopologyChangedEvent] = deque(maxlen=self.max_events) + self._event_queue: asyncio.Queue[TopologyChangedEvent] = asyncio.Queue() + + self._watch_task: asyncio.Task[None] | None = None + + async def start(self) -> None: + await self._refresh_snapshot() + self._watch_task = asyncio.create_task(self._watch_topology_events()) + + async def close(self) -> None: + if self._watch_task is not None: + self._watch_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._watch_task + + async def _watch_topology_events(self) -> None: + after_seq = 0 + async for event in self.ctx.subscribe_topology_events(after_seq=after_seq): + after_seq = event.seq + await self._event_queue.put(event) + + async def _refresh_snapshot(self) -> None: + self.snapshot = await self.ctx.snapshot() + self.last_snapshot_time = time.time() + + async def update(self) -> int: + """ + Drain queued topology events and refresh snapshot if needed. + + Returns: + Number of drained events. + """ + drained = 0 + refresh_requested = False + + while True: + try: + event = self._event_queue.get_nowait() + except asyncio.QueueEmpty: + break + self._events.append(event) + refresh_requested = True + drained += 1 + + if self.last_snapshot_time is None: + await self._refresh_snapshot() + elif refresh_requested or (time.time() - self.last_snapshot_time) >= self.snapshot_interval: + await self._refresh_snapshot() + + return drained + + def render(self, drained_events: int) -> None: + print("\x1bc", end="") + print("ezmsg topology tui") + print("Ctrl-C to quit") + print( + f"snapshot_interval={self.snapshot_interval:.2f}s " + f"render_interval={self.render_interval:.2f}s" + ) + if self.last_snapshot_time is not None: + print(f"snapshot_age={_fmt_age(max(0.0, time.time() - self.last_snapshot_time))}") + if drained_events > 0: + print(f"applied_events={drained_events}") + + snapshot = self.snapshot + if snapshot is None: + print("\n") + return + + edges = _flatten_edges(snapshot) + node_names = set(snapshot.graph.keys()) + for _, dst in edges: + node_names.add(dst) + + print( + "\nsummary: " + f"nodes={len(node_names)} edges={len(edges)} " + f"sessions={len(snapshot.sessions)} processes={len(snapshot.processes)}" + ) + + print("\nprocesses") + proc_header = f"{'Process':<30} {'PID':>8} {'Host':<24} {'Units':<80}" + print(proc_header) + print("-" * len(proc_header)) + if not snapshot.processes: + print("") + else: + process_items = sorted(snapshot.processes.values(), key=lambda p: p.process_id) + for proc in process_items[: self.max_processes]: + units = ", ".join(proc.units) if proc.units else "-" + print( + f"{_truncate(proc.process_id, 30):<30} " + f"{str(proc.pid) if proc.pid is not None else '-':>8} " + f"{_truncate(proc.host if proc.host is not None else '-', 24):<24} " + f"{_truncate(units, 80):<80}" + ) + if len(process_items) > self.max_processes: + print(f"... {len(process_items) - self.max_processes} more process rows") + + print("\nedges") + edge_header = f"{'From':<48} {'To':<48}" + print(edge_header) + print("-" * len(edge_header)) + if not edges: + print("") + else: + for src, dst in edges[: self.max_edges]: + print(f"{_truncate(src, 48):<48} {_truncate(dst, 48):<48}") + if len(edges) > self.max_edges: + print(f"... {len(edges) - self.max_edges} more edges") + + print("\nrecent topology events") + event_header = ( + f"{'Seq':>6} {'Type':<15} {'Age':>8} {'Topics':<44} " + f"{'Source Session':<38} {'Source Process':<30}" + ) + print(event_header) + print("-" * len(event_header)) + if not self._events: + print("") + else: + now = time.time() + for event in reversed(self._events): + topics = ", ".join(event.changed_topics) if event.changed_topics else "-" + print( + f"{event.seq:>6} " + f"{_truncate(event.event_type.value, 15):<15} " + f"{_fmt_age(max(0.0, now - event.timestamp)):>8} " + f"{_truncate(topics, 44):<44} " + f"{_truncate(event.source_session_id or '-', 38):<38} " + f"{_truncate(event.source_process_id or '-', 30):<30}" + ) + + +def _parse_address(host: str, port: int) -> tuple[str, int]: + return (host, port) + + +async def _run_tui(args: argparse.Namespace) -> None: + async with GraphContext( + _parse_address(args.host, args.port), auto_start=args.auto_start + ) as ctx: + tui = TopologyTUI( + ctx, + snapshot_interval=args.snapshot_interval, + render_interval=args.render_interval, + max_edges=args.max_edges, + max_events=args.max_events, + max_processes=args.max_processes, + ) + await tui.start() + try: + while True: + drained = await tui.update() + tui.render(drained_events=drained) + await asyncio.sleep(tui.render_interval) + finally: + await tui.close() + + +def _build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="ezmsg topology TUI") + parser.add_argument("--host", default=DEFAULT_HOST, help="GraphServer host") + parser.add_argument( + "--port", + type=int, + default=GRAPHSERVER_PORT_DEFAULT, + help="GraphServer port", + ) + parser.add_argument( + "--auto-start", + action="store_true", + help="Allow GraphContext to auto-start GraphServer if unavailable", + ) + parser.add_argument( + "--snapshot-interval", + type=float, + default=1.0, + help="Seconds between forced snapshot refreshes", + ) + parser.add_argument( + "--render-interval", + type=float, + default=0.5, + help="Seconds between screen redraws", + ) + parser.add_argument( + "--max-edges", + type=int, + default=50, + help="Max edge rows to render", + ) + parser.add_argument( + "--max-events", + type=int, + default=25, + help="Max recent topology events to retain/render", + ) + parser.add_argument( + "--max-processes", + type=int, + default=20, + help="Max process rows to render", + ) + return parser + + +def main() -> None: + parser = _build_parser() + args = parser.parse_args() + asyncio.run(_run_tui(args)) + + +if __name__ == "__main__": + main() diff --git a/src/ezmsg/core/graphcontext.py b/src/ezmsg/core/graphcontext.py index 03ef8124..ae107852 100644 --- a/src/ezmsg/core/graphcontext.py +++ b/src/ezmsg/core/graphcontext.py @@ -35,6 +35,7 @@ ProfilingTraceControl, SettingsChangedEvent, SettingsSnapshotValue, + TopologyChangedEvent, ) logger = logging.getLogger("ezmsg") @@ -423,6 +424,37 @@ async def subscribe_settings_events( finally: await close_stream_writer(writer) + async def subscribe_topology_events( + self, + *, + after_seq: int = 0, + ) -> typing.AsyncIterator[TopologyChangedEvent]: + reader, writer = await GraphService(self.graph_address).open_connection() + writer.write(Command.SESSION_TOPOLOGY_SUBSCRIBE.value) + writer.write(encode_str(str(after_seq))) + await writer.drain() + + _subscriber_id = UUID(await read_str(reader)) + response = await reader.read(1) + if response != Command.COMPLETE.value: + await close_stream_writer(writer) + raise RuntimeError("Failed to subscribe to topology events") + + try: + while True: + payload_size = await read_int(reader) + payload = await reader.readexactly(payload_size) + event = pickle.loads(payload) + if not isinstance(event, TopologyChangedEvent): + raise RuntimeError( + "Topology subscription received invalid event payload" + ) + yield event + except asyncio.IncompleteReadError: + return + finally: + await close_stream_writer(writer) + async def subscribe_profiling_trace( self, *, diff --git a/src/ezmsg/core/graphmeta.py b/src/ezmsg/core/graphmeta.py index 00d5be2c..97f21bb1 100644 --- a/src/ezmsg/core/graphmeta.py +++ b/src/ezmsg/core/graphmeta.py @@ -162,6 +162,21 @@ class SettingsChangedEvent: value: SettingsSnapshotValue +class TopologyEventType(enum.Enum): + GRAPH_CHANGED = "GRAPH_CHANGED" + PROCESS_CHANGED = "PROCESS_CHANGED" + + +@dataclass +class TopologyChangedEvent: + seq: int + event_type: TopologyEventType + timestamp: float + changed_topics: list[str] + source_session_id: str | None + source_process_id: str | None + + @dataclass class ProcessSettingsUpdate: process_id: str diff --git a/src/ezmsg/core/graphserver.py b/src/ezmsg/core/graphserver.py index 6b18ca87..2584490e 100644 --- a/src/ezmsg/core/graphserver.py +++ b/src/ezmsg/core/graphserver.py @@ -28,6 +28,8 @@ SettingsChangedEvent, SettingsEventType, SettingsSnapshotValue, + TopologyChangedEvent, + TopologyEventType, SnapshotProcess, SnapshotSession, ) @@ -94,6 +96,9 @@ class GraphServer(threading.Thread): _settings_event_seq: int _settings_owned_by_session: dict[UUID, set[str]] _settings_subscribers: dict[UUID, asyncio.Queue[SettingsChangedEvent]] + _topology_events: list[TopologyChangedEvent] + _topology_event_seq: int + _topology_subscribers: dict[UUID, asyncio.Queue[TopologyChangedEvent]] _pending_process_requests: dict[ str, tuple[UUID, "asyncio.Future[ProcessControlResponse]"] ] @@ -118,6 +123,9 @@ def __init__(self, **kwargs) -> None: self._settings_event_seq = 0 self._settings_owned_by_session = {} self._settings_subscribers = {} + self._topology_events = [] + self._topology_event_seq = 0 + self._topology_subscribers = {} self._pending_process_requests = {} @property @@ -330,6 +338,22 @@ async def api( # to avoid closing writer return + elif req == Command.SESSION_TOPOLOGY_SUBSCRIBE.value: + subscriber_id = uuid1() + after_seq = int(await read_str(reader)) + writer.write(encode_str(str(subscriber_id))) + writer.write(Command.COMPLETE.value) + await writer.drain() + self._client_tasks[subscriber_id] = asyncio.create_task( + self._handle_topology_subscriber( + subscriber_id, after_seq, reader, writer + ) + ) + + # NOTE: Created a stream client, must return early + # to avoid closing writer + return + elif req == Command.SESSION_PROFILING_SUBSCRIBE.value: subscriber_id = uuid1() interval = float(await read_str(reader)) @@ -422,6 +446,12 @@ async def api( writer.write(Command.CYCLIC.value) if topology_changed: + self._append_topology_event_locked( + event_type=TopologyEventType.GRAPH_CHANGED, + changed_topics=[to_topic], + source_session_id=None, + source_process_id=None, + ) await self._notify_downstream_for_topic(to_topic) await writer.drain() @@ -663,8 +693,19 @@ async def _handle_process( ), ) ) - - self.clients.pop(process_client_id, None) + if process_info is not None: + source_process_id = ( + process_info.process_id + if process_info.process_id is not None + else str(process_client_id) + ) + self._append_topology_event_locked( + event_type=TopologyEventType.PROCESS_CHANGED, + changed_topics=[], + source_session_id=None, + source_process_id=source_process_id, + ) + self.clients.pop(process_client_id, None) self._client_tasks.pop(process_client_id, None) await close_stream_writer(writer) @@ -697,6 +738,18 @@ def _queue_settings_event( with suppress(asyncio.QueueFull): queue.put_nowait(event) + def _queue_topology_event( + self, queue: asyncio.Queue[TopologyChangedEvent], event: TopologyChangedEvent + ) -> None: + try: + queue.put_nowait(event) + except asyncio.QueueFull: + # Keep most recent samples under backpressure. + with suppress(asyncio.QueueEmpty): + queue.get_nowait() + with suppress(asyncio.QueueFull): + queue.put_nowait(event) + async def _settings_sender( self, subscriber_id: UUID, @@ -751,6 +804,60 @@ async def _handle_settings_subscriber( await sender_task await close_stream_writer(writer) + async def _topology_sender( + self, + subscriber_id: UUID, + queue: asyncio.Queue[TopologyChangedEvent], + writer: asyncio.StreamWriter, + ) -> None: + try: + while True: + event = await queue.get() + payload = pickle.dumps(event) + writer.write(uint64_to_bytes(len(payload))) + writer.write(payload) + await writer.drain() + except (ConnectionResetError, BrokenPipeError): + logger.debug(f"Topology subscriber {subscriber_id} disconnected on send") + except asyncio.CancelledError: + raise + + async def _handle_topology_subscriber( + self, + subscriber_id: UUID, + after_seq: int, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + ) -> None: + queue: asyncio.Queue[TopologyChangedEvent] = asyncio.Queue(maxsize=1024) + + async with self._command_lock: + self._topology_subscribers[subscriber_id] = queue + for event in self._topology_events: + if event.seq > after_seq: + self._queue_topology_event(queue, event) + + sender_task = asyncio.create_task( + self._topology_sender(subscriber_id, queue, writer), + name=f"topology-sender-{subscriber_id}", + ) + + try: + while True: + req = await reader.read(1) + if not req: + break + except (ConnectionResetError, BrokenPipeError) as e: + logger.debug(f"Topology subscriber {subscriber_id} disconnected: {e}") + finally: + async with self._command_lock: + self._topology_subscribers.pop(subscriber_id, None) + self._client_tasks.pop(subscriber_id, None) + sender_task.cancel() + with suppress(asyncio.CancelledError): + await sender_task + await close_stream_writer(writer) + async def _profiling_route_targets(self) -> list[tuple[str, str]]: targets: list[tuple[str, str]] = [] async with self._command_lock: @@ -865,10 +972,18 @@ async def _handle_process_register_request( if process_info is None: return Command.COMPLETE.value + prev_units = set(process_info.units) process_info.process_id = registration.process_id process_info.pid = registration.pid process_info.host = registration.host process_info.units = set(registration.units) + if prev_units != process_info.units: + self._append_topology_event_locked( + event_type=TopologyEventType.PROCESS_CHANGED, + changed_topics=[], + source_session_id=None, + source_process_id=registration.process_id, + ) return Command.COMPLETE.value @@ -914,8 +1029,16 @@ async def _handle_process_update_ownership_request( elif process_info.process_id is None: process_info.process_id = update.process_id + prev_units = set(process_info.units) process_info.units.update(update.added_units) process_info.units.difference_update(update.removed_units) + if prev_units != process_info.units: + self._append_topology_event_locked( + event_type=TopologyEventType.PROCESS_CHANGED, + changed_topics=[], + source_session_id=None, + source_process_id=update.process_id, + ) return Command.COMPLETE.value @@ -1168,6 +1291,32 @@ def _append_settings_event_locked( if len(self._settings_events) > max_events: del self._settings_events[0 : len(self._settings_events) - max_events] + def _append_topology_event_locked( + self, + event_type: TopologyEventType, + changed_topics: list[str], + source_session_id: str | None, + source_process_id: str | None, + timestamp: float | None = None, + ) -> None: + self._topology_event_seq += 1 + event = TopologyChangedEvent( + seq=self._topology_event_seq, + event_type=event_type, + timestamp=timestamp if timestamp is not None else time.time(), + changed_topics=sorted(set(changed_topics)), + source_session_id=source_session_id, + source_process_id=source_process_id, + ) + self._topology_events.append(event) + + for queue in self._topology_subscribers.values(): + self._queue_topology_event(queue, event) + + max_events = 10_000 + if len(self._topology_events) > max_events: + del self._topology_events[0 : len(self._topology_events) - max_events] + def _remove_settings_for_session_locked(self, session_id: UUID) -> None: component_addresses = self._settings_owned_by_session.pop(session_id, set()) for component_address in component_addresses: @@ -1220,6 +1369,12 @@ async def _handle_session_edge_request( return Command.CYCLIC.value if topology_changed: + self._append_topology_event_locked( + event_type=TopologyEventType.GRAPH_CHANGED, + changed_topics=[to_topic], + source_session_id=str(session_id), + source_process_id=None, + ) await self._notify_downstream_for_topic(to_topic) return Command.COMPLETE.value @@ -1302,6 +1457,13 @@ def _clear_session_state(self, session_id: UUID) -> set[str]: self._remove_settings_for_session_locked(session_id) session.metadata = None + if notify_topics: + self._append_topology_event_locked( + event_type=TopologyEventType.GRAPH_CHANGED, + changed_topics=list(notify_topics), + source_session_id=str(session_id), + source_process_id=None, + ) return notify_topics def _drop_session(self, session_id: UUID) -> set[str]: @@ -1317,6 +1479,13 @@ def _drop_session(self, session_id: UUID) -> set[str]: self._remove_settings_for_session_locked(session_id) session.metadata = None self.clients.pop(session_id, None) + if notify_topics: + self._append_topology_event_locked( + event_type=TopologyEventType.GRAPH_CHANGED, + changed_topics=list(notify_topics), + source_session_id=str(session_id), + source_process_id=None, + ) return notify_topics def _snapshot(self) -> GraphSnapshot: diff --git a/src/ezmsg/core/netprotocol.py b/src/ezmsg/core/netprotocol.py index 33f0833b..22401658 100644 --- a/src/ezmsg/core/netprotocol.py +++ b/src/ezmsg/core/netprotocol.py @@ -334,6 +334,7 @@ def _generate_next_value_(name, start, count, last_values) -> bytes: SESSION_SETTINGS_SNAPSHOT = enum.auto() SESSION_SETTINGS_EVENTS = enum.auto() SESSION_SETTINGS_SUBSCRIBE = enum.auto() + SESSION_TOPOLOGY_SUBSCRIBE = enum.auto() SESSION_PROFILING_SUBSCRIBE = enum.auto() SESSION_PROCESS_REQUEST = enum.auto() diff --git a/tests/test_topology_api.py b/tests/test_topology_api.py new file mode 100644 index 00000000..bd4b66a4 --- /dev/null +++ b/tests/test_topology_api.py @@ -0,0 +1,103 @@ +import asyncio + +import pytest + +from ezmsg.core.graphcontext import GraphContext +from ezmsg.core.graphmeta import TopologyChangedEvent, TopologyEventType +from ezmsg.core.graphserver import GraphService +from ezmsg.core.processclient import ProcessControlClient + + +async def _next_matching_event( + stream, predicate, timeout: float = 1.0 +) -> TopologyChangedEvent: + async def _wait() -> TopologyChangedEvent: + while True: + event = await anext(stream) + if predicate(event): + return event + + return await asyncio.wait_for(_wait(), timeout=timeout) + + +@pytest.mark.asyncio +async def test_topology_subscription_reports_session_edge_changes(): + graph_server = GraphService().create_server() + address = graph_server.address + + owner = GraphContext(address, auto_start=False) + observer = GraphContext(address, auto_start=False) + + await owner.__aenter__() + await observer.__aenter__() + + stream = observer.subscribe_topology_events(after_seq=0) + try: + await owner.connect("SRC", "DST") + event = await _next_matching_event( + stream, + lambda e: ( + e.event_type == TopologyEventType.GRAPH_CHANGED + and "DST" in e.changed_topics + ), + timeout=1.0, + ) + assert event.source_session_id == str(owner._session_id) + + await owner.disconnect("SRC", "DST") + event = await _next_matching_event( + stream, + lambda e: ( + e.event_type == TopologyEventType.GRAPH_CHANGED + and "DST" in e.changed_topics + ), + timeout=1.0, + ) + assert event.source_session_id == str(owner._session_id) + finally: + await stream.aclose() + await owner.__aexit__(None, None, None) + await observer.__aexit__(None, None, None) + graph_server.stop() + + +@pytest.mark.asyncio +async def test_topology_subscription_reports_process_changes(): + graph_server = GraphService().create_server() + address = graph_server.address + + observer = GraphContext(address, auto_start=False) + await observer.__aenter__() + + process = ProcessControlClient(address, process_id="proc-topology") + stream = observer.subscribe_topology_events(after_seq=0) + + try: + await process.connect() + await process.register(["SYS/U1"]) + + registered = await _next_matching_event( + stream, + lambda e: ( + e.event_type == TopologyEventType.PROCESS_CHANGED + and e.source_process_id == "proc-topology" + ), + timeout=1.0, + ) + assert registered.source_session_id is None + + await process.update_ownership(added_units=["SYS/U2"], removed_units=["SYS/U1"]) + updated = await _next_matching_event( + stream, + lambda e: ( + e.event_type == TopologyEventType.PROCESS_CHANGED + and e.source_process_id == "proc-topology" + ), + timeout=1.0, + ) + assert updated.source_session_id is None + finally: + await stream.aclose() + await process.close() + await observer.__aexit__(None, None, None) + graph_server.stop() From e183452398c59834cc0f64fb8d6345c94998c737 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Mon, 16 Mar 2026 15:51:00 -0400 Subject: [PATCH 18/52] more flexible profile trace configurability --- examples/profiling_tui.py | 12 ++- src/ezmsg/core/graphcontext.py | 13 ++-- src/ezmsg/core/graphmeta.py | 13 ++++ src/ezmsg/core/graphserver.py | 57 ++++++++++++--- src/ezmsg/core/profiling.py | 115 ++++++++++++++++++++++++----- tests/test_profiling_api.py | 129 ++++++++++++++++++++++++++++++++- 6 files changed, 299 insertions(+), 40 deletions(-) diff --git a/examples/profiling_tui.py b/examples/profiling_tui.py index 13f230c6..99059630 100644 --- a/examples/profiling_tui.py +++ b/examples/profiling_tui.py @@ -20,7 +20,11 @@ from dataclasses import dataclass from ezmsg.core.graphcontext import GraphContext -from ezmsg.core.graphmeta import ProcessProfilingSnapshot, ProfilingTraceControl +from ezmsg.core.graphmeta import ( + ProcessProfilingSnapshot, + ProfilingStreamControl, + ProfilingTraceControl, +) from ezmsg.core.netprotocol import DEFAULT_HOST, GRAPHSERVER_PORT_DEFAULT @@ -167,8 +171,10 @@ async def _refresh_snapshot(self) -> None: async def _trace_loop(self) -> None: async for batch in self.ctx.subscribe_profiling_trace( - interval=self.trace_interval, - max_samples=self.trace_max_samples, + ProfilingStreamControl( + interval=self.trace_interval, + max_samples=self.trace_max_samples, + ) ): for process_batch in batch.batches.values(): for sample in process_batch.samples: diff --git a/src/ezmsg/core/graphcontext.py b/src/ezmsg/core/graphcontext.py index ae107852..b0a2f924 100644 --- a/src/ezmsg/core/graphcontext.py +++ b/src/ezmsg/core/graphcontext.py @@ -30,6 +30,7 @@ ProcessProfilingSnapshot, ProcessProfilingTraceBatch, ProfilingTraceStreamBatch, + ProfilingStreamControl, ProcessStats, ProcessControlResponse, ProfilingTraceControl, @@ -457,14 +458,16 @@ async def subscribe_topology_events( async def subscribe_profiling_trace( self, - *, - interval: float = 0.05, - max_samples: int = 1000, + control: ProfilingStreamControl, ) -> typing.AsyncIterator[ProfilingTraceStreamBatch]: + """ + Subscribe to streamed profiling trace batches from GraphServer. + """ reader, writer = await GraphService(self.graph_address).open_connection() + payload = pickle.dumps(control) writer.write(Command.SESSION_PROFILING_SUBSCRIBE.value) - writer.write(encode_str(str(interval))) - writer.write(encode_str(str(max_samples))) + writer.write(uint64_to_bytes(len(payload))) + writer.write(payload) await writer.drain() _subscriber_id = UUID(await read_str(reader)) diff --git a/src/ezmsg/core/graphmeta.py b/src/ezmsg/core/graphmeta.py index 97f21bb1..13005510 100644 --- a/src/ezmsg/core/graphmeta.py +++ b/src/ezmsg/core/graphmeta.py @@ -295,6 +295,10 @@ class ProfilingTraceControl: sample_mod: int = 1 publisher_topics: list[str] | None = None subscriber_topics: list[str] | None = None + publisher_endpoint_ids: list[str] | None = None + subscriber_endpoint_ids: list[str] | None = None + metrics: list[str] | None = None + ttl_seconds: float | None = None @dataclass @@ -322,6 +326,15 @@ class ProfilingTraceStreamBatch: batches: dict[str, ProcessProfilingTraceBatch] +@dataclass +class ProfilingStreamControl: + interval: float = 0.05 + max_samples: int = 1000 + process_ids: list[str] | None = None + timeout_per_process: float = 0.25 + include_empty_batches: bool = False + + class Edge(NamedTuple): from_topic: str to_topic: str diff --git a/src/ezmsg/core/graphserver.py b/src/ezmsg/core/graphserver.py index 2584490e..c86f2b1f 100644 --- a/src/ezmsg/core/graphserver.py +++ b/src/ezmsg/core/graphserver.py @@ -20,6 +20,7 @@ GraphSnapshot, ProcessProfilingTraceBatch, ProfilingTraceStreamBatch, + ProfilingStreamControl, ProcessControlRequest, ProcessControlResponse, ProcessRegistration, @@ -356,16 +357,14 @@ async def api( elif req == Command.SESSION_PROFILING_SUBSCRIBE.value: subscriber_id = uuid1() - interval = float(await read_str(reader)) - max_samples = int(await read_str(reader)) + stream_control = await self._read_profiling_stream_control(reader) writer.write(encode_str(str(subscriber_id))) writer.write(Command.COMPLETE.value) await writer.drain() self._client_tasks[subscriber_id] = asyncio.create_task( self._handle_profiling_subscriber( subscriber_id, - max(0.01, interval), - max(1, max_samples), + stream_control, reader, writer, ) @@ -869,16 +868,48 @@ async def _profiling_route_targets(self) -> list[tuple[str, str]]: process_id = info.process_id if info.process_id is not None else str(client_id) route_unit = sorted(info.units)[0] targets.append((process_id, route_unit)) - return targets + return sorted(targets, key=lambda item: item[0]) + + async def _read_profiling_stream_control( + self, reader: asyncio.StreamReader + ) -> ProfilingStreamControl: + payload_size = await read_int(reader) + payload = await reader.readexactly(payload_size) + + try: + payload_obj = pickle.loads(payload) + except Exception as exc: + raise RuntimeError( + f"Invalid profiling stream control payload: {exc}" + ) from exc + + if not isinstance(payload_obj, ProfilingStreamControl): + raise RuntimeError( + "Invalid profiling stream control payload type: " + f"{type(payload_obj).__name__}" + ) + return payload_obj async def _collect_profiling_trace_stream_batch( self, *, - max_samples: int, - timeout_per_process: float, + stream_control: ProfilingStreamControl, ) -> ProfilingTraceStreamBatch: + process_ids_filter = ( + set(stream_control.process_ids) + if stream_control.process_ids is not None + else None + ) targets = await self._profiling_route_targets() + if process_ids_filter is not None: + targets = [ + (process_id, route_unit) + for process_id, route_unit in targets + if process_id in process_ids_filter + ] batches: dict[str, ProcessProfilingTraceBatch] = {} + max_samples = max(1, int(stream_control.max_samples)) + timeout_per_process = max(0.01, float(stream_control.timeout_per_process)) request_payload = pickle.dumps(max_samples) for process_id, route_unit in targets: @@ -896,7 +927,10 @@ async def _collect_profiling_trace_stream_batch( continue if not isinstance(payload_obj, ProcessProfilingTraceBatch): continue - if len(payload_obj.samples) == 0: + if ( + len(payload_obj.samples) == 0 + and not stream_control.include_empty_batches + ): continue batches[process_id] = payload_obj @@ -908,11 +942,11 @@ async def _collect_profiling_trace_stream_batch( async def _handle_profiling_subscriber( self, subscriber_id: UUID, - interval: float, - max_samples: int, + stream_control: ProfilingStreamControl, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, ) -> None: + interval = max(0.01, float(stream_control.interval)) try: while True: try: @@ -925,8 +959,7 @@ async def _handle_profiling_subscriber( pass batch = await self._collect_profiling_trace_stream_batch( - max_samples=max_samples, - timeout_per_process=max(0.05, interval), + stream_control=stream_control, ) if len(batch.batches) == 0: continue diff --git a/src/ezmsg/core/profiling.py b/src/ezmsg/core/profiling.py index 3537f3d0..94ca0d0c 100644 --- a/src/ezmsg/core/profiling.py +++ b/src/ezmsg/core/profiling.py @@ -106,11 +106,15 @@ class _PublisherMetrics: _inflight: _Rolling = field(default_factory=_Rolling) trace_enabled: bool = False trace_sample_mod: int = 1 + trace_metrics: set[str] | None = None _trace_counter: int = 0 trace_samples: deque[ProfilingTraceSample] = field( default_factory=lambda: deque(maxlen=TRACE_MAX_SAMPLES) ) + def _trace_metric_enabled(self, metric: str) -> bool: + return self.trace_metrics is None or metric in self.trace_metrics + def record_publish(self, ts_ns: int, inflight: int) -> None: self.messages_published_total += 1 self._publish_count.add(ts_ns, 1) @@ -121,7 +125,11 @@ def record_publish(self, ts_ns: int, inflight: int) -> None: self._last_publish_ts_ns = ts_ns self.sample_inflight(ts_ns, inflight) self._trace_counter += 1 - if self.trace_enabled and (self._trace_counter % max(1, self.trace_sample_mod) == 0): + if ( + self.trace_enabled + and self._trace_metric_enabled("publish_delta_ns") + and (self._trace_counter % max(1, self.trace_sample_mod) == 0) + ): self.trace_samples.append( ProfilingTraceSample( timestamp=float(PROFILE_TIME()), @@ -135,7 +143,7 @@ def record_publish(self, ts_ns: int, inflight: int) -> None: def record_backpressure_wait(self, ts_ns: int, wait_ns: int) -> None: self.backpressure_wait_ns_total += wait_ns self._backpressure_wait.add(ts_ns, wait_ns) - if self.trace_enabled: + if self.trace_enabled and self._trace_metric_enabled("backpressure_wait_ns"): self.trace_samples.append( ProfilingTraceSample( timestamp=float(PROFILE_TIME()), @@ -183,11 +191,15 @@ class _SubscriberMetrics: _attrib_bp: _Rolling = field(default_factory=_Rolling) trace_enabled: bool = False trace_sample_mod: int = 1 + trace_metrics: set[str] | None = None _trace_counter: int = 0 trace_samples: deque[ProfilingTraceSample] = field( default_factory=lambda: deque(maxlen=TRACE_MAX_SAMPLES) ) + def _trace_metric_enabled(self, metric: str) -> bool: + return self.trace_metrics is None or metric in self.trace_metrics + def record_receive(self, ts_ns: int, lease_ns: int, channel_kind: ProfileChannelType) -> None: self.messages_received_total += 1 self.lease_time_ns_total += lease_ns @@ -195,7 +207,11 @@ def record_receive(self, ts_ns: int, lease_ns: int, channel_kind: ProfileChannel self._recv_count.add(ts_ns, 1) self._lease_time.add(ts_ns, lease_ns) self._trace_counter += 1 - if self.trace_enabled and (self._trace_counter % max(1, self.trace_sample_mod) == 0): + if ( + self.trace_enabled + and self._trace_metric_enabled("lease_time_ns") + and (self._trace_counter % max(1, self.trace_sample_mod) == 0) + ): self.trace_samples.append( ProfilingTraceSample( timestamp=float(PROFILE_TIME()), @@ -210,7 +226,7 @@ def record_receive(self, ts_ns: int, lease_ns: int, channel_kind: ProfileChannel def record_user_span(self, ts_ns: int, span_ns: int, label: str | None) -> None: self.user_span_ns_total += span_ns self._user_span.add(ts_ns, span_ns) - if self.trace_enabled: + if self.trace_enabled and self._trace_metric_enabled("user_span_ns"): self.trace_samples.append( ProfilingTraceSample( timestamp=float(PROFILE_TIME()), @@ -229,6 +245,17 @@ def record_attributed_backpressure( self.attributable_backpressure_events_total += 1 self.channel_kind_last = channel_kind self._attrib_bp.add(ts_ns, duration_ns) + if self.trace_enabled and self._trace_metric_enabled("attributable_backpressure_ns"): + self.trace_samples.append( + ProfilingTraceSample( + timestamp=float(PROFILE_TIME()), + endpoint_id=self.endpoint_id, + topic=self.topic, + metric="attributable_backpressure_ns", + value=float(duration_ns), + channel_kind=channel_kind, + ) + ) def snapshot(self) -> SubscriberProfileSnapshot: recv_count = self._recv_count.count_total() @@ -262,38 +289,46 @@ def __init__(self) -> None: self._publishers: dict[UUID, _PublisherMetrics] = {} self._subscribers: dict[UUID, _SubscriberMetrics] = {} self._default_trace_control = ProfilingTraceControl(enabled=False) + self._trace_control_expires_ns: int | None = None def set_process_id(self, process_id: str, *, reset: bool = False) -> None: if reset or (self._process_id and self._process_id != process_id): self._publishers.clear() self._subscribers.clear() self._default_trace_control = ProfilingTraceControl(enabled=False) + self._trace_control_expires_ns = None self._process_id = process_id def register_publisher(self, pub_id: UUID, topic: str) -> None: - self._publishers[pub_id] = _PublisherMetrics( + metric = _PublisherMetrics( topic=topic, endpoint_id=_endpoint_id(topic, pub_id), ) + self._publishers[pub_id] = metric + self._apply_trace_control_to_publisher(metric) def unregister_publisher(self, pub_id: UUID) -> None: self._publishers.pop(pub_id, None) def register_subscriber(self, sub_id: UUID, topic: str) -> None: - self._subscribers[sub_id] = _SubscriberMetrics( + metric = _SubscriberMetrics( topic=topic, endpoint_id=_endpoint_id(topic, sub_id), ) + self._subscribers[sub_id] = metric + self._apply_trace_control_to_subscriber(metric) def unregister_subscriber(self, sub_id: UUID) -> None: self._subscribers.pop(sub_id, None) def publisher_publish(self, pub_id: UUID, ts_ns: int, inflight: int) -> None: + self._expire_trace_control_if_needed(ts_ns) metric = self._publishers.get(pub_id) if metric is not None: metric.record_publish(ts_ns, inflight) def publisher_backpressure_wait(self, pub_id: UUID, ts_ns: int, wait_ns: int) -> None: + self._expire_trace_control_if_needed(ts_ns) metric = self._publishers.get(pub_id) if metric is not None: metric.record_backpressure_wait(ts_ns, wait_ns) @@ -310,6 +345,7 @@ def subscriber_receive( lease_ns: int, channel_kind: ProfileChannelType, ) -> None: + self._expire_trace_control_if_needed(ts_ns) metric = self._subscribers.get(sub_id) if metric is not None: metric.record_receive(ts_ns, lease_ns, channel_kind) @@ -317,6 +353,7 @@ def subscriber_receive( def subscriber_user_span( self, sub_id: UUID, ts_ns: int, span_ns: int, label: str | None ) -> None: + self._expire_trace_control_if_needed(ts_ns) metric = self._subscribers.get(sub_id) if metric is not None: metric.record_user_span(ts_ns, span_ns, label) @@ -328,6 +365,7 @@ def subscriber_attributed_backpressure( duration_ns: int, channel_kind: ProfileChannelType, ) -> None: + self._expire_trace_control_if_needed(ts_ns) metric = self._subscribers.get(sub_id) if metric is not None: metric.record_attributed_backpressure(ts_ns, duration_ns, channel_kind) @@ -351,25 +389,21 @@ def snapshot(self) -> ProcessProfilingSnapshot: def set_trace_control(self, control: ProfilingTraceControl) -> None: self._default_trace_control = control - sample_mod = max(1, control.sample_mod) - pub_topics = set(control.publisher_topics or []) - sub_topics = set(control.subscriber_topics or []) + if control.enabled and control.ttl_seconds is not None: + self._trace_control_expires_ns = PROFILE_TIME() + max( + 0, int(control.ttl_seconds * 1e9) + ) + else: + self._trace_control_expires_ns = None for metric in self._publishers.values(): - enabled = control.enabled and ( - not pub_topics or metric.topic in pub_topics - ) - metric.trace_enabled = enabled - metric.trace_sample_mod = sample_mod + self._apply_trace_control_to_publisher(metric) for metric in self._subscribers.values(): - enabled = control.enabled and ( - not sub_topics or metric.topic in sub_topics - ) - metric.trace_enabled = enabled - metric.trace_sample_mod = sample_mod + self._apply_trace_control_to_subscriber(metric) def trace_batch(self, max_samples: int = 1000) -> ProcessProfilingTraceBatch: + self._expire_trace_control_if_needed() samples: list[ProfilingTraceSample] = [] for metric in self._publishers.values(): while metric.trace_samples and len(samples) < max_samples: @@ -391,5 +425,48 @@ def trace_batch(self, max_samples: int = 1000) -> ProcessProfilingTraceBatch: samples=samples, ) + def _expire_trace_control_if_needed(self, now_ns: int | None = None) -> None: + expires_ns = self._trace_control_expires_ns + if expires_ns is None: + return + ts_ns = now_ns if now_ns is not None else PROFILE_TIME() + if ts_ns < expires_ns: + return + self.set_trace_control(ProfilingTraceControl(enabled=False)) + + def _apply_trace_control_to_publisher(self, metric: _PublisherMetrics) -> None: + control = self._default_trace_control + sample_mod = max(1, control.sample_mod) + pub_topics = set(control.publisher_topics or []) + pub_endpoint_ids = set(control.publisher_endpoint_ids or []) + trace_metrics = ( + set(control.metrics) if control.metrics is not None else None + ) + enabled = control.enabled + if enabled and pub_topics and metric.topic not in pub_topics: + enabled = False + if enabled and pub_endpoint_ids and metric.endpoint_id not in pub_endpoint_ids: + enabled = False + metric.trace_enabled = enabled + metric.trace_sample_mod = sample_mod + metric.trace_metrics = trace_metrics + + def _apply_trace_control_to_subscriber(self, metric: _SubscriberMetrics) -> None: + control = self._default_trace_control + sample_mod = max(1, control.sample_mod) + sub_topics = set(control.subscriber_topics or []) + sub_endpoint_ids = set(control.subscriber_endpoint_ids or []) + trace_metrics = ( + set(control.metrics) if control.metrics is not None else None + ) + enabled = control.enabled + if enabled and sub_topics and metric.topic not in sub_topics: + enabled = False + if enabled and sub_endpoint_ids and metric.endpoint_id not in sub_endpoint_ids: + enabled = False + metric.trace_enabled = enabled + metric.trace_sample_mod = sample_mod + metric.trace_metrics = trace_metrics + PROFILES = ProfileRegistry() diff --git a/tests/test_profiling_api.py b/tests/test_profiling_api.py index 49724588..2d8ef51f 100644 --- a/tests/test_profiling_api.py +++ b/tests/test_profiling_api.py @@ -5,6 +5,7 @@ from ezmsg.core.graphcontext import GraphContext from ezmsg.core.graphmeta import ( ProcessControlErrorCode, + ProfilingStreamControl, ProfilingTraceControl, ) from ezmsg.core.graphserver import GraphService @@ -166,7 +167,9 @@ async def test_process_profiling_trace_subscription_push(): ) assert response.ok - stream = ctx.subscribe_profiling_trace(interval=0.02, max_samples=256) + stream = ctx.subscribe_profiling_trace( + ProfilingStreamControl(interval=0.02, max_samples=256) + ) for idx in range(8): await pub.broadcast(idx) @@ -189,3 +192,127 @@ async def test_process_profiling_trace_subscription_push(): await process.close() await ctx.__aexit__(None, None, None) graph_server.stop() + + +@pytest.mark.asyncio +async def test_process_profiling_trace_control_endpoint_metric_and_ttl(): + graph_server = GraphService().create_server() + address = graph_server.address + + ctx = GraphContext(address, auto_start=False) + await ctx.__aenter__() + + process = ProcessControlClient(address, process_id="proc-trace-filter") + await process.connect() + await process.register(["SYS/U5"]) + + pub_a = await ctx.publisher("TOPIC_A") + sub_a = await ctx.subscriber("TOPIC_A") + pub_b = await ctx.publisher("TOPIC_B") + sub_b = await ctx.subscriber("TOPIC_B") + + try: + # Warm up and discover endpoint IDs for precise filter targeting. + for idx in range(3): + await pub_a.broadcast(idx) + async with sub_a.recv_zero_copy() as _msg: + await asyncio.sleep(0) + await pub_b.broadcast(idx) + async with sub_b.recv_zero_copy() as _msg: + await asyncio.sleep(0) + + snapshot = await ctx.process_profiling_snapshot("SYS/U5", timeout=1.0) + pub_a_endpoint = next( + pub.endpoint_id + for pub in snapshot.publishers.values() + if pub.topic == "TOPIC_A" + ) + + response = await ctx.process_set_profiling_trace( + "SYS/U5", + ProfilingTraceControl( + enabled=True, + sample_mod=1, + publisher_endpoint_ids=[pub_a_endpoint], + metrics=["publish_delta_ns"], + ), + timeout=1.0, + ) + assert response.ok + + for idx in range(8): + await pub_a.broadcast(idx) + async with sub_a.recv_zero_copy() as _msg: + await asyncio.sleep(0) + await pub_b.broadcast(idx) + async with sub_b.recv_zero_copy() as _msg: + await asyncio.sleep(0) + + batch = await ctx.process_profiling_trace_batch( + "SYS/U5", max_samples=512, timeout=1.0 + ) + assert len(batch.samples) > 0 + assert all(sample.metric == "publish_delta_ns" for sample in batch.samples) + assert all(sample.endpoint_id == pub_a_endpoint for sample in batch.samples) + + ttl_response = await ctx.process_set_profiling_trace( + "SYS/U5", + ProfilingTraceControl( + enabled=True, + sample_mod=1, + publisher_endpoint_ids=[pub_a_endpoint], + metrics=["publish_delta_ns"], + ttl_seconds=0.01, + ), + timeout=1.0, + ) + assert ttl_response.ok + await asyncio.sleep(0.03) + + for idx in range(3): + await pub_a.broadcast(idx) + async with sub_a.recv_zero_copy() as _msg: + await asyncio.sleep(0) + + expired_batch = await ctx.process_profiling_trace_batch( + "SYS/U5", max_samples=512, timeout=1.0 + ) + assert len(expired_batch.samples) == 0 + finally: + await process.close() + await ctx.__aexit__(None, None, None) + graph_server.stop() + + +@pytest.mark.asyncio +async def test_process_profiling_trace_subscription_stream_control(): + graph_server = GraphService().create_server() + address = graph_server.address + + ctx = GraphContext(address, auto_start=False) + await ctx.__aenter__() + + process_a = ProcessControlClient(address, process_id="proc-stream-a") + await process_a.connect() + await process_a.register(["SYS/U6"]) + + stream = None + try: + stream = ctx.subscribe_profiling_trace( + ProfilingStreamControl( + interval=0.02, + max_samples=64, + process_ids=["proc-stream-a"], + include_empty_batches=True, + timeout_per_process=0.1, + ) + ) + batch = await asyncio.wait_for(anext(stream), timeout=1.0) + assert "proc-stream-a" in batch.batches + assert len(batch.batches) == 1 + finally: + if stream is not None: + await stream.aclose() + await process_a.close() + await ctx.__aexit__(None, None, None) + graph_server.stop() From 0f906681f801e02c6a1f2d6fe7ab736116c9fdad Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Mon, 16 Mar 2026 16:57:35 -0400 Subject: [PATCH 19/52] GraphContext settings control APIs --- src/ezmsg/core/backendprocess.py | 150 +++++++++++++++++++++++++++++- src/ezmsg/core/graphcontext.py | 84 +++++++++++++++++ src/ezmsg/core/graphmeta.py | 8 +- src/ezmsg/core/graphserver.py | 155 +++++++++++++++++++++---------- src/ezmsg/core/netprotocol.py | 1 + src/ezmsg/core/processclient.py | 66 +++++++++++++ src/ezmsg/core/profiling.py | 4 + tests/test_profiling_api.py | 1 - tests/test_settings_api.py | 87 +++++++++++++++++ 9 files changed, 503 insertions(+), 53 deletions(-) diff --git a/src/ezmsg/core/backendprocess.py b/src/ezmsg/core/backendprocess.py index 1c3ed487..0f445a90 100644 --- a/src/ezmsg/core/backendprocess.py +++ b/src/ezmsg/core/backendprocess.py @@ -9,7 +9,7 @@ import weakref from abc import abstractmethod -from dataclasses import asdict, dataclass, is_dataclass +from dataclasses import asdict, dataclass, fields as dataclass_fields, is_dataclass, replace from collections import defaultdict from collections.abc import Awaitable, Callable, Coroutine, Generator, Sequence from functools import wraps, partial @@ -26,7 +26,14 @@ from .unit import Unit, TIMEIT_ATTR, SUBSCRIBES_ATTR from .graphcontext import GraphContext -from .graphmeta import SettingsSnapshotValue +from .graphmeta import ( + ProcessControlErrorCode, + ProcessControlOperation, + ProcessControlRequest, + ProcessControlResponse, + SettingsFieldUpdateRequest, + SettingsSnapshotValue, +) from .profiling import PROFILES, PROFILE_TIME from .processclient import ProcessControlClient from .pubclient import Publisher @@ -239,6 +246,35 @@ def _settings_snapshot_value(self, value: object) -> SettingsSnapshotValue: return SettingsSnapshotValue(serialized=serialized, repr_value=repr(value)) + def _replace_settings_field( + self, settings_value: object, field_path: str, value: object + ) -> object: + if field_path == "": + raise ValueError("field_path must not be empty") + path = field_path.split(".") + + def apply(current: object, idx: int) -> object: + if not is_dataclass(current): + raise TypeError( + "Cannot patch non-dataclass settings value at " + f"'{'.'.join(path[:idx])}'" + ) + field_name = path[idx] + valid_fields = {f.name for f in dataclass_fields(current)} + if field_name not in valid_fields: + raise AttributeError( + f"Settings field '{field_name}' does not exist on " + f"{type(current).__name__}" + ) + if idx == len(path) - 1: + return replace(current, **{field_name: value}) + + child_value = getattr(current, field_name) + patched_child = apply(child_value, idx + 1) + return replace(current, **{field_name: patched_child}) + + return apply(settings_value, 0) + def process(self, loop: asyncio.AbstractEventLoop) -> None: main_func = None context = GraphContext(self.graph_address) @@ -246,8 +282,114 @@ def process(self, loop: asyncio.AbstractEventLoop) -> None: PROFILES.set_process_id(process_client.process_id) process_register_future: concurrent.futures.Future[None] | None = None coro_callables: dict[str, Callable[[], Coroutine[Any, Any, None]]] = dict() + settings_input_topics: dict[str, str] = {} + current_settings: dict[str, object] = {} + control_publishers: dict[str, Publisher] = {} self._shutdown_errors = False + async def process_request_handler( + request: ProcessControlRequest, + ) -> ProcessControlResponse: + if request.operation != ProcessControlOperation.UPDATE_SETTING_FIELD.value: + return ProcessControlResponse( + request_id=request.request_id, + ok=False, + error=f"Unsupported process control operation: {request.operation}", + error_code=ProcessControlErrorCode.UNSUPPORTED_OPERATION, + process_id=process_client.process_id, + ) + + if request.payload is None: + return ProcessControlResponse( + request_id=request.request_id, + ok=False, + error="Missing settings field update payload", + error_code=ProcessControlErrorCode.INVALID_RESPONSE, + process_id=process_client.process_id, + ) + + try: + update_obj = pickle.loads(request.payload) + if not isinstance(update_obj, SettingsFieldUpdateRequest): + raise RuntimeError( + "settings field update payload was not SettingsFieldUpdateRequest" + ) + except Exception as exc: + return ProcessControlResponse( + request_id=request.request_id, + ok=False, + error=f"Invalid settings field update payload: {exc}", + error_code=ProcessControlErrorCode.INVALID_RESPONSE, + process_id=process_client.process_id, + ) + + unit_address = request.unit_address + input_topic = settings_input_topics.get(unit_address) + if input_topic is None: + return ProcessControlResponse( + request_id=request.request_id, + ok=False, + error=( + f"Unit '{unit_address}' does not expose INPUT_SETTINGS; " + "settings field update unsupported" + ), + error_code=ProcessControlErrorCode.UNSUPPORTED_OPERATION, + process_id=process_client.process_id, + ) + + if unit_address not in current_settings: + return ProcessControlResponse( + request_id=request.request_id, + ok=False, + error=( + f"No current settings value tracked for unit '{unit_address}'. " + "Send a full settings object first via update_settings()." + ), + error_code=ProcessControlErrorCode.HANDLER_ERROR, + process_id=process_client.process_id, + ) + + try: + patched = self._replace_settings_field( + current_settings[unit_address], + update_obj.field_path, + update_obj.value, + ) + current_settings[unit_address] = patched + control_pub = control_publishers.get(input_topic) + if control_pub is None: + control_pub = await context.publisher(input_topic) + control_publishers[input_topic] = control_pub + except Exception as exc: + return ProcessControlResponse( + request_id=request.request_id, + ok=False, + error=f"Failed to patch settings field: {exc}", + error_code=ProcessControlErrorCode.HANDLER_ERROR, + process_id=process_client.process_id, + ) + + async def publish_patched_settings() -> None: + try: + await control_pub.broadcast(patched) + except Exception as exc: + logger.warning( + "Failed to publish patched settings for %s: %s", + unit_address, + exc, + ) + + asyncio.create_task(publish_patched_settings()) + result_value = self._settings_snapshot_value(patched) + return ProcessControlResponse( + request_id=request.request_id, + ok=True, + payload=pickle.dumps(result_value), + process_id=process_client.process_id, + ) + + process_client.set_request_handler(process_request_handler) + try: self.pubs = dict() @@ -281,6 +423,8 @@ async def setup_state(): main_func = None for unit in self.units: + if unit.SETTINGS is not None: + current_settings[unit.address] = unit.SETTINGS sub_callables: defaultdict[ str, set[Callable[..., Coroutine[Any, Any, None]]] ] = defaultdict(set) @@ -311,12 +455,14 @@ async def setup_state(): ) = None if stream.name == "INPUT_SETTINGS": component_address = unit.address + settings_input_topics[component_address] = stream.address async def report_settings_update_cb( msg: object, *, _component_address: str = component_address, ) -> None: + current_settings[_component_address] = msg value = self._settings_snapshot_value(msg) await process_client.report_settings_update( component_address=_component_address, diff --git a/src/ezmsg/core/graphcontext.py b/src/ezmsg/core/graphcontext.py index b0a2f924..1cdb494c 100644 --- a/src/ezmsg/core/graphcontext.py +++ b/src/ezmsg/core/graphcontext.py @@ -34,6 +34,7 @@ ProcessStats, ProcessControlResponse, ProfilingTraceControl, + SettingsFieldUpdateRequest, SettingsChangedEvent, SettingsSnapshotValue, TopologyChangedEvent, @@ -394,6 +395,89 @@ async def settings_events(self, after_seq: int = 0) -> list[SettingsChangedEvent raise RuntimeError("Settings event payload contained invalid entries") return events + async def settings_input_topic(self, component_address: str) -> str: + """ + Resolve the dynamic settings input topic for a component. + + The topic is discovered from currently registered session metadata. + Raises if the component is missing, does not opt in to dynamic settings, + or appears with conflicting dynamic settings topics. + """ + snapshot = await self.snapshot() + topics: set[str] = set() + for session in snapshot.sessions.values(): + metadata = session.metadata + if metadata is None: + continue + component = metadata.components.get(component_address) + if component is None: + continue + dynamic_settings = component.dynamic_settings + if dynamic_settings.enabled and dynamic_settings.input_topic is not None: + topics.add(dynamic_settings.input_topic) + + if len(topics) == 1: + return next(iter(topics)) + if len(topics) > 1: + raise RuntimeError( + "Conflicting dynamic settings topics for component " + f"'{component_address}': {sorted(topics)}" + ) + raise RuntimeError( + f"Component '{component_address}' does not expose dynamic settings metadata" + ) + + async def update_settings( + self, + component_address: str, + value: object, + *, + input_topic: str | None = None, + ) -> None: + """ + Publish a settings value to a component's `INPUT_SETTINGS` inlet. + + By default the target topic is resolved from metadata via + :meth:`settings_input_topic`. Supplying `input_topic` bypasses + metadata lookup. + """ + topic = input_topic if input_topic is not None else await self.settings_input_topic( + component_address + ) + pub = await self.publisher(topic) + try: + await pub.broadcast(value) + finally: + pub.close() + await pub.wait_closed() + self._clients.discard(pub) + + async def update_setting( + self, + component_address: str, + field_path: str, + value: object, + *, + timeout: float = 2.0, + ) -> SettingsSnapshotValue: + """ + Patch one field of a unit's current dynamic settings value. + + The patch is routed to the owning backend process, applied in-process + using dataclass replacement, and then published to `INPUT_SETTINGS`. + Returns a snapshot representation of the patched settings value. + """ + response = await self.process_request( + component_address, + ProcessControlOperation.UPDATE_SETTING_FIELD, + payload_obj=SettingsFieldUpdateRequest(field_path=field_path, value=value), + timeout=timeout, + ) + return typing.cast( + SettingsSnapshotValue, + self.decode_process_payload(response, SettingsSnapshotValue), + ) + async def subscribe_settings_events( self, *, diff --git a/src/ezmsg/core/graphmeta.py b/src/ezmsg/core/graphmeta.py index 13005510..fc1bd98c 100644 --- a/src/ezmsg/core/graphmeta.py +++ b/src/ezmsg/core/graphmeta.py @@ -199,6 +199,7 @@ class ProcessControlOperation(enum.Enum): GET_PROFILING_SNAPSHOT = "GET_PROFILING_SNAPSHOT" SET_PROFILING_TRACE = "SET_PROFILING_TRACE" GET_PROFILING_TRACE_BATCH = "GET_PROFILING_TRACE_BATCH" + UPDATE_SETTING_FIELD = "UPDATE_SETTING_FIELD" class ProcessControlErrorCode(enum.Enum): @@ -222,6 +223,12 @@ class ProcessControlResponse: process_id: str | None = None +@dataclass +class SettingsFieldUpdateRequest: + field_path: str + value: Any + + @dataclass class ProcessPing: process_id: str @@ -331,7 +338,6 @@ class ProfilingStreamControl: interval: float = 0.05 max_samples: int = 1000 process_ids: list[str] | None = None - timeout_per_process: float = 0.25 include_empty_batches: bool = False diff --git a/src/ezmsg/core/graphserver.py b/src/ezmsg/core/graphserver.py index c86f2b1f..a578a95d 100644 --- a/src/ezmsg/core/graphserver.py +++ b/src/ezmsg/core/graphserver.py @@ -5,6 +5,7 @@ import socket import threading import time +from collections import deque from contextlib import suppress from uuid import UUID, uuid1 @@ -19,6 +20,7 @@ GraphMetadata, GraphSnapshot, ProcessProfilingTraceBatch, + ProfilingTraceSample, ProfilingTraceStreamBatch, ProfilingStreamControl, ProcessControlRequest, @@ -103,6 +105,8 @@ class GraphServer(threading.Thread): _pending_process_requests: dict[ str, tuple[UUID, "asyncio.Future[ProcessControlResponse]"] ] + _profiling_trace_buffers: dict[str, deque[ProfilingTraceSample]] + _profiling_trace_process_meta: dict[str, tuple[int, str]] def __init__(self, **kwargs) -> None: super().__init__( @@ -128,6 +132,8 @@ def __init__(self, **kwargs) -> None: self._topology_event_seq = 0 self._topology_subscribers = {} self._pending_process_requests = {} + self._profiling_trace_buffers = {} + self._profiling_trace_process_meta = {} @property def address(self) -> Address: @@ -651,6 +657,11 @@ async def _handle_process( process_client_id, writer, response ) + elif req == Command.PROCESS_PROFILING_TRACE_UPDATE.value: + await self._handle_process_profiling_trace_update_request( + process_client_id, reader + ) + elif req == Command.PROCESS_ROUTE_RESPONSE.value: await self._handle_process_route_response_request( process_client_id, reader @@ -698,6 +709,8 @@ async def _handle_process( if process_info.process_id is not None else str(process_client_id) ) + self._profiling_trace_buffers.pop(source_process_id, None) + self._profiling_trace_process_meta.pop(source_process_id, None) self._append_topology_event_locked( event_type=TopologyEventType.PROCESS_CHANGED, changed_topics=[], @@ -857,19 +870,6 @@ async def _handle_topology_subscriber( await sender_task await close_stream_writer(writer) - async def _profiling_route_targets(self) -> list[tuple[str, str]]: - targets: list[tuple[str, str]] = [] - async with self._command_lock: - for client_id, info in self.clients.items(): - if not isinstance(info, ProcessInfo): - continue - if len(info.units) == 0: - continue - process_id = info.process_id if info.process_id is not None else str(client_id) - route_unit = sorted(info.units)[0] - targets.append((process_id, route_unit)) - return sorted(targets, key=lambda item: item[0]) - async def _read_profiling_stream_control( self, reader: asyncio.StreamReader ) -> ProfilingStreamControl: @@ -900,44 +900,50 @@ async def _collect_profiling_trace_stream_batch( if stream_control.process_ids is not None else None ) - targets = await self._profiling_route_targets() - if process_ids_filter is not None: - targets = [ - (process_id, route_unit) - for process_id, route_unit in targets - if process_id in process_ids_filter - ] - batches: dict[str, ProcessProfilingTraceBatch] = {} max_samples = max(1, int(stream_control.max_samples)) - timeout_per_process = max(0.01, float(stream_control.timeout_per_process)) - request_payload = pickle.dumps(max_samples) - - for process_id, route_unit in targets: - response = await self._route_process_request( - unit_address=route_unit, - operation=ProcessControlOperation.GET_PROFILING_TRACE_BATCH.value, - payload=request_payload, - timeout=timeout_per_process, - ) - if not response.ok or response.payload is None: - continue - try: - payload_obj = pickle.loads(response.payload) - except Exception: - continue - if not isinstance(payload_obj, ProcessProfilingTraceBatch): - continue - if ( - len(payload_obj.samples) == 0 - and not stream_control.include_empty_batches - ): - continue - batches[process_id] = payload_obj + now_ts = time.time() + batches: dict[str, ProcessProfilingTraceBatch] = {} - return ProfilingTraceStreamBatch( - timestamp=time.time(), - batches=batches, - ) + async with self._command_lock: + connected_processes: dict[str, tuple[int, str]] = {} + for client_id, info in self.clients.items(): + if not isinstance(info, ProcessInfo): + continue + process_id = ( + info.process_id if info.process_id is not None else str(client_id) + ) + pid = info.pid if info.pid is not None else -1 + host = info.host if info.host is not None else "" + connected_processes[process_id] = (pid, host) + + process_ids: list[str] + if process_ids_filter is not None: + process_ids = sorted(process_ids_filter) + else: + process_ids = sorted(connected_processes.keys()) + + for process_id in process_ids: + sample_buffer = self._profiling_trace_buffers.get(process_id) + samples: list[ProfilingTraceSample] = [] + while sample_buffer and len(samples) < max_samples: + samples.append(sample_buffer.popleft()) + + if len(samples) == 0 and not stream_control.include_empty_batches: + continue + + pid, host = connected_processes.get( + process_id, + self._profiling_trace_process_meta.get(process_id, (-1, "")), + ) + batches[process_id] = ProcessProfilingTraceBatch( + process_id=process_id, + pid=pid, + host=host, + timestamp=now_ts, + samples=samples, + ) + + return ProfilingTraceStreamBatch(timestamp=now_ts, batches=batches) async def _handle_profiling_subscriber( self, @@ -972,10 +978,57 @@ async def _handle_profiling_subscriber( logger.debug(f"Profiling subscriber {subscriber_id} disconnected: {e}") except asyncio.CancelledError: raise + except Exception as exc: + logger.error( + "Profiling subscriber %s failed: %s", + subscriber_id, + exc, + ) finally: self._client_tasks.pop(subscriber_id, None) await close_stream_writer(writer) + async def _handle_process_profiling_trace_update_request( + self, process_client_id: UUID, reader: asyncio.StreamReader + ) -> None: + payload_size = await read_int(reader) + payload = await reader.readexactly(payload_size) + batch: ProcessProfilingTraceBatch | None = None + try: + payload_obj = pickle.loads(payload) + if isinstance(payload_obj, ProcessProfilingTraceBatch): + batch = payload_obj + else: + raise RuntimeError( + "process profiling trace payload was not ProcessProfilingTraceBatch" + ) + except Exception as exc: + logger.warning( + "Process control %s trace update parse failed; ignoring payload: %s", + process_client_id, + exc, + ) + + if batch is None: + return + + async with self._command_lock: + process_info = self._process_info(process_client_id) + process_id = ( + ( + process_info.process_id + if process_info is not None and process_info.process_id is not None + else str(process_client_id) + ) + if batch.process_id == "" + else batch.process_id + ) + trace_buffer = self._profiling_trace_buffers.setdefault( + process_id, deque(maxlen=200_000) + ) + trace_buffer.extend(batch.samples) + self._profiling_trace_process_meta[process_id] = (batch.pid, batch.host) + async def _handle_process_register_request( self, process_client_id: UUID, reader: asyncio.StreamReader ) -> bytes: @@ -1006,10 +1059,14 @@ async def _handle_process_register_request( return Command.COMPLETE.value prev_units = set(process_info.units) + prev_process_id = process_info.process_id process_info.process_id = registration.process_id process_info.pid = registration.pid process_info.host = registration.host process_info.units = set(registration.units) + if prev_process_id is not None and prev_process_id != registration.process_id: + self._profiling_trace_buffers.pop(prev_process_id, None) + self._profiling_trace_process_meta.pop(prev_process_id, None) if prev_units != process_info.units: self._append_topology_event_locked( event_type=TopologyEventType.PROCESS_CHANGED, diff --git a/src/ezmsg/core/netprotocol.py b/src/ezmsg/core/netprotocol.py index 22401658..13c582b3 100644 --- a/src/ezmsg/core/netprotocol.py +++ b/src/ezmsg/core/netprotocol.py @@ -343,6 +343,7 @@ def _generate_next_value_(name, start, count, last_values) -> bytes: PROCESS_REGISTER = enum.auto() PROCESS_UPDATE_OWNERSHIP = enum.auto() PROCESS_SETTINGS_UPDATE = enum.auto() + PROCESS_PROFILING_TRACE_UPDATE = enum.auto() PROCESS_ROUTE_REQUEST = enum.auto() PROCESS_ROUTE_RESPONSE = enum.auto() diff --git a/src/ezmsg/core/processclient.py b/src/ezmsg/core/processclient.py index f6574b5c..d5279395 100644 --- a/src/ezmsg/core/processclient.py +++ b/src/ezmsg/core/processclient.py @@ -51,6 +51,9 @@ class ProcessControlClient: [ProcessControlRequest], ProcessControlResponse | Awaitable[ProcessControlResponse] ] | None _owned_units: set[str] + _trace_push_task: asyncio.Task[None] | None + _trace_push_interval_s: float + _trace_push_max_samples: int def __init__( self, graph_address: AddressType | None = None, process_id: str | None = None @@ -65,6 +68,13 @@ def __init__( self._io_task = None self._request_handler = None self._owned_units = set() + self._trace_push_task = None + self._trace_push_interval_s = float( + os.environ.get("EZMSG_PROFILE_TRACE_PUSH_INTERVAL_S", "0.05") + ) + self._trace_push_max_samples = int( + os.environ.get("EZMSG_PROFILE_TRACE_PUSH_MAX_SAMPLES", "1000") + ) PROFILES.set_process_id(self._process_id, reset=True) @property @@ -156,6 +166,13 @@ async def close(self) -> None: if writer is None: return + trace_task = self._trace_push_task + self._trace_push_task = None + if trace_task is not None: + trace_task.cancel() + with suppress(asyncio.CancelledError): + await trace_task + io_task = self._io_task self._io_task = None if io_task is not None: @@ -338,6 +355,11 @@ async def _handle_route_request( ) PROFILES.set_trace_control(control) + if control.enabled: + await self._ensure_trace_push_task() + else: + await self._cancel_trace_push_task() + return ProcessControlResponse( request_id=request.request_id, ok=True, @@ -414,3 +436,47 @@ async def _handle_route_request( result.process_id = self._process_id return result + + async def _ensure_trace_push_task(self) -> None: + task = self._trace_push_task + if task is not None and not task.done(): + return + self._trace_push_task = asyncio.create_task( + self._trace_push_loop(), + name=f"proc-trace-push-{self._process_id}", + ) + + async def _cancel_trace_push_task(self) -> None: + task = self._trace_push_task + self._trace_push_task = None + if task is None: + return + task.cancel() + with suppress(asyncio.CancelledError): + await task + + async def _trace_push_loop(self) -> None: + try: + while True: + await asyncio.sleep(max(0.01, self._trace_push_interval_s)) + batch: ProcessProfilingTraceBatch = PROFILES.trace_batch( + max_samples=max(1, self._trace_push_max_samples) + ) + if len(batch.samples) > 0: + await self._write_payload( + Command.PROCESS_PROFILING_TRACE_UPDATE, + batch, + expect_complete=False, + ) + + if not PROFILES.trace_enabled(): + break + except asyncio.CancelledError: + raise + except (ConnectionResetError, BrokenPipeError): + logger.debug("Process trace push loop disconnected") + except Exception as exc: + logger.warning(f"Process trace push loop failed: {exc}") + finally: + if asyncio.current_task() is self._trace_push_task: + self._trace_push_task = None diff --git a/src/ezmsg/core/profiling.py b/src/ezmsg/core/profiling.py index 94ca0d0c..760b377b 100644 --- a/src/ezmsg/core/profiling.py +++ b/src/ezmsg/core/profiling.py @@ -425,6 +425,10 @@ def trace_batch(self, max_samples: int = 1000) -> ProcessProfilingTraceBatch: samples=samples, ) + def trace_enabled(self) -> bool: + self._expire_trace_control_if_needed() + return self._default_trace_control.enabled + def _expire_trace_control_if_needed(self, now_ns: int | None = None) -> None: expires_ns = self._trace_control_expires_ns if expires_ns is None: diff --git a/tests/test_profiling_api.py b/tests/test_profiling_api.py index 2d8ef51f..0449520b 100644 --- a/tests/test_profiling_api.py +++ b/tests/test_profiling_api.py @@ -304,7 +304,6 @@ async def test_process_profiling_trace_subscription_stream_control(): max_samples=64, process_ids=["proc-stream-a"], include_empty_batches=True, - timeout_per_process=0.1, ) ) batch = await asyncio.wait_for(anext(stream), timeout=1.0) diff --git a/tests/test_settings_api.py b/tests/test_settings_api.py index b949a526..2920374a 100644 --- a/tests/test_settings_api.py +++ b/tests/test_settings_api.py @@ -1,4 +1,5 @@ import asyncio +import pickle from dataclasses import dataclass import pytest @@ -9,6 +10,8 @@ ComponentMetadata, DynamicSettingsMetadata, GraphMetadata, + ProcessControlResponse, + SettingsFieldUpdateRequest, SettingsEventType, SettingsSnapshotValue, ) @@ -101,6 +104,13 @@ def network(self) -> ez.NetworkDefinition: return ((self.SRC.OUTPUT, self.SINK.INPUT_SETTINGS),) +class _SettingsOnlySystem(ez.Collection): + SINK = _SettingsSink() + + def network(self) -> ez.NetworkDefinition: + return () + + def test_input_settings_hook_reports_to_graphserver(): graph_server = GraphService().create_server() address = graph_server.address @@ -132,6 +142,83 @@ async def observe() -> None: graph_server.stop() +@pytest.mark.asyncio +async def test_graphcontext_update_settings_via_input_settings_topic(): + graph_server = GraphService().create_server() + address = graph_server.address + + observer = GraphContext(address, auto_start=False) + await observer.__aenter__() + + run_task = asyncio.create_task( + asyncio.to_thread( + ez.run, + components={"SYS": _SettingsOnlySystem()}, + graph_address=address, + force_single_process=True, + ) + ) + + try: + for _ in range(40): + try: + await observer.settings_input_topic("SYS/SINK") + break + except RuntimeError: + await asyncio.sleep(0.05) + else: + raise AssertionError("Timed out waiting for dynamic settings metadata") + + await observer.update_settings("SYS/SINK", _SettingsMsg(gain=11)) + await asyncio.wait_for(run_task, timeout=5.0) + + settings = await observer.settings_snapshot() + assert settings["SYS/SINK"].repr_value == {"gain": 11} + + finally: + if not run_task.done(): + await asyncio.wait_for(run_task, timeout=5.0) + await observer.__aexit__(None, None, None) + graph_server.stop() + + +@pytest.mark.asyncio +async def test_graphcontext_update_setting_field_routes_to_process(): + graph_server = GraphService().create_server() + address = graph_server.address + + observer = GraphContext(address, auto_start=False) + await observer.__aenter__() + process = ProcessControlClient(address, process_id="proc-setting-patch") + await process.connect() + await process.register(["SYS/SINK"]) + + try: + async def handler(request) -> ProcessControlResponse: + assert request.operation == "UPDATE_SETTING_FIELD" + assert request.payload is not None + update = pickle.loads(request.payload) + assert isinstance(update, SettingsFieldUpdateRequest) + assert update.field_path == "gain" + assert update.value == 11 + return ProcessControlResponse( + request_id=request.request_id, + ok=True, + payload=pickle.dumps( + SettingsSnapshotValue(serialized=None, repr_value={"gain": 11}) + ), + ) + + process.set_request_handler(handler) + + patched = await observer.update_setting("SYS/SINK", "gain", 11) + assert patched.repr_value == {"gain": 11} + finally: + await process.close() + await observer.__aexit__(None, None, None) + graph_server.stop() + + @pytest.mark.asyncio async def test_process_reported_settings_update_visible_in_snapshot_and_events(): graph_server = GraphService().create_server() From 8a9ad347b28c39137af59d8f03220218295e6897 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Mon, 16 Mar 2026 17:52:56 -0400 Subject: [PATCH 20/52] tests: sync shutdown READY signal with task startup --- tests/shutdown_runner.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/shutdown_runner.py b/tests/shutdown_runner.py index c185a23a..39bb5124 100644 --- a/tests/shutdown_runner.py +++ b/tests/shutdown_runner.py @@ -8,11 +8,14 @@ import ezmsg.core as ez +STARTED = threading.Event() + class BlockingDiskIO(ez.Unit): @ez.task async def blocked_read(self) -> None: # Cross-platform "hung disk I/O" simulation. + STARTED.set() event = threading.Event() self._event = event await asyncio.shield(asyncio.to_thread(event.wait)) @@ -21,6 +24,7 @@ async def blocked_read(self) -> None: class BlockingSocket(ez.Unit): @ez.task async def blocked_recv(self) -> None: + STARTED.set() sock_r, sock_w = socket.socketpair() sock_r.setblocking(True) sock_w.setblocking(True) @@ -33,6 +37,7 @@ async def blocked_recv(self) -> None: class ExplodeOnCancel(ez.Unit): @ez.task async def explode(self) -> None: + STARTED.set() try: while True: await asyncio.sleep(1.0) @@ -43,6 +48,7 @@ async def explode(self) -> None: class StubbornTask(ez.Unit): @ez.task async def ignore_cancel(self) -> None: + STARTED.set() while True: try: await asyncio.sleep(1.0) @@ -84,7 +90,7 @@ def _emit_ready() -> None: def _watch_ready() -> None: while not done.is_set(): - if runner.running: + if runner.running and STARTED.is_set(): _emit_ready() return time.sleep(0.01) From 35ed2320c0537c0bd1d1fc628c13938fb87432f1 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Mon, 16 Mar 2026 18:00:22 -0400 Subject: [PATCH 21/52] added settings structure and pydantic/param compat --- src/ezmsg/core/backend.py | 22 ++- src/ezmsg/core/backendprocess.py | 86 +++++--- src/ezmsg/core/graphmeta.py | 22 +++ src/ezmsg/core/graphserver.py | 5 +- src/ezmsg/core/settingsmeta.py | 329 +++++++++++++++++++++++++++++++ tests/shutdown_runner.py | 8 +- tests/test_settings_api.py | 12 ++ 7 files changed, 450 insertions(+), 34 deletions(-) create mode 100644 src/ezmsg/core/settingsmeta.py diff --git a/src/ezmsg/core/backend.py b/src/ezmsg/core/backend.py index 4318f1e2..7a06fbd0 100644 --- a/src/ezmsg/core/backend.py +++ b/src/ezmsg/core/backend.py @@ -1,7 +1,6 @@ import asyncio from collections.abc import Callable, Mapping, Iterable from collections.abc import Collection as AbstractCollection -from dataclasses import asdict, is_dataclass import enum import inspect import logging @@ -53,6 +52,11 @@ UnitMetadata, ) from .relay import _CollectionRelayUnit, _RelaySettings +from .settingsmeta import ( + settings_repr_value, + settings_schema_from_type, + settings_schema_from_value, +) from .graphserver import GraphService from .graphcontext import GraphContext @@ -424,15 +428,7 @@ def _stream_type_name(self, stream_type: object) -> str: return repr(stream_type) def _settings_repr(self, value: object) -> dict[str, object] | str: - if is_dataclass(value): - try: - asdict_value = asdict(value) - if isinstance(asdict_value, dict): - return asdict_value - except Exception: - pass - - return repr(value) + return settings_repr_value(value) def _settings_snapshot(self, value: object) -> tuple[bytes | None, dict[str, object] | str]: try: @@ -579,6 +575,11 @@ def _component_metadata(self) -> GraphMetadata: if inspect.isclass(settings_type) else repr(settings_type) ) + settings_schema = ( + settings_schema_from_value(comp.SETTINGS) + if comp.SETTINGS is not None + else settings_schema_from_type(settings_type) + ) component_common = dict( address=comp.address, @@ -587,6 +588,7 @@ def _component_metadata(self) -> GraphMetadata: settings_type=settings_type_name, initial_settings=self._settings_snapshot(comp.SETTINGS), dynamic_settings=dynamic_settings, + settings_schema=settings_schema, ) metadata_entry: ComponentMetadataType diff --git a/src/ezmsg/core/backendprocess.py b/src/ezmsg/core/backendprocess.py index 0f445a90..c2355232 100644 --- a/src/ezmsg/core/backendprocess.py +++ b/src/ezmsg/core/backendprocess.py @@ -7,11 +7,12 @@ import traceback import threading import weakref +from copy import deepcopy from abc import abstractmethod -from dataclasses import asdict, dataclass, fields as dataclass_fields, is_dataclass, replace +from dataclasses import dataclass, fields as dataclass_fields, is_dataclass, replace from collections import defaultdict -from collections.abc import Awaitable, Callable, Coroutine, Generator, Sequence +from collections.abc import Awaitable, Callable, Coroutine, Generator, Mapping, Sequence from functools import wraps, partial from concurrent.futures import ThreadPoolExecutor from concurrent.futures.thread import _worker @@ -39,6 +40,11 @@ from .pubclient import Publisher from .subclient import Subscriber from .netprotocol import AddressType +from .settingsmeta import ( + settings_repr_value, + settings_schema_from_value, + settings_structured_value, +) logger = logging.getLogger("ezmsg") @@ -236,15 +242,12 @@ def _settings_snapshot_value(self, value: object) -> SettingsSnapshotValue: except Exception: serialized = None - if is_dataclass(value): - try: - repr_value = asdict(value) - if isinstance(repr_value, dict): - return SettingsSnapshotValue(serialized=serialized, repr_value=repr_value) - except Exception: - pass - - return SettingsSnapshotValue(serialized=serialized, repr_value=repr(value)) + return SettingsSnapshotValue( + serialized=serialized, + repr_value=settings_repr_value(value), + structured_value=settings_structured_value(value), + settings_schema=settings_schema_from_value(value), + ) def _replace_settings_field( self, settings_value: object, field_path: str, value: object @@ -254,27 +257,64 @@ def _replace_settings_field( path = field_path.split(".") def apply(current: object, idx: int) -> object: - if not is_dataclass(current): - raise TypeError( - "Cannot patch non-dataclass settings value at " - f"'{'.'.join(path[:idx])}'" - ) field_name = path[idx] - valid_fields = {f.name for f in dataclass_fields(current)} - if field_name not in valid_fields: + if isinstance(current, Mapping): + if field_name not in current: + raise AttributeError( + f"Settings field '{field_name}' does not exist in mapping" + ) + if idx == len(path) - 1: + updated = dict(current) + updated[field_name] = value + return updated + patched_child = apply(current[field_name], idx + 1) + updated = dict(current) + updated[field_name] = patched_child + return updated + + if not hasattr(current, field_name): raise AttributeError( f"Settings field '{field_name}' does not exist on " f"{type(current).__name__}" ) + if idx == len(path) - 1: - return replace(current, **{field_name: value}) + return self._patch_object_field(current, field_name, value) child_value = getattr(current, field_name) patched_child = apply(child_value, idx + 1) - return replace(current, **{field_name: patched_child}) + return self._patch_object_field(current, field_name, patched_child) return apply(settings_value, 0) + def _patch_object_field( + self, obj: object, field_name: str, value: object + ) -> object: + if is_dataclass(obj): + valid_fields = {f.name for f in dataclass_fields(obj)} + if field_name not in valid_fields: + raise AttributeError( + f"Settings field '{field_name}' does not exist on " + f"{type(obj).__name__}" + ) + return replace(obj, **{field_name: value}) + + if hasattr(obj, "model_copy") and callable(getattr(obj, "model_copy")): + return obj.model_copy(update={field_name: value}) # type: ignore[attr-defined] + + if hasattr(obj, "copy") and callable(getattr(obj, "copy")): + try: + return obj.copy(update={field_name: value}) # type: ignore[attr-defined] + except Exception: + pass + + if hasattr(obj, field_name): + patched = deepcopy(obj) + setattr(patched, field_name, value) + return patched + + raise TypeError(f"Cannot patch settings object of type {type(obj).__name__}") + def process(self, loop: asyncio.AbstractEventLoop) -> None: main_func = None context = GraphContext(self.graph_address) @@ -698,8 +738,10 @@ async def wrapped_task(msg: Any = None) -> None: except Exception: logger.error(f"Exception in Task: {task_address}") logger.error(traceback.format_exc()) - if self.term_ev.is_set(): - self._shutdown_errors = True + # Any task exception should mark shutdown as unclean so + # interrupt-driven teardown can return a non-zero exit code. + # Gating this on term_ev introduces timing-dependent behavior. + self._shutdown_errors = True if strict_shutdown: raise diff --git a/src/ezmsg/core/graphmeta.py b/src/ezmsg/core/graphmeta.py index fc1bd98c..44b1a7aa 100644 --- a/src/ezmsg/core/graphmeta.py +++ b/src/ezmsg/core/graphmeta.py @@ -83,6 +83,25 @@ class TaskMetadata: publishes: list[str] = field(default_factory=list) +@dataclass +class SettingsFieldMetadata: + name: str + field_type: str + required: bool + default: Any + description: str | None + bounds: tuple[float | None, float | None] | None + choices: list[Any] | None + widget_hint: str | None + + +@dataclass +class SettingsSchemaMetadata: + provider: str + settings_type: str + fields: list[SettingsFieldMetadata] + + SettingsReprType: TypeAlias = dict[str, Any] | str SerializedSettingsType: TypeAlias = bytes | None InitialSettingsType: TypeAlias = tuple[SerializedSettingsType, SettingsReprType] @@ -96,6 +115,7 @@ class ComponentMetadata: settings_type: str initial_settings: InitialSettingsType dynamic_settings: DynamicSettingsMetadata + settings_schema: SettingsSchemaMetadata | None @dataclass @@ -144,6 +164,8 @@ class ProcessOwnershipUpdate: class SettingsSnapshotValue: serialized: bytes | None repr_value: dict[str, Any] | str + structured_value: dict[str, Any] | None = None + settings_schema: SettingsSchemaMetadata | None = None class SettingsEventType(enum.Enum): diff --git a/src/ezmsg/core/graphserver.py b/src/ezmsg/core/graphserver.py index a578a95d..a2eb470b 100644 --- a/src/ezmsg/core/graphserver.py +++ b/src/ezmsg/core/graphserver.py @@ -1419,9 +1419,12 @@ def _apply_session_metadata_settings_locked( ) -> None: session_components: set[str] = set() for component in metadata.components.values(): + initial_repr = component.initial_settings[1] value = SettingsSnapshotValue( serialized=component.initial_settings[0], - repr_value=component.initial_settings[1], + repr_value=initial_repr, + structured_value=initial_repr if isinstance(initial_repr, dict) else None, + settings_schema=component.settings_schema, ) self._settings_current[component.address] = value self._settings_source_session[component.address] = session_id diff --git a/src/ezmsg/core/settingsmeta.py b/src/ezmsg/core/settingsmeta.py new file mode 100644 index 00000000..1260e8ef --- /dev/null +++ b/src/ezmsg/core/settingsmeta.py @@ -0,0 +1,329 @@ +from __future__ import annotations + +from dataclasses import MISSING, asdict, fields as dataclass_fields, is_dataclass +import enum +from collections.abc import Mapping +from typing import Any, get_args, get_origin + +from .graphmeta import SettingsFieldMetadata, SettingsSchemaMetadata + + +def _type_name(tp: object) -> str: + if isinstance(tp, type): + return f"{tp.__module__}.{tp.__qualname__}" + return str(tp) + + +def _sanitize(value: Any) -> Any: + if value is None or isinstance(value, (bool, int, float, str)): + return value + if isinstance(value, enum.Enum): + return _sanitize(value.value) + if isinstance(value, Mapping): + return {str(key): _sanitize(val) for key, val in value.items()} + if isinstance(value, (list, tuple, set, frozenset)): + return [_sanitize(val) for val in value] + if is_dataclass(value): + try: + return _sanitize(asdict(value)) + except Exception: + return repr(value) + return repr(value) + + +def settings_structured_value(value: object) -> dict[str, Any] | None: + if value is None: + return None + + if is_dataclass(value): + try: + asdict_value = asdict(value) + if isinstance(asdict_value, dict): + return _sanitize(asdict_value) + except Exception: + pass + + if hasattr(value, "model_dump") and callable(getattr(value, "model_dump")): + try: + dumped = value.model_dump() # type: ignore[attr-defined] + if isinstance(dumped, dict): + return _sanitize(dumped) + except Exception: + pass + + if hasattr(value, "dict") and callable(getattr(value, "dict")): + try: + dumped = value.dict() # type: ignore[attr-defined] + if isinstance(dumped, dict): + return _sanitize(dumped) + except Exception: + pass + + if isinstance(value, Mapping): + return _sanitize(dict(value)) + + if hasattr(value, "param"): + param_ns = getattr(value, "param") + if hasattr(param_ns, "values") and callable(param_ns.values): + try: + values = param_ns.values() + if isinstance(values, dict): + return _sanitize(values) + except Exception: + pass + + return None + + +def settings_repr_value(value: object) -> dict[str, Any] | str: + structured = settings_structured_value(value) + if structured is not None: + return structured + return repr(value) + + +def _widget_hint( + *, + field_type: str, + choices: list[Any] | None, + bounds: tuple[float | None, float | None] | None, +) -> str | None: + field_type_lower = field_type.lower() + if choices: + return "select" + if "bool" in field_type_lower: + return "checkbox" + if bounds is not None and ("int" in field_type_lower or "float" in field_type_lower): + return "slider" + if "int" in field_type_lower: + return "int_input" + if "float" in field_type_lower: + return "float_input" + if "str" in field_type_lower: + return "text_input" + return None + + +def _choices_from_annotation(annotation: Any) -> list[Any] | None: + origin = get_origin(annotation) + if origin is None: + return None + origin_name = getattr(origin, "__name__", str(origin)) + if origin_name != "Literal": + return None + return [_sanitize(val) for val in get_args(annotation)] + + +def _extract_bounds(obj: object) -> tuple[float | None, float | None] | None: + lower = None + upper = None + for attr in ("ge", "gt", "min_length"): + if hasattr(obj, attr): + bound_val = getattr(obj, attr) + if isinstance(bound_val, (int, float)): + lower = float(bound_val) + break + for attr in ("le", "lt", "max_length"): + if hasattr(obj, attr): + bound_val = getattr(obj, attr) + if isinstance(bound_val, (int, float)): + upper = float(bound_val) + break + if lower is None and upper is None: + return None + return (lower, upper) + + +def settings_schema_from_type(settings_type: object) -> SettingsSchemaMetadata | None: + if not isinstance(settings_type, type): + return None + + if is_dataclass(settings_type): + fields: list[SettingsFieldMetadata] = [] + for f in dataclass_fields(settings_type): + required = f.default is MISSING and f.default_factory is MISSING + default_val: Any | None = None + if not required: + if f.default is not MISSING: + default_val = _sanitize(f.default) + elif f.default_factory is not MISSING: + try: + default_val = _sanitize(f.default_factory()) + except Exception: + default_val = "" + metadata = f.metadata if isinstance(f.metadata, Mapping) else {} + description = metadata.get("description") + choices = metadata.get("choices") + if isinstance(choices, (list, tuple, set)): + choices = [_sanitize(val) for val in choices] + else: + choices = _choices_from_annotation(f.type) + bounds = None + ge = metadata.get("ge", metadata.get("min")) + le = metadata.get("le", metadata.get("max")) + if isinstance(ge, (int, float)) or isinstance(le, (int, float)): + bounds = ( + float(ge) if isinstance(ge, (int, float)) else None, + float(le) if isinstance(le, (int, float)) else None, + ) + field_type = _type_name(f.type) + fields.append( + SettingsFieldMetadata( + name=f.name, + field_type=field_type, + required=required, + default=default_val, + description=description if isinstance(description, str) else None, + bounds=bounds, + choices=choices if isinstance(choices, list) else None, + widget_hint=_widget_hint( + field_type=field_type, + choices=choices if isinstance(choices, list) else None, + bounds=bounds, + ), + ) + ) + return SettingsSchemaMetadata( + provider="dataclass", + settings_type=_type_name(settings_type), + fields=fields, + ) + + if hasattr(settings_type, "model_fields"): + model_fields = getattr(settings_type, "model_fields") + if isinstance(model_fields, dict): + fields: list[SettingsFieldMetadata] = [] + for name, field_info in model_fields.items(): + annotation = getattr(field_info, "annotation", Any) + is_required_attr = getattr(field_info, "is_required", None) + required = ( + bool(is_required_attr()) + if callable(is_required_attr) + else bool(is_required_attr) + ) + default_val = None + if not required: + default = getattr(field_info, "default", None) + default_val = _sanitize(default) + description = getattr(field_info, "description", None) + choices = _choices_from_annotation(annotation) + bounds = _extract_bounds(field_info) + field_type = _type_name(annotation) + fields.append( + SettingsFieldMetadata( + name=name, + field_type=field_type, + required=required, + default=default_val, + description=description if isinstance(description, str) else None, + bounds=bounds, + choices=choices, + widget_hint=_widget_hint( + field_type=field_type, choices=choices, bounds=bounds + ), + ) + ) + return SettingsSchemaMetadata( + provider="pydantic", + settings_type=_type_name(settings_type), + fields=fields, + ) + + if hasattr(settings_type, "__fields__"): + model_fields = getattr(settings_type, "__fields__") + if isinstance(model_fields, dict): + fields: list[SettingsFieldMetadata] = [] + for name, field_info in model_fields.items(): + annotation = getattr(field_info, "outer_type_", Any) + required = bool(getattr(field_info, "required", False)) + default_val = None if required else _sanitize(getattr(field_info, "default", None)) + fi = getattr(field_info, "field_info", None) + description = getattr(fi, "description", None) if fi is not None else None + choices = _choices_from_annotation(annotation) + bounds = _extract_bounds(fi if fi is not None else field_info) + field_type = _type_name(annotation) + fields.append( + SettingsFieldMetadata( + name=name, + field_type=field_type, + required=required, + default=default_val, + description=description if isinstance(description, str) else None, + bounds=bounds, + choices=choices, + widget_hint=_widget_hint( + field_type=field_type, choices=choices, bounds=bounds + ), + ) + ) + return SettingsSchemaMetadata( + provider="pydantic", + settings_type=_type_name(settings_type), + fields=fields, + ) + + param_ns = getattr(settings_type, "param", None) + if param_ns is not None and hasattr(param_ns, "objects"): + try: + objects = param_ns.objects("existing") + except Exception: + try: + objects = param_ns.objects() + except Exception: + objects = None + if isinstance(objects, dict): + fields: list[SettingsFieldMetadata] = [] + for name, param_obj in objects.items(): + if name == "name": + continue + choices_obj = getattr(param_obj, "objects", None) + choices = None + if isinstance(choices_obj, Mapping): + choices = [_sanitize(choice) for choice in choices_obj.keys()] + elif isinstance(choices_obj, (list, tuple, set)): + choices = [_sanitize(choice) for choice in choices_obj] + bounds_obj = getattr(param_obj, "bounds", None) + bounds = None + if ( + isinstance(bounds_obj, tuple) + and len(bounds_obj) == 2 + and all( + bound is None or isinstance(bound, (int, float)) + for bound in bounds_obj + ) + ): + bounds = ( + float(bounds_obj[0]) if isinstance(bounds_obj[0], (int, float)) else None, + float(bounds_obj[1]) if isinstance(bounds_obj[1], (int, float)) else None, + ) + default_val = _sanitize(getattr(param_obj, "default", None)) + description = getattr(param_obj, "doc", None) + field_type = _type_name(type(param_obj)) + fields.append( + SettingsFieldMetadata( + name=name, + field_type=field_type, + required=False, + default=default_val, + description=description if isinstance(description, str) else None, + bounds=bounds, + choices=choices, + widget_hint=_widget_hint( + field_type=field_type, choices=choices, bounds=bounds + ), + ) + ) + return SettingsSchemaMetadata( + provider="param", + settings_type=_type_name(settings_type), + fields=fields, + ) + + return None + + +def settings_schema_from_value(value: object) -> SettingsSchemaMetadata | None: + if value is None: + return None + return settings_schema_from_type(type(value)) + diff --git a/tests/shutdown_runner.py b/tests/shutdown_runner.py index c185a23a..39bb5124 100644 --- a/tests/shutdown_runner.py +++ b/tests/shutdown_runner.py @@ -8,11 +8,14 @@ import ezmsg.core as ez +STARTED = threading.Event() + class BlockingDiskIO(ez.Unit): @ez.task async def blocked_read(self) -> None: # Cross-platform "hung disk I/O" simulation. + STARTED.set() event = threading.Event() self._event = event await asyncio.shield(asyncio.to_thread(event.wait)) @@ -21,6 +24,7 @@ async def blocked_read(self) -> None: class BlockingSocket(ez.Unit): @ez.task async def blocked_recv(self) -> None: + STARTED.set() sock_r, sock_w = socket.socketpair() sock_r.setblocking(True) sock_w.setblocking(True) @@ -33,6 +37,7 @@ async def blocked_recv(self) -> None: class ExplodeOnCancel(ez.Unit): @ez.task async def explode(self) -> None: + STARTED.set() try: while True: await asyncio.sleep(1.0) @@ -43,6 +48,7 @@ async def explode(self) -> None: class StubbornTask(ez.Unit): @ez.task async def ignore_cancel(self) -> None: + STARTED.set() while True: try: await asyncio.sleep(1.0) @@ -84,7 +90,7 @@ def _emit_ready() -> None: def _watch_ready() -> None: while not done.is_set(): - if runner.running: + if runner.running and STARTED.is_set(): _emit_ready() return time.sleep(0.01) diff --git a/tests/test_settings_api.py b/tests/test_settings_api.py index 2920374a..6972e127 100644 --- a/tests/test_settings_api.py +++ b/tests/test_settings_api.py @@ -35,6 +35,7 @@ def _metadata_with_component(component_address: str) -> GraphMetadata: input_topic=f"{component_address}/INPUT_SETTINGS", settings_type="example.Settings", ), + settings_schema=None, ) }, ) @@ -58,6 +59,8 @@ async def test_settings_snapshot_and_events_from_metadata_registration(): settings = await observer.settings_snapshot() assert component_address in settings assert settings[component_address].repr_value == {"alpha": 1} + assert settings[component_address].structured_value == {"alpha": 1} + assert settings[component_address].settings_schema is None events = await observer.settings_events(after_seq=0) matching = [ @@ -125,6 +128,15 @@ async def observe() -> None: sink_address = "SYS/SINK" assert sink_address in settings assert settings[sink_address].repr_value == {"gain": 7} + assert settings[sink_address].structured_value == {"gain": 7} + assert settings[sink_address].settings_schema is not None + schema = settings[sink_address].settings_schema + assert schema is not None + assert schema.provider == "dataclass" + assert any( + field.name == "gain" and "int" in field.field_type.lower() + for field in schema.fields + ) events = await observer.settings_events(after_seq=0) matching = [ From 038843b575f1094383f5fcc767b6da0a2cc43d27 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Mon, 16 Mar 2026 19:51:41 -0400 Subject: [PATCH 22/52] a few bugfixes and some code condensation --- src/ezmsg/core/backendprocess.py | 14 +- src/ezmsg/core/graphcontext.py | 104 ++++------ src/ezmsg/core/graphserver.py | 335 +++++++++++++------------------ src/ezmsg/core/netprotocol.py | 1 + src/ezmsg/core/processclient.py | 6 +- src/ezmsg/core/subclient.py | 10 +- tests/test_process_control.py | 78 +++++++ tests/test_profiling_api.py | 60 ++++++ tests/test_settings_api.py | 37 ++++ 9 files changed, 368 insertions(+), 277 deletions(-) diff --git a/src/ezmsg/core/backendprocess.py b/src/ezmsg/core/backendprocess.py index c2355232..3fc6a61f 100644 --- a/src/ezmsg/core/backendprocess.py +++ b/src/ezmsg/core/backendprocess.py @@ -395,11 +395,12 @@ async def process_request_handler( update_obj.field_path, update_obj.value, ) - current_settings[unit_address] = patched control_pub = control_publishers.get(input_topic) if control_pub is None: control_pub = await context.publisher(input_topic) control_publishers[input_topic] = control_pub + await control_pub.broadcast(patched) + current_settings[unit_address] = patched except Exception as exc: return ProcessControlResponse( request_id=request.request_id, @@ -409,17 +410,6 @@ async def process_request_handler( process_id=process_client.process_id, ) - async def publish_patched_settings() -> None: - try: - await control_pub.broadcast(patched) - except Exception as exc: - logger.warning( - "Failed to publish patched settings for %s: %s", - unit_address, - exc, - ) - - asyncio.create_task(publish_patched_settings()) result_value = self._settings_snapshot_value(patched) return ProcessControlResponse( request_id=request.request_id, diff --git a/src/ezmsg/core/graphcontext.py b/src/ezmsg/core/graphcontext.py index 1cdb494c..cd1e9d48 100644 --- a/src/ezmsg/core/graphcontext.py +++ b/src/ezmsg/core/graphcontext.py @@ -483,62 +483,28 @@ async def subscribe_settings_events( *, after_seq: int = 0, ) -> typing.AsyncIterator[SettingsChangedEvent]: - reader, writer = await GraphService(self.graph_address).open_connection() - writer.write(Command.SESSION_SETTINGS_SUBSCRIBE.value) - writer.write(encode_str(str(after_seq))) - await writer.drain() - - _subscriber_id = UUID(await read_str(reader)) - response = await reader.read(1) - if response != Command.COMPLETE.value: - await close_stream_writer(writer) - raise RuntimeError("Failed to subscribe to settings events") - - try: - while True: - payload_size = await read_int(reader) - payload = await reader.readexactly(payload_size) - event = pickle.loads(payload) - if not isinstance(event, SettingsChangedEvent): - raise RuntimeError( - "Settings subscription received invalid event payload" - ) - yield event - except asyncio.IncompleteReadError: - return - finally: - await close_stream_writer(writer) + async for event in self._subscribe_pickled_stream( + command=Command.SESSION_SETTINGS_SUBSCRIBE, + setup_payload=encode_str(str(after_seq)), + expected_type=SettingsChangedEvent, + subscribe_error="Failed to subscribe to settings events", + payload_error="Settings subscription received invalid event payload", + ): + yield typing.cast(SettingsChangedEvent, event) async def subscribe_topology_events( self, *, after_seq: int = 0, ) -> typing.AsyncIterator[TopologyChangedEvent]: - reader, writer = await GraphService(self.graph_address).open_connection() - writer.write(Command.SESSION_TOPOLOGY_SUBSCRIBE.value) - writer.write(encode_str(str(after_seq))) - await writer.drain() - - _subscriber_id = UUID(await read_str(reader)) - response = await reader.read(1) - if response != Command.COMPLETE.value: - await close_stream_writer(writer) - raise RuntimeError("Failed to subscribe to topology events") - - try: - while True: - payload_size = await read_int(reader) - payload = await reader.readexactly(payload_size) - event = pickle.loads(payload) - if not isinstance(event, TopologyChangedEvent): - raise RuntimeError( - "Topology subscription received invalid event payload" - ) - yield event - except asyncio.IncompleteReadError: - return - finally: - await close_stream_writer(writer) + async for event in self._subscribe_pickled_stream( + command=Command.SESSION_TOPOLOGY_SUBSCRIBE, + setup_payload=encode_str(str(after_seq)), + expected_type=TopologyChangedEvent, + subscribe_error="Failed to subscribe to topology events", + payload_error="Topology subscription received invalid event payload", + ): + yield typing.cast(TopologyChangedEvent, event) async def subscribe_profiling_trace( self, @@ -547,29 +513,45 @@ async def subscribe_profiling_trace( """ Subscribe to streamed profiling trace batches from GraphServer. """ - reader, writer = await GraphService(self.graph_address).open_connection() payload = pickle.dumps(control) - writer.write(Command.SESSION_PROFILING_SUBSCRIBE.value) - writer.write(uint64_to_bytes(len(payload))) - writer.write(payload) + setup_payload = uint64_to_bytes(len(payload)) + payload + async for batch in self._subscribe_pickled_stream( + command=Command.SESSION_PROFILING_SUBSCRIBE, + setup_payload=setup_payload, + expected_type=ProfilingTraceStreamBatch, + subscribe_error="Failed to subscribe to profiling trace stream", + payload_error="Profiling subscription received invalid batch payload", + ): + yield typing.cast(ProfilingTraceStreamBatch, batch) + + async def _subscribe_pickled_stream( + self, + *, + command: Command, + setup_payload: bytes, + expected_type: type[object], + subscribe_error: str, + payload_error: str, + ) -> typing.AsyncIterator[object]: + reader, writer = await GraphService(self.graph_address).open_connection() + writer.write(command.value) + writer.write(setup_payload) await writer.drain() _subscriber_id = UUID(await read_str(reader)) response = await reader.read(1) if response != Command.COMPLETE.value: await close_stream_writer(writer) - raise RuntimeError("Failed to subscribe to profiling trace stream") + raise RuntimeError(subscribe_error) try: while True: payload_size = await read_int(reader) payload = await reader.readexactly(payload_size) - batch = pickle.loads(payload) - if not isinstance(batch, ProfilingTraceStreamBatch): - raise RuntimeError( - "Profiling subscription received invalid batch payload" - ) - yield batch + value = pickle.loads(payload) + if not isinstance(value, expected_type): + raise RuntimeError(payload_error) + yield value except asyncio.IncompleteReadError: return finally: diff --git a/src/ezmsg/core/graphserver.py b/src/ezmsg/core/graphserver.py index a2eb470b..37adebc3 100644 --- a/src/ezmsg/core/graphserver.py +++ b/src/ezmsg/core/graphserver.py @@ -6,6 +6,7 @@ import threading import time from collections import deque +from collections.abc import Sequence from contextlib import suppress from uuid import UUID, uuid1 @@ -98,15 +99,16 @@ class GraphServer(threading.Thread): _settings_events: list[SettingsChangedEvent] _settings_event_seq: int _settings_owned_by_session: dict[UUID, set[str]] - _settings_subscribers: dict[UUID, asyncio.Queue[SettingsChangedEvent]] + _settings_subscribers: dict[UUID, asyncio.Queue[object]] _topology_events: list[TopologyChangedEvent] _topology_event_seq: int - _topology_subscribers: dict[UUID, asyncio.Queue[TopologyChangedEvent]] + _topology_subscribers: dict[UUID, asyncio.Queue[object]] _pending_process_requests: dict[ str, tuple[UUID, "asyncio.Future[ProcessControlResponse]"] ] - _profiling_trace_buffers: dict[str, deque[ProfilingTraceSample]] + _profiling_trace_buffers: dict[str, deque[tuple[int, ProfilingTraceSample]]] _profiling_trace_process_meta: dict[str, tuple[int, str]] + _profiling_trace_seq: dict[str, int] def __init__(self, **kwargs) -> None: super().__init__( @@ -134,6 +136,7 @@ def __init__(self, **kwargs) -> None: self._pending_process_requests = {} self._profiling_trace_buffers = {} self._profiling_trace_process_meta = {} + self._profiling_trace_seq = {} @property def address(self) -> Address: @@ -711,6 +714,7 @@ async def _handle_process( ) self._profiling_trace_buffers.pop(source_process_id, None) self._profiling_trace_process_meta.pop(source_process_id, None) + self._profiling_trace_seq.pop(source_process_id, None) self._append_topology_event_locked( event_type=TopologyEventType.PROCESS_CHANGED, changed_topics=[], @@ -738,20 +742,33 @@ async def _write_process_response( writer.write(response) await writer.drain() - def _queue_settings_event( - self, queue: asyncio.Queue[SettingsChangedEvent], event: SettingsChangedEvent - ) -> None: + async def _read_pickled_payload(self, reader: asyncio.StreamReader) -> object: + payload_size = await read_int(reader) + payload = await reader.readexactly(payload_size) + return pickle.loads(payload) + + async def _read_typed_payload( + self, + reader: asyncio.StreamReader, + expected_type: type[object], + *, + log_prefix: str, + ) -> object | None: try: - queue.put_nowait(event) - except asyncio.QueueFull: - # Keep most recent samples under backpressure. - with suppress(asyncio.QueueEmpty): - queue.get_nowait() - with suppress(asyncio.QueueFull): - queue.put_nowait(event) + payload_obj = await self._read_pickled_payload(reader) + if not isinstance(payload_obj, expected_type): + raise RuntimeError( + f"payload was not {expected_type.__name__}: {type(payload_obj).__name__}" + ) + return payload_obj + except Exception as exc: + logger.warning("%s parse failed; ignoring payload: %s", log_prefix, exc) + return None - def _queue_topology_event( - self, queue: asyncio.Queue[TopologyChangedEvent], event: TopologyChangedEvent + def _queue_stream_event( + self, + queue: asyncio.Queue[object], + event: object, ) -> None: try: queue.put_nowait(event) @@ -762,11 +779,12 @@ def _queue_topology_event( with suppress(asyncio.QueueFull): queue.put_nowait(event) - async def _settings_sender( + async def _stream_sender( self, subscriber_id: UUID, - queue: asyncio.Queue[SettingsChangedEvent], + queue: asyncio.Queue[object], writer: asyncio.StreamWriter, + label: str, ) -> None: try: while True: @@ -776,28 +794,31 @@ async def _settings_sender( writer.write(payload) await writer.drain() except (ConnectionResetError, BrokenPipeError): - logger.debug(f"Settings subscriber {subscriber_id} disconnected on send") + logger.debug(f"{label} subscriber {subscriber_id} disconnected on send") except asyncio.CancelledError: raise - async def _handle_settings_subscriber( + async def _handle_event_subscriber( self, + *, subscriber_id: UUID, after_seq: int, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, + queue: asyncio.Queue[object], + subscribers: dict[UUID, asyncio.Queue[object]], + events: Sequence[object], + label: str, ) -> None: - queue: asyncio.Queue[SettingsChangedEvent] = asyncio.Queue(maxsize=1024) - async with self._command_lock: - self._settings_subscribers[subscriber_id] = queue - for event in self._settings_events: - if event.seq > after_seq: - self._queue_settings_event(queue, event) + subscribers[subscriber_id] = queue + for event in events: + if getattr(event, "seq", 0) > after_seq: + self._queue_stream_event(queue, event) sender_task = asyncio.create_task( - self._settings_sender(subscriber_id, queue, writer), - name=f"settings-sender-{subscriber_id}", + self._stream_sender(subscriber_id, queue, writer, label), + name=f"{label}-sender-{subscriber_id}", ) try: @@ -806,33 +827,34 @@ async def _handle_settings_subscriber( if not req: break except (ConnectionResetError, BrokenPipeError) as e: - logger.debug(f"Settings subscriber {subscriber_id} disconnected: {e}") + logger.debug(f"{label} subscriber {subscriber_id} disconnected: {e}") finally: async with self._command_lock: - self._settings_subscribers.pop(subscriber_id, None) + subscribers.pop(subscriber_id, None) self._client_tasks.pop(subscriber_id, None) sender_task.cancel() with suppress(asyncio.CancelledError): await sender_task await close_stream_writer(writer) - async def _topology_sender( + async def _handle_settings_subscriber( self, subscriber_id: UUID, - queue: asyncio.Queue[TopologyChangedEvent], + after_seq: int, + reader: asyncio.StreamReader, writer: asyncio.StreamWriter, ) -> None: - try: - while True: - event = await queue.get() - payload = pickle.dumps(event) - writer.write(uint64_to_bytes(len(payload))) - writer.write(payload) - await writer.drain() - except (ConnectionResetError, BrokenPipeError): - logger.debug(f"Topology subscriber {subscriber_id} disconnected on send") - except asyncio.CancelledError: - raise + queue: asyncio.Queue[object] = asyncio.Queue(maxsize=1024) + await self._handle_event_subscriber( + subscriber_id=subscriber_id, + after_seq=after_seq, + reader=reader, + writer=writer, + queue=queue, + subscribers=self._settings_subscribers, + events=self._settings_events, + label="settings", + ) async def _handle_topology_subscriber( self, @@ -841,48 +863,22 @@ async def _handle_topology_subscriber( reader: asyncio.StreamReader, writer: asyncio.StreamWriter, ) -> None: - queue: asyncio.Queue[TopologyChangedEvent] = asyncio.Queue(maxsize=1024) - - async with self._command_lock: - self._topology_subscribers[subscriber_id] = queue - for event in self._topology_events: - if event.seq > after_seq: - self._queue_topology_event(queue, event) - - sender_task = asyncio.create_task( - self._topology_sender(subscriber_id, queue, writer), - name=f"topology-sender-{subscriber_id}", + queue: asyncio.Queue[object] = asyncio.Queue(maxsize=1024) + await self._handle_event_subscriber( + subscriber_id=subscriber_id, + after_seq=after_seq, + reader=reader, + writer=writer, + queue=queue, + subscribers=self._topology_subscribers, + events=self._topology_events, + label="topology", ) - try: - while True: - req = await reader.read(1) - if not req: - break - except (ConnectionResetError, BrokenPipeError) as e: - logger.debug(f"Topology subscriber {subscriber_id} disconnected: {e}") - finally: - async with self._command_lock: - self._topology_subscribers.pop(subscriber_id, None) - self._client_tasks.pop(subscriber_id, None) - sender_task.cancel() - with suppress(asyncio.CancelledError): - await sender_task - await close_stream_writer(writer) - async def _read_profiling_stream_control( self, reader: asyncio.StreamReader ) -> ProfilingStreamControl: - payload_size = await read_int(reader) - payload = await reader.readexactly(payload_size) - - try: - payload_obj = pickle.loads(payload) - except Exception as exc: - raise RuntimeError( - f"Invalid profiling stream control payload: {exc}" - ) from exc - + payload_obj = await self._read_pickled_payload(reader) if not isinstance(payload_obj, ProfilingStreamControl): raise RuntimeError( "Invalid profiling stream control payload type: " @@ -894,6 +890,7 @@ async def _collect_profiling_trace_stream_batch( self, *, stream_control: ProfilingStreamControl, + last_seq_by_process: dict[str, int], ) -> ProfilingTraceStreamBatch: process_ids_filter = ( set(stream_control.process_ids) @@ -925,8 +922,19 @@ async def _collect_profiling_trace_stream_batch( for process_id in process_ids: sample_buffer = self._profiling_trace_buffers.get(process_id) samples: list[ProfilingTraceSample] = [] - while sample_buffer and len(samples) < max_samples: - samples.append(sample_buffer.popleft()) + if sample_buffer: + last_seq = last_seq_by_process.get(process_id, 0) + oldest_seq = sample_buffer[0][0] + if last_seq < oldest_seq - 1: + last_seq = oldest_seq - 1 + for seq, sample in sample_buffer: + if seq <= last_seq: + continue + samples.append(sample) + last_seq = seq + if len(samples) >= max_samples: + break + last_seq_by_process[process_id] = last_seq if len(samples) == 0 and not stream_control.include_empty_batches: continue @@ -953,6 +961,7 @@ async def _handle_profiling_subscriber( writer: asyncio.StreamWriter, ) -> None: interval = max(0.01, float(stream_control.interval)) + last_seq_by_process: dict[str, int] = {} try: while True: try: @@ -966,6 +975,7 @@ async def _handle_profiling_subscriber( batch = await self._collect_profiling_trace_stream_batch( stream_control=stream_control, + last_seq_by_process=last_seq_by_process, ) if len(batch.batches) == 0: continue @@ -991,24 +1001,11 @@ async def _handle_profiling_subscriber( async def _handle_process_profiling_trace_update_request( self, process_client_id: UUID, reader: asyncio.StreamReader ) -> None: - payload_size = await read_int(reader) - payload = await reader.readexactly(payload_size) - batch: ProcessProfilingTraceBatch | None = None - try: - payload_obj = pickle.loads(payload) - if isinstance(payload_obj, ProcessProfilingTraceBatch): - batch = payload_obj - else: - raise RuntimeError( - "process profiling trace payload was not ProcessProfilingTraceBatch" - ) - except Exception as exc: - logger.warning( - "Process control %s trace update parse failed; ignoring payload: %s", - process_client_id, - exc, - ) - + batch = await self._read_typed_payload( + reader, + ProcessProfilingTraceBatch, + log_prefix=f"Process control {process_client_id} trace update", + ) if batch is None: return @@ -1026,37 +1023,28 @@ async def _handle_process_profiling_trace_update_request( trace_buffer = self._profiling_trace_buffers.setdefault( process_id, deque(maxlen=200_000) ) - trace_buffer.extend(batch.samples) + next_seq = self._profiling_trace_seq.get(process_id, 0) + for sample in batch.samples: + next_seq += 1 + trace_buffer.append((next_seq, sample)) + self._profiling_trace_seq[process_id] = next_seq self._profiling_trace_process_meta[process_id] = (batch.pid, batch.host) async def _handle_process_register_request( self, process_client_id: UUID, reader: asyncio.StreamReader ) -> bytes: - num_bytes = await read_int(reader) - payload = await reader.readexactly(num_bytes) - registration: ProcessRegistration | None = None - try: - payload_obj = pickle.loads(payload) - if isinstance(payload_obj, ProcessRegistration): - registration = payload_obj - else: - raise RuntimeError( - "process registration payload was not ProcessRegistration" - ) - except Exception as exc: - logger.warning( - "Process control %s registration parse failed; ignoring payload: %s", - process_client_id, - exc, - ) - + registration = await self._read_typed_payload( + reader, + ProcessRegistration, + log_prefix=f"Process control {process_client_id} registration", + ) if registration is None: - return Command.COMPLETE.value + return Command.ERROR.value async with self._command_lock: process_info = self._process_info(process_client_id) if process_info is None: - return Command.COMPLETE.value + return Command.ERROR.value prev_units = set(process_info.units) prev_process_id = process_info.process_id @@ -1067,6 +1055,7 @@ async def _handle_process_register_request( if prev_process_id is not None and prev_process_id != registration.process_id: self._profiling_trace_buffers.pop(prev_process_id, None) self._profiling_trace_process_meta.pop(prev_process_id, None) + self._profiling_trace_seq.pop(prev_process_id, None) if prev_units != process_info.units: self._append_topology_event_locked( event_type=TopologyEventType.PROCESS_CHANGED, @@ -1080,31 +1069,18 @@ async def _handle_process_register_request( async def _handle_process_update_ownership_request( self, process_client_id: UUID, reader: asyncio.StreamReader ) -> bytes: - num_bytes = await read_int(reader) - payload = await reader.readexactly(num_bytes) - update: ProcessOwnershipUpdate | None = None - try: - update_obj = pickle.loads(payload) - if isinstance(update_obj, ProcessOwnershipUpdate): - update = update_obj - else: - raise RuntimeError( - "process ownership payload was not ProcessOwnershipUpdate" - ) - except Exception as exc: - logger.warning( - "Process control %s ownership update parse failed; ignoring payload: %s", - process_client_id, - exc, - ) - + update = await self._read_typed_payload( + reader, + ProcessOwnershipUpdate, + log_prefix=f"Process control {process_client_id} ownership update", + ) if update is None: - return Command.COMPLETE.value + return Command.ERROR.value async with self._command_lock: process_info = self._process_info(process_client_id) if process_info is None: - return Command.COMPLETE.value + return Command.ERROR.value if ( process_info.process_id is not None @@ -1135,35 +1111,27 @@ async def _handle_process_update_ownership_request( async def _handle_process_settings_update_request( self, process_client_id: UUID, reader: asyncio.StreamReader ) -> bytes: - num_bytes = await read_int(reader) - payload = await reader.readexactly(num_bytes) - update: ProcessSettingsUpdate | None = None - try: - update_obj = pickle.loads(payload) - if isinstance(update_obj, ProcessSettingsUpdate): - update = update_obj - else: - raise RuntimeError( - "process settings payload was not ProcessSettingsUpdate" - ) - except Exception as exc: - logger.warning( - "Process control %s settings update parse failed; ignoring payload: %s", - process_client_id, - exc, - ) - + update = await self._read_typed_payload( + reader, + ProcessSettingsUpdate, + log_prefix=f"Process control {process_client_id} settings update", + ) if update is None: - return Command.COMPLETE.value + return Command.ERROR.value async with self._command_lock: process_info = self._process_info(process_client_id) if process_info is None: - return Command.COMPLETE.value + return Command.ERROR.value if process_info.process_id is None: process_info.process_id = update.process_id + source_process_id = ( + process_info.process_id + if process_info.process_id is not None + else update.process_id + ) self._settings_current[update.component_address] = update.value self._settings_source_session[update.component_address] = None self._append_settings_event_locked( @@ -1171,7 +1139,7 @@ async def _handle_process_settings_update_request( component_address=update.component_address, value=update.value, source_session_id=None, - source_process_id=update.process_id, + source_process_id=source_process_id, timestamp=update.timestamp, ) @@ -1180,24 +1148,11 @@ async def _handle_process_settings_update_request( async def _handle_process_route_response_request( self, process_client_id: UUID, reader: asyncio.StreamReader ) -> None: - num_bytes = await read_int(reader) - payload = await reader.readexactly(num_bytes) - response: ProcessControlResponse | None = None - try: - response_obj = pickle.loads(payload) - if isinstance(response_obj, ProcessControlResponse): - response = response_obj - else: - raise RuntimeError( - "process route response payload was not ProcessControlResponse" - ) - except Exception as exc: - logger.warning( - "Process control %s route response parse failed; ignoring payload: %s", - process_client_id, - exc, - ) - + response = await self._read_typed_payload( + reader, + ProcessControlResponse, + log_prefix=f"Process control {process_client_id} route response", + ) if response is None: return @@ -1374,7 +1329,7 @@ def _append_settings_event_locked( self._settings_events.append(event) for queue in self._settings_subscribers.values(): - self._queue_settings_event(queue, event) + self._queue_stream_event(queue, event) # Bound memory growth for long-lived servers. max_events = 10_000 @@ -1401,7 +1356,7 @@ def _append_topology_event_locked( self._topology_events.append(event) for queue in self._topology_subscribers.values(): - self._queue_topology_event(queue, event) + self._queue_stream_event(queue, event) max_events = 10_000 if len(self._topology_events) > max_events: @@ -1482,19 +1437,11 @@ async def _handle_session_clear_request(self, session_id: UUID) -> bytes: async def _handle_session_register_request( self, session_id: UUID, reader: asyncio.StreamReader ) -> bytes: - num_bytes = await read_int(reader) - payload = await reader.readexactly(num_bytes) - metadata: GraphMetadata | None = None - try: - metadata_obj = pickle.loads(payload) - if isinstance(metadata_obj, GraphMetadata): - metadata = metadata_obj - else: - raise RuntimeError("metadata payload was not GraphMetadata") - except Exception as exc: - logger.warning( - f"Session {session_id} metadata parse failed; ignoring payload: {exc}" - ) + metadata = await self._read_typed_payload( + reader, + GraphMetadata, + log_prefix=f"Session {session_id} metadata", + ) async with self._command_lock: session = self._session_info(session_id) diff --git a/src/ezmsg/core/netprotocol.py b/src/ezmsg/core/netprotocol.py index 13c582b3..67d9c1da 100644 --- a/src/ezmsg/core/netprotocol.py +++ b/src/ezmsg/core/netprotocol.py @@ -346,6 +346,7 @@ def _generate_next_value_(name, start, count, last_values) -> bytes: PROCESS_PROFILING_TRACE_UPDATE = enum.auto() PROCESS_ROUTE_REQUEST = enum.auto() PROCESS_ROUTE_RESPONSE = enum.auto() + ERROR = enum.auto() def create_socket( diff --git a/src/ezmsg/core/processclient.py b/src/ezmsg/core/processclient.py index d5279395..e3caeacc 100644 --- a/src/ezmsg/core/processclient.py +++ b/src/ezmsg/core/processclient.py @@ -218,6 +218,10 @@ async def _write_payload( ) from exc if response != Command.COMPLETE.value: + if response == Command.ERROR.value: + raise RuntimeError( + f"Process control command failed: {command.name}" + ) raise RuntimeError( f"Unexpected response to process control command: {command.name}" ) @@ -234,7 +238,7 @@ async def _io_loop(self) -> None: if not req: break - if req == Command.COMPLETE.value: + if req in (Command.COMPLETE.value, Command.ERROR.value): self._ack_queue.put_nowait(req) continue diff --git a/src/ezmsg/core/subclient.py b/src/ezmsg/core/subclient.py index 50a7c3c1..0f650421 100644 --- a/src/ezmsg/core/subclient.py +++ b/src/ezmsg/core/subclient.py @@ -3,7 +3,7 @@ import typing from uuid import UUID -from contextlib import asynccontextmanager, contextmanager, suppress +from contextlib import asynccontextmanager, suppress from copy import deepcopy from .graphserver import GraphService @@ -315,11 +315,3 @@ def begin_profile(self) -> int: def end_profile(self, start_ns: int, label: str | None = None) -> None: end_ns = PROFILE_TIME() PROFILES.subscriber_user_span(self.id, end_ns, end_ns - start_ns, label) - - @contextmanager - def profile_span(self, label: str | None = None) -> typing.Generator[None, None, None]: - start_ns = self.begin_profile() - try: - yield - finally: - self.end_profile(start_ns, label=label) diff --git a/tests/test_process_control.py b/tests/test_process_control.py index dbf70a2d..ad5db53b 100644 --- a/tests/test_process_control.py +++ b/tests/test_process_control.py @@ -1,10 +1,13 @@ import asyncio +import pickle import pytest from ezmsg.core.graphcontext import GraphContext +from ezmsg.core.graphmeta import ProcessRegistration from ezmsg.core.processclient import ProcessControlClient from ezmsg.core.graphserver import GraphService +from ezmsg.core.netprotocol import Command, close_stream_writer, read_str, uint64_to_bytes @pytest.mark.asyncio @@ -64,3 +67,78 @@ async def test_process_snapshot_entry_drops_on_disconnect(): await process.close() await observer.__aexit__(None, None, None) graph_server.stop() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "command", + [ + Command.PROCESS_REGISTER, + Command.PROCESS_UPDATE_OWNERSHIP, + Command.PROCESS_SETTINGS_UPDATE, + ], +) +async def test_process_payload_parse_failures_return_error_ack(command: Command): + graph_server = GraphService().create_server() + address = graph_server.address + + reader, writer = await GraphService(address).open_connection() + try: + writer.write(Command.PROCESS.value) + await writer.drain() + _client_id = await read_str(reader) + response = await reader.read(1) + assert response == Command.COMPLETE.value + + # Non-pickled bytes intentionally trigger parse failure in process handlers. + bad_payload = b"not-a-pickle-payload" + writer.write(command.value) + writer.write(uint64_to_bytes(len(bad_payload))) + writer.write(bad_payload) + await writer.drain() + + response = await reader.read(1) + assert response == Command.ERROR.value + finally: + await close_stream_writer(writer) + graph_server.stop() + + +@pytest.mark.asyncio +async def test_process_register_succeeds_after_error_ack(): + graph_server = GraphService().create_server() + address = graph_server.address + + reader, writer = await GraphService(address).open_connection() + try: + writer.write(Command.PROCESS.value) + await writer.drain() + _client_id = await read_str(reader) + response = await reader.read(1) + assert response == Command.COMPLETE.value + + bad_payload = b"not-a-pickle-payload" + writer.write(Command.PROCESS_REGISTER.value) + writer.write(uint64_to_bytes(len(bad_payload))) + writer.write(bad_payload) + await writer.drain() + response = await reader.read(1) + assert response == Command.ERROR.value + + good_payload = pickle.dumps( + ProcessRegistration( + process_id="proc-after-error", + pid=123, + host="test-host", + units=["SYS/U1"], + ) + ) + writer.write(Command.PROCESS_REGISTER.value) + writer.write(uint64_to_bytes(len(good_payload))) + writer.write(good_payload) + await writer.drain() + response = await reader.read(1) + assert response == Command.COMPLETE.value + finally: + await close_stream_writer(writer) + graph_server.stop() diff --git a/tests/test_profiling_api.py b/tests/test_profiling_api.py index 0449520b..67fcf9ed 100644 --- a/tests/test_profiling_api.py +++ b/tests/test_profiling_api.py @@ -315,3 +315,63 @@ async def test_process_profiling_trace_subscription_stream_control(): await process_a.close() await ctx.__aexit__(None, None, None) graph_server.stop() + + +@pytest.mark.asyncio +async def test_process_profiling_trace_subscription_does_not_starve_peer_subscribers(): + graph_server = GraphService().create_server() + address = graph_server.address + + ctx = GraphContext(address, auto_start=False) + await ctx.__aenter__() + + process = ProcessControlClient(address, process_id="proc-stream-multi") + await process.connect() + await process.register(["SYS/U7"]) + + pub = await ctx.publisher("TOPIC_STREAM_MULTI") + sub = await ctx.subscriber("TOPIC_STREAM_MULTI") + stream_a = None + stream_b = None + + try: + response = await ctx.process_set_profiling_trace( + "SYS/U7", + ProfilingTraceControl( + enabled=True, + sample_mod=1, + publisher_topics=["TOPIC_STREAM_MULTI"], + subscriber_topics=["TOPIC_STREAM_MULTI"], + ), + timeout=1.0, + ) + assert response.ok + + control = ProfilingStreamControl(interval=0.02, max_samples=256) + stream_a = ctx.subscribe_profiling_trace(control) + stream_b = ctx.subscribe_profiling_trace(control) + + for idx in range(12): + await pub.broadcast(idx) + async with sub.recv_zero_copy() as _msg: + span_start_ns = sub.begin_profile() + try: + await asyncio.sleep(0) + finally: + sub.end_profile(span_start_ns, "taskA") + + batch_a = await asyncio.wait_for(anext(stream_a), timeout=1.0) + batch_b = await asyncio.wait_for(anext(stream_b), timeout=1.0) + + assert "proc-stream-multi" in batch_a.batches + assert "proc-stream-multi" in batch_b.batches + assert len(batch_a.batches["proc-stream-multi"].samples) > 0 + assert len(batch_b.batches["proc-stream-multi"].samples) > 0 + finally: + if stream_a is not None: + await stream_a.aclose() + if stream_b is not None: + await stream_b.aclose() + await process.close() + await ctx.__aexit__(None, None, None) + graph_server.stop() diff --git a/tests/test_settings_api.py b/tests/test_settings_api.py index 6972e127..e987a98a 100644 --- a/tests/test_settings_api.py +++ b/tests/test_settings_api.py @@ -1,5 +1,6 @@ import asyncio import pickle +import time from dataclasses import dataclass import pytest @@ -10,6 +11,7 @@ ComponentMetadata, DynamicSettingsMetadata, GraphMetadata, + ProcessControlErrorCode, ProcessControlResponse, SettingsFieldUpdateRequest, SettingsEventType, @@ -231,6 +233,41 @@ async def handler(request) -> ProcessControlResponse: graph_server.stop() +@pytest.mark.asyncio +async def test_graphcontext_update_setting_waits_and_propagates_process_failure(): + graph_server = GraphService().create_server() + address = graph_server.address + + observer = GraphContext(address, auto_start=False) + await observer.__aenter__() + process = ProcessControlClient(address, process_id="proc-setting-fail") + await process.connect() + await process.register(["SYS/SINK"]) + + try: + async def handler(request) -> ProcessControlResponse: + assert request.operation == "UPDATE_SETTING_FIELD" + await asyncio.sleep(0.05) + return ProcessControlResponse( + request_id=request.request_id, + ok=False, + error="Simulated publish failure", + error_code=ProcessControlErrorCode.HANDLER_ERROR, + ) + + process.set_request_handler(handler) + + start = time.perf_counter() + with pytest.raises(RuntimeError, match="Simulated publish failure"): + await observer.update_setting("SYS/SINK", "gain", 99, timeout=1.0) + elapsed = time.perf_counter() - start + assert elapsed >= 0.04 + finally: + await process.close() + await observer.__aexit__(None, None, None) + graph_server.stop() + + @pytest.mark.asyncio async def test_process_reported_settings_update_visible_in_snapshot_and_events(): graph_server = GraphService().create_server() From 01780c23e02b56f7105f63e65a5e8d2b16ff7c88 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Mon, 16 Mar 2026 20:34:38 -0400 Subject: [PATCH 23/52] using UUIDs and errors on unit collision --- examples/profiling_tui.py | 22 +++-- examples/topology_tui.py | 3 +- src/ezmsg/core/backendprocess.py | 1 - src/ezmsg/core/graphcontext.py | 4 +- src/ezmsg/core/graphmeta.py | 26 +++--- src/ezmsg/core/graphserver.py | 154 ++++++++++++++++++------------- src/ezmsg/core/netprotocol.py | 1 - src/ezmsg/core/processclient.py | 54 ++++++----- src/ezmsg/core/profiling.py | 4 +- tests/test_process_control.py | 50 +++++++++- tests/test_process_routing.py | 20 ++-- tests/test_profiling_api.py | 52 +++++++---- tests/test_settings_api.py | 38 +++++++- tests/test_topology_api.py | 8 +- 14 files changed, 274 insertions(+), 163 deletions(-) diff --git a/examples/profiling_tui.py b/examples/profiling_tui.py index 99059630..17582968 100644 --- a/examples/profiling_tui.py +++ b/examples/profiling_tui.py @@ -18,6 +18,7 @@ import contextlib import time from dataclasses import dataclass +from uuid import UUID from ezmsg.core.graphcontext import GraphContext from ezmsg.core.graphmeta import ( @@ -28,7 +29,8 @@ from ezmsg.core.netprotocol import DEFAULT_HOST, GRAPHSERVER_PORT_DEFAULT -def _truncate(text: str, width: int) -> str: +def _truncate(text: object, width: int) -> str: + text = str(text) if width <= 3: return text[:width] if len(text) <= width: @@ -42,7 +44,7 @@ def _fmt_float(value: float, digits: int = 2) -> str: @dataclass class PublisherView: - process_id: str + process_id: UUID topic: str endpoint_id: str published_total: int @@ -58,7 +60,7 @@ class PublisherView: @dataclass class SubscriberView: - process_id: str + process_id: UUID topic: str endpoint_id: str channel_kind: str @@ -92,10 +94,10 @@ def __init__( self.trace_sample_mod = max(1, trace_sample_mod) self.max_rows = max(5, max_rows) - self.snapshots: dict[str, ProcessProfilingSnapshot] = {} - self.route_units: dict[str, str] = {} - self.trace_enabled_processes: set[str] = set() - self.trace_errors: dict[str, str] = {} + self.snapshots: dict[UUID, ProcessProfilingSnapshot] = {} + self.route_units: dict[UUID, str] = {} + self.trace_enabled_processes: set[UUID] = set() + self.trace_errors: dict[UUID, str] = {} self.trace_samples_seen_by_endpoint: dict[str, int] = {} self.trace_last_timestamp_by_endpoint: dict[str, float] = {} self.last_snapshot_time: float | None = None @@ -135,7 +137,7 @@ async def _snapshot_loop(self) -> None: async def _refresh_snapshot(self) -> None: graph_snapshot = await self.ctx.snapshot() - route_units: dict[str, str] = {} + route_units: dict[UUID, str] = {} for process in graph_snapshot.processes.values(): if process.units: route_units[process.process_id] = process.units[0] @@ -349,8 +351,8 @@ def render(self) -> None: if self.trace_errors: print("\ntrace errors:") - for process_id, err in sorted(self.trace_errors.items()): - print(f" {_truncate(process_id, 30)}: {_truncate(err, 120)}") + for process_id, err in sorted(self.trace_errors.items(), key=lambda item: str(item[0])): + print(f" {_truncate(str(process_id), 30)}: {_truncate(err, 120)}") def _parse_address(host: str, port: int) -> tuple[str, int]: diff --git a/examples/topology_tui.py b/examples/topology_tui.py index 07bc82f4..56876a34 100644 --- a/examples/topology_tui.py +++ b/examples/topology_tui.py @@ -26,7 +26,8 @@ from ezmsg.core.netprotocol import DEFAULT_HOST, GRAPHSERVER_PORT_DEFAULT -def _truncate(text: str, width: int) -> str: +def _truncate(text: object, width: int) -> str: + text = str(text) if width <= 3: return text[:width] if len(text) <= width: diff --git a/src/ezmsg/core/backendprocess.py b/src/ezmsg/core/backendprocess.py index 3fc6a61f..54214252 100644 --- a/src/ezmsg/core/backendprocess.py +++ b/src/ezmsg/core/backendprocess.py @@ -319,7 +319,6 @@ def process(self, loop: asyncio.AbstractEventLoop) -> None: main_func = None context = GraphContext(self.graph_address) process_client = ProcessControlClient(self.graph_address) - PROFILES.set_process_id(process_client.process_id) process_register_future: concurrent.futures.Future[None] | None = None coro_callables: dict[str, Callable[[], Coroutine[Any, Any, None]]] = dict() settings_input_topics: dict[str, str] = {} diff --git a/src/ezmsg/core/graphcontext.py b/src/ezmsg/core/graphcontext.py index cd1e9d48..0c07438b 100644 --- a/src/ezmsg/core/graphcontext.py +++ b/src/ezmsg/core/graphcontext.py @@ -667,9 +667,9 @@ async def profiling_snapshot_all( self, *, timeout_per_process: float = 0.5, - ) -> dict[str, ProcessProfilingSnapshot]: + ) -> dict[UUID, ProcessProfilingSnapshot]: graph_snapshot = await self.snapshot() - out: dict[str, ProcessProfilingSnapshot] = {} + out: dict[UUID, ProcessProfilingSnapshot] = {} for process in graph_snapshot.processes.values(): if len(process.units) == 0: continue diff --git a/src/ezmsg/core/graphmeta.py b/src/ezmsg/core/graphmeta.py index 44b1a7aa..0eafc11a 100644 --- a/src/ezmsg/core/graphmeta.py +++ b/src/ezmsg/core/graphmeta.py @@ -2,6 +2,7 @@ from dataclasses import dataclass, field from typing import Any, TypeAlias, NamedTuple +from uuid import UUID @dataclass @@ -147,7 +148,6 @@ class GraphMetadata: @dataclass class ProcessRegistration: - process_id: str pid: int host: str units: list[str] @@ -155,7 +155,6 @@ class ProcessRegistration: @dataclass class ProcessOwnershipUpdate: - process_id: str added_units: list[str] = field(default_factory=list) removed_units: list[str] = field(default_factory=list) @@ -180,7 +179,7 @@ class SettingsChangedEvent: component_address: str timestamp: float source_session_id: str | None - source_process_id: str | None + source_process_id: UUID | None value: SettingsSnapshotValue @@ -196,12 +195,11 @@ class TopologyChangedEvent: timestamp: float changed_topics: list[str] source_session_id: str | None - source_process_id: str | None + source_process_id: UUID | None @dataclass class ProcessSettingsUpdate: - process_id: str component_address: str value: SettingsSnapshotValue timestamp: float @@ -242,7 +240,7 @@ class ProcessControlResponse: payload: bytes | None = None error: str | None = None error_code: ProcessControlErrorCode | None = None - process_id: str | None = None + process_id: UUID | None = None @dataclass @@ -253,7 +251,7 @@ class SettingsFieldUpdateRequest: @dataclass class ProcessPing: - process_id: str + process_id: UUID pid: int host: str timestamp: float @@ -261,7 +259,7 @@ class ProcessPing: @dataclass class ProcessStats: - process_id: str + process_id: UUID pid: int host: str owned_units: list[str] @@ -309,7 +307,7 @@ class SubscriberProfileSnapshot: @dataclass class ProcessProfilingSnapshot: - process_id: str + process_id: UUID pid: int host: str window_seconds: float @@ -342,7 +340,7 @@ class ProfilingTraceSample: @dataclass class ProcessProfilingTraceBatch: - process_id: str + process_id: UUID pid: int host: str timestamp: float @@ -352,14 +350,14 @@ class ProcessProfilingTraceBatch: @dataclass class ProfilingTraceStreamBatch: timestamp: float - batches: dict[str, ProcessProfilingTraceBatch] + batches: dict[UUID, ProcessProfilingTraceBatch] @dataclass class ProfilingStreamControl: interval: float = 0.05 max_samples: int = 1000 - process_ids: list[str] | None = None + process_ids: list[UUID] | None = None include_empty_batches: bool = False @@ -376,7 +374,7 @@ class SnapshotSession: @dataclass class SnapshotProcess: - process_id: str + process_id: UUID pid: int | None host: str | None units: list[str] @@ -387,4 +385,4 @@ class GraphSnapshot: graph: dict[str, list[str]] edge_owners: dict[Edge, list[str]] sessions: dict[str, SnapshotSession] - processes: dict[str, SnapshotProcess] = field(default_factory=dict) + processes: dict[UUID, SnapshotProcess] = field(default_factory=dict) diff --git a/src/ezmsg/core/graphserver.py b/src/ezmsg/core/graphserver.py index 37adebc3..a1d4e76e 100644 --- a/src/ezmsg/core/graphserver.py +++ b/src/ezmsg/core/graphserver.py @@ -106,9 +106,9 @@ class GraphServer(threading.Thread): _pending_process_requests: dict[ str, tuple[UUID, "asyncio.Future[ProcessControlResponse]"] ] - _profiling_trace_buffers: dict[str, deque[tuple[int, ProfilingTraceSample]]] - _profiling_trace_process_meta: dict[str, tuple[int, str]] - _profiling_trace_seq: dict[str, int] + _profiling_trace_buffers: dict[UUID, deque[tuple[int, ProfilingTraceSample]]] + _profiling_trace_process_meta: dict[UUID, tuple[int, str]] + _profiling_trace_seq: dict[UUID, int] def __init__(self, **kwargs) -> None: super().__init__( @@ -621,6 +621,9 @@ def _process_info(self, process_client_id: UUID) -> ProcessInfo | None: return info return None + def _process_key(self, process_client_id: UUID) -> UUID: + return process_client_id + async def _handle_process( self, process_client_id: UUID, @@ -701,17 +704,11 @@ async def _handle_process( ok=False, error="Owning process disconnected before response", error_code=ProcessControlErrorCode.PROCESS_DISCONNECTED, - process_id=( - process_info.process_id if process_info is not None else None - ), + process_id=self._process_key(process_client_id), ) ) if process_info is not None: - source_process_id = ( - process_info.process_id - if process_info.process_id is not None - else str(process_client_id) - ) + source_process_id = self._process_key(process_client_id) self._profiling_trace_buffers.pop(source_process_id, None) self._profiling_trace_process_meta.pop(source_process_id, None) self._profiling_trace_seq.pop(source_process_id, None) @@ -890,7 +887,7 @@ async def _collect_profiling_trace_stream_batch( self, *, stream_control: ProfilingStreamControl, - last_seq_by_process: dict[str, int], + last_seq_by_process: dict[UUID, int], ) -> ProfilingTraceStreamBatch: process_ids_filter = ( set(stream_control.process_ids) @@ -899,25 +896,23 @@ async def _collect_profiling_trace_stream_batch( ) max_samples = max(1, int(stream_control.max_samples)) now_ts = time.time() - batches: dict[str, ProcessProfilingTraceBatch] = {} + batches: dict[UUID, ProcessProfilingTraceBatch] = {} async with self._command_lock: - connected_processes: dict[str, tuple[int, str]] = {} + connected_processes: dict[UUID, tuple[int, str]] = {} for client_id, info in self.clients.items(): if not isinstance(info, ProcessInfo): continue - process_id = ( - info.process_id if info.process_id is not None else str(client_id) - ) + process_id = self._process_key(client_id) pid = info.pid if info.pid is not None else -1 host = info.host if info.host is not None else "" connected_processes[process_id] = (pid, host) - process_ids: list[str] + process_ids: list[UUID] if process_ids_filter is not None: - process_ids = sorted(process_ids_filter) + process_ids = sorted(process_ids_filter, key=str) else: - process_ids = sorted(connected_processes.keys()) + process_ids = sorted(connected_processes.keys(), key=str) for process_id in process_ids: sample_buffer = self._profiling_trace_buffers.get(process_id) @@ -961,7 +956,7 @@ async def _handle_profiling_subscriber( writer: asyncio.StreamWriter, ) -> None: interval = max(0.01, float(stream_control.interval)) - last_seq_by_process: dict[str, int] = {} + last_seq_by_process: dict[UUID, int] = {} try: while True: try: @@ -1010,16 +1005,7 @@ async def _handle_process_profiling_trace_update_request( return async with self._command_lock: - process_info = self._process_info(process_client_id) - process_id = ( - ( - process_info.process_id - if process_info is not None and process_info.process_id is not None - else str(process_client_id) - ) - if batch.process_id == "" - else batch.process_id - ) + process_id = self._process_key(process_client_id) trace_buffer = self._profiling_trace_buffers.setdefault( process_id, deque(maxlen=200_000) ) @@ -1046,22 +1032,34 @@ async def _handle_process_register_request( if process_info is None: return Command.ERROR.value + conflicts = sorted( + { + unit + for unit in set(registration.units) + if ( + (owner := self._process_owner_for_unit(unit)) is not None + and owner != process_client_id + ) + } + ) + if conflicts: + logger.warning( + "Process control %s register rejected due to unit ownership conflict(s): %s", + process_client_id, + ", ".join(conflicts), + ) + return Command.ERROR.value + prev_units = set(process_info.units) - prev_process_id = process_info.process_id - process_info.process_id = registration.process_id process_info.pid = registration.pid process_info.host = registration.host process_info.units = set(registration.units) - if prev_process_id is not None and prev_process_id != registration.process_id: - self._profiling_trace_buffers.pop(prev_process_id, None) - self._profiling_trace_process_meta.pop(prev_process_id, None) - self._profiling_trace_seq.pop(prev_process_id, None) if prev_units != process_info.units: self._append_topology_event_locked( event_type=TopologyEventType.PROCESS_CHANGED, changed_topics=[], source_session_id=None, - source_process_id=registration.process_id, + source_process_id=self._process_key(process_client_id), ) return Command.COMPLETE.value @@ -1082,18 +1080,23 @@ async def _handle_process_update_ownership_request( if process_info is None: return Command.ERROR.value - if ( - process_info.process_id is not None - and process_info.process_id != update.process_id - ): + conflicts = sorted( + { + unit + for unit in set(update.added_units) + if ( + (owner := self._process_owner_for_unit(unit)) is not None + and owner != process_client_id + ) + } + ) + if conflicts: logger.warning( - "Process control %s process_id mismatch: %s != %s", + "Process control %s ownership update rejected due to unit ownership conflict(s): %s", process_client_id, - process_info.process_id, - update.process_id, + ", ".join(conflicts), ) - elif process_info.process_id is None: - process_info.process_id = update.process_id + return Command.ERROR.value prev_units = set(process_info.units) process_info.units.update(update.added_units) @@ -1103,7 +1106,7 @@ async def _handle_process_update_ownership_request( event_type=TopologyEventType.PROCESS_CHANGED, changed_topics=[], source_session_id=None, - source_process_id=update.process_id, + source_process_id=self._process_key(process_client_id), ) return Command.COMPLETE.value @@ -1124,14 +1127,7 @@ async def _handle_process_settings_update_request( if process_info is None: return Command.ERROR.value - if process_info.process_id is None: - process_info.process_id = update.process_id - - source_process_id = ( - process_info.process_id - if process_info.process_id is not None - else update.process_id - ) + source_process_id = self._process_key(process_client_id) self._settings_current[update.component_address] = update.value self._settings_source_session[update.component_address] = None self._append_settings_event_locked( @@ -1193,6 +1189,28 @@ def _process_for_unit(self, unit_address: str) -> ProcessInfo | None: return info return None + def _process_owner_for_unit(self, unit_address: str) -> UUID | None: + for client_id, info in self.clients.items(): + if isinstance(info, ProcessInfo) and unit_address in info.units: + return client_id + return None + + def _metadata_collisions( + self, session_id: UUID, metadata: GraphMetadata + ) -> list[str]: + collisions: list[str] = [] + requested = set(metadata.components.keys()) + if not requested: + return collisions + for other_session_id, info in self.clients.items(): + if other_session_id == session_id or not isinstance(info, SessionInfo): + continue + if info.metadata is None: + continue + overlap = requested.intersection(info.metadata.components.keys()) + collisions.extend(overlap) + return sorted(set(collisions)) + async def _route_process_request( self, unit_address: str, @@ -1238,7 +1256,7 @@ async def _route_process_request( ok=False, error=f"Failed to route request to owning process: {exc}", error_code=ProcessControlErrorCode.ROUTE_WRITE_FAILED, - process_id=process_info.process_id, + process_id=self._process_key(process_info.id), ) try: @@ -1254,7 +1272,7 @@ async def _route_process_request( f"(unit={unit_address}, operation={operation}, timeout={timeout}s)" ), error_code=ProcessControlErrorCode.TIMEOUT, - process_id=process_info.process_id, + process_id=self._process_key(process_info.id), ) async def _handle_session_process_request( @@ -1313,7 +1331,7 @@ def _append_settings_event_locked( component_address: str, value: SettingsSnapshotValue, source_session_id: str | None, - source_process_id: str | None, + source_process_id: UUID | None, timestamp: float | None = None, ) -> None: self._settings_event_seq += 1 @@ -1341,7 +1359,7 @@ def _append_topology_event_locked( event_type: TopologyEventType, changed_topics: list[str], source_session_id: str | None, - source_process_id: str | None, + source_process_id: UUID | None, timestamp: float | None = None, ) -> None: self._topology_event_seq += 1 @@ -1446,6 +1464,14 @@ async def _handle_session_register_request( async with self._command_lock: session = self._session_info(session_id) if session is not None and metadata is not None: + collisions = self._metadata_collisions(session_id, metadata) + if collisions: + logger.warning( + "Session %s metadata registration rejected due to component address collision(s): %s", + session_id, + ", ".join(collisions), + ) + return Command.ERROR.value self._remove_settings_for_session_locked(session_id) session.metadata = metadata self._apply_session_metadata_settings_locked(session_id, metadata) @@ -1560,12 +1586,8 @@ def _snapshot(self) -> GraphSnapshot: ) } processes = { - str(client_id): SnapshotProcess( - process_id=( - process.process_id - if process.process_id is not None - else str(client_id) - ), + client_id: SnapshotProcess( + process_id=self._process_key(client_id), pid=process.pid, host=process.host, units=sorted(process.units), diff --git a/src/ezmsg/core/netprotocol.py b/src/ezmsg/core/netprotocol.py index 67d9c1da..ff6bf242 100644 --- a/src/ezmsg/core/netprotocol.py +++ b/src/ezmsg/core/netprotocol.py @@ -182,7 +182,6 @@ class ProcessInfo(ClientInfo): Process-scoped control-plane client information. """ - process_id: str | None = None pid: int | None = None host: str | None = None units: set[str] = field(default_factory=set) diff --git a/src/ezmsg/core/processclient.py b/src/ezmsg/core/processclient.py index e3caeacc..df62ba9d 100644 --- a/src/ezmsg/core/processclient.py +++ b/src/ezmsg/core/processclient.py @@ -5,7 +5,7 @@ import socket import time -from uuid import UUID, uuid1 +from uuid import UUID from contextlib import suppress from collections.abc import Awaitable, Callable @@ -40,7 +40,6 @@ class ProcessControlClient: _graph_address: AddressType | None - _process_id: str _client_id: UUID | None _reader: asyncio.StreamReader | None _writer: asyncio.StreamWriter | None @@ -55,11 +54,8 @@ class ProcessControlClient: _trace_push_interval_s: float _trace_push_max_samples: int - def __init__( - self, graph_address: AddressType | None = None, process_id: str | None = None - ) -> None: + def __init__(self, graph_address: AddressType | None = None) -> None: self._graph_address = graph_address - self._process_id = process_id if process_id is not None else str(uuid1()) self._client_id = None self._reader = None self._writer = None @@ -75,11 +71,16 @@ def __init__( self._trace_push_max_samples = int( os.environ.get("EZMSG_PROFILE_TRACE_PUSH_MAX_SAMPLES", "1000") ) - PROFILES.set_process_id(self._process_id, reset=True) + PROFILES.set_process_id(UUID(int=0), reset=True) + + def _require_client_id(self) -> UUID: + if self._client_id is None: + raise RuntimeError("Process control connection is not active") + return self._client_id @property - def process_id(self) -> str: - return self._process_id + def process_id(self) -> UUID: + return self._require_client_id() @property def client_id(self) -> UUID | None: @@ -100,6 +101,7 @@ async def connect(self) -> None: raise RuntimeError("Failed to create process control connection") self._client_id = client_id + PROFILES.set_process_id(client_id, reset=True) self._reader = reader self._writer = writer self._io_task = asyncio.create_task( @@ -118,10 +120,8 @@ def set_request_handler( async def register(self, units: list[str]) -> None: await self.connect() - PROFILES.set_process_id(self._process_id) normalized_units = sorted(set(units)) payload = ProcessRegistration( - process_id=self._process_id, pid=os.getpid(), host=socket.gethostname(), units=normalized_units, @@ -138,7 +138,6 @@ async def update_ownership( added = sorted(set(added_units or [])) removed = sorted(set(removed_units or [])) payload = ProcessOwnershipUpdate( - process_id=self._process_id, added_units=added, removed_units=removed, ) @@ -154,7 +153,6 @@ async def report_settings_update( ) -> None: await self.connect() payload = ProcessSettingsUpdate( - process_id=self._process_id, component_address=component_address, value=value, timestamp=timestamp if timestamp is not None else time.time(), @@ -299,13 +297,13 @@ async def _handle_route_request( ok=True, payload=pickle.dumps( ProcessPing( - process_id=self._process_id, + process_id=self.process_id, pid=os.getpid(), host=socket.gethostname(), timestamp=time.time(), ) ), - process_id=self._process_id, + process_id=self.process_id, ) if operation == ProcessControlOperation.GET_PROCESS_STATS: @@ -314,14 +312,14 @@ async def _handle_route_request( ok=True, payload=pickle.dumps( ProcessStats( - process_id=self._process_id, + process_id=self.process_id, pid=os.getpid(), host=socket.gethostname(), owned_units=sorted(self._owned_units), timestamp=time.time(), ) ), - process_id=self._process_id, + process_id=self.process_id, ) if operation == ProcessControlOperation.GET_PROFILING_SNAPSHOT: @@ -330,7 +328,7 @@ async def _handle_route_request( request_id=request.request_id, ok=True, payload=pickle.dumps(snapshot), - process_id=self._process_id, + process_id=self.process_id, ) if operation == ProcessControlOperation.SET_PROFILING_TRACE: @@ -346,7 +344,7 @@ async def _handle_route_request( ok=False, error=f"Invalid profiling trace control payload: {exc}", error_code=ProcessControlErrorCode.INVALID_RESPONSE, - process_id=self._process_id, + process_id=self.process_id, ) if control is None: @@ -355,7 +353,7 @@ async def _handle_route_request( ok=False, error="Missing profiling trace control payload", error_code=ProcessControlErrorCode.INVALID_RESPONSE, - process_id=self._process_id, + process_id=self.process_id, ) PROFILES.set_trace_control(control) @@ -367,7 +365,7 @@ async def _handle_route_request( return ProcessControlResponse( request_id=request.request_id, ok=True, - process_id=self._process_id, + process_id=self.process_id, ) if operation == ProcessControlOperation.GET_PROFILING_TRACE_BATCH: @@ -387,7 +385,7 @@ async def _handle_route_request( request_id=request.request_id, ok=True, payload=pickle.dumps(batch), - process_id=self._process_id, + process_id=self.process_id, ) if self._request_handler is None: @@ -396,7 +394,7 @@ async def _handle_route_request( ok=False, error=f"Unsupported process control operation: {request.operation}", error_code=ProcessControlErrorCode.HANDLER_NOT_CONFIGURED, - process_id=self._process_id, + process_id=self.process_id, ) try: @@ -409,7 +407,7 @@ async def _handle_route_request( ok=False, error=f"process request handler failed: {exc}", error_code=ProcessControlErrorCode.HANDLER_ERROR, - process_id=self._process_id, + process_id=self.process_id, ) if not isinstance(result, ProcessControlResponse): @@ -421,7 +419,7 @@ async def _handle_route_request( f"{type(result).__name__}" ), error_code=ProcessControlErrorCode.INVALID_RESPONSE, - process_id=self._process_id, + process_id=self.process_id, ) if result.request_id != request.request_id: @@ -433,11 +431,11 @@ async def _handle_route_request( f"{result.request_id}" ), error_code=ProcessControlErrorCode.INVALID_RESPONSE, - process_id=self._process_id, + process_id=self.process_id, ) if result.process_id is None: - result.process_id = self._process_id + result.process_id = self.process_id return result @@ -447,7 +445,7 @@ async def _ensure_trace_push_task(self) -> None: return self._trace_push_task = asyncio.create_task( self._trace_push_loop(), - name=f"proc-trace-push-{self._process_id}", + name=f"proc-trace-push-{self.process_id}", ) async def _cancel_trace_push_task(self) -> None: diff --git a/src/ezmsg/core/profiling.py b/src/ezmsg/core/profiling.py index 760b377b..6c8cbeca 100644 --- a/src/ezmsg/core/profiling.py +++ b/src/ezmsg/core/profiling.py @@ -283,7 +283,7 @@ def snapshot(self) -> SubscriberProfileSnapshot: class ProfileRegistry: def __init__(self) -> None: - self._process_id = "" + self._process_id = UUID(int=0) self._pid = os.getpid() self._host = socket.gethostname() self._publishers: dict[UUID, _PublisherMetrics] = {} @@ -291,7 +291,7 @@ def __init__(self) -> None: self._default_trace_control = ProfilingTraceControl(enabled=False) self._trace_control_expires_ns: int | None = None - def set_process_id(self, process_id: str, *, reset: bool = False) -> None: + def set_process_id(self, process_id: UUID, *, reset: bool = False) -> None: if reset or (self._process_id and self._process_id != process_id): self._publishers.clear() self._subscribers.clear() diff --git a/tests/test_process_control.py b/tests/test_process_control.py index ad5db53b..81a22df1 100644 --- a/tests/test_process_control.py +++ b/tests/test_process_control.py @@ -18,8 +18,10 @@ async def test_process_registration_visible_in_snapshot(): observer = GraphContext(address, auto_start=False) await observer.__aenter__() - process = ProcessControlClient(address, process_id="proc-A") + process = ProcessControlClient(address) await process.connect() + assert process.client_id is not None + process_key = process.client_id try: await process.register(["SYS/U1", "SYS/U2"]) @@ -29,7 +31,7 @@ async def test_process_registration_visible_in_snapshot(): assert len(snapshot.processes) == 1 process_entry = next(iter(snapshot.processes.values())) - assert process_entry.process_id == "proc-A" + assert process_entry.process_id == process_key assert process_entry.pid is not None assert process_entry.host is not None assert process_entry.units == ["SYS/U2", "SYS/U3"] @@ -49,7 +51,7 @@ async def test_process_snapshot_entry_drops_on_disconnect(): observer = GraphContext(address, auto_start=False) await observer.__aenter__() - process = ProcessControlClient(address, process_id="proc-B") + process = ProcessControlClient(address) await process.connect() try: @@ -127,7 +129,6 @@ async def test_process_register_succeeds_after_error_ack(): good_payload = pickle.dumps( ProcessRegistration( - process_id="proc-after-error", pid=123, host="test-host", units=["SYS/U1"], @@ -142,3 +143,44 @@ async def test_process_register_succeeds_after_error_ack(): finally: await close_stream_writer(writer) graph_server.stop() + + +@pytest.mark.asyncio +async def test_process_register_rejects_unit_ownership_collision(): + graph_server = GraphService().create_server() + address = graph_server.address + + process_a = ProcessControlClient(address) + process_b = ProcessControlClient(address) + await process_a.connect() + await process_b.connect() + + try: + await process_a.register(["SYS/U1"]) + with pytest.raises(RuntimeError, match="PROCESS_REGISTER"): + await process_b.register(["SYS/U1"]) + finally: + await process_a.close() + await process_b.close() + graph_server.stop() + + +@pytest.mark.asyncio +async def test_process_update_ownership_rejects_unit_ownership_collision(): + graph_server = GraphService().create_server() + address = graph_server.address + + process_a = ProcessControlClient(address) + process_b = ProcessControlClient(address) + await process_a.connect() + await process_b.connect() + + try: + await process_a.register(["SYS/U1"]) + await process_b.register(["SYS/U2"]) + with pytest.raises(RuntimeError, match="PROCESS_UPDATE_OWNERSHIP"): + await process_b.update_ownership(added_units=["SYS/U1"]) + finally: + await process_a.close() + await process_b.close() + graph_server.stop() diff --git a/tests/test_process_routing.py b/tests/test_process_routing.py index f093bcb5..2402466d 100644 --- a/tests/test_process_routing.py +++ b/tests/test_process_routing.py @@ -20,8 +20,10 @@ async def test_process_routing_roundtrip(): observer = GraphContext(address, auto_start=False) await observer.__aenter__() - process = ProcessControlClient(address, process_id="proc-route") + process = ProcessControlClient(address) await process.connect() + assert process.client_id is not None + process_key = process.client_id await process.register(["SYS/U1"]) async def handler(request: ProcessControlRequest) -> ProcessControlResponse: @@ -44,7 +46,7 @@ async def handler(request: ProcessControlRequest) -> ProcessControlResponse: ) assert response.ok assert response.payload == b"hello" - assert response.process_id == "proc-route" + assert response.process_id == process_key finally: await process.close() await observer.__aexit__(None, None, None) @@ -59,20 +61,22 @@ async def test_process_routing_builtin_ping_and_stats(): observer = GraphContext(address, auto_start=False) await observer.__aenter__() - process = ProcessControlClient(address, process_id="proc-builtins") + process = ProcessControlClient(address) await process.connect() + assert process.client_id is not None + process_key = process.client_id await process.register(["SYS/U1", "SYS/U2"]) await process.update_ownership(removed_units=["SYS/U2"], added_units=["SYS/U3"]) try: ping = await observer.process_ping("SYS/U1", timeout=1.0) - assert ping.process_id == "proc-builtins" + assert ping.process_id == process_key assert ping.pid > 0 assert ping.host assert ping.timestamp > 0.0 stats = await observer.process_stats("SYS/U1", timeout=1.0) - assert stats.process_id == "proc-builtins" + assert stats.process_id == process_key assert stats.pid > 0 assert stats.host assert stats.owned_units == ["SYS/U1", "SYS/U3"] @@ -115,8 +119,10 @@ async def test_process_routing_timeout_returns_error(): observer = GraphContext(address, auto_start=False) await observer.__aenter__() - process = ProcessControlClient(address, process_id="proc-timeout") + process = ProcessControlClient(address) await process.connect() + assert process.client_id is not None + process_key = process.client_id await process.register(["SYS/U2"]) block = asyncio.Event() @@ -137,7 +143,7 @@ async def blocking_handler(_request: ProcessControlRequest) -> ProcessControlRes assert response.error is not None assert "Timed out waiting for process response" in response.error assert response.error_code == ProcessControlErrorCode.TIMEOUT - assert response.process_id == "proc-timeout" + assert response.process_id == process_key finally: await process.close() await observer.__aexit__(None, None, None) diff --git a/tests/test_profiling_api.py b/tests/test_profiling_api.py index 67fcf9ed..78df76d7 100644 --- a/tests/test_profiling_api.py +++ b/tests/test_profiling_api.py @@ -20,8 +20,10 @@ async def test_process_profiling_snapshot_collects_pub_sub_metrics(): ctx = GraphContext(address, auto_start=False) await ctx.__aenter__() - process = ProcessControlClient(address, process_id="proc-prof") + process = ProcessControlClient(address) await process.connect() + assert process.client_id is not None + process_key = process.client_id await process.register(["SYS/U1"]) pub = await ctx.publisher("TOPIC_PROF") @@ -34,7 +36,7 @@ async def test_process_profiling_snapshot_collects_pub_sub_metrics(): await asyncio.sleep(0) snap = await ctx.process_profiling_snapshot("SYS/U1", timeout=1.0) - assert snap.process_id == "proc-prof" + assert snap.process_id == process_key assert snap.window_seconds > 0 assert len(snap.publishers) >= 1 assert len(snap.subscribers) >= 1 @@ -61,8 +63,10 @@ async def test_process_profiling_trace_control_and_batch(): ctx = GraphContext(address, auto_start=False) await ctx.__aenter__() - process = ProcessControlClient(address, process_id="proc-trace") + process = ProcessControlClient(address) await process.connect() + assert process.client_id is not None + process_key = process.client_id await process.register(["SYS/U2"]) pub = await ctx.publisher("TOPIC_TRACE") @@ -93,7 +97,7 @@ async def test_process_profiling_trace_control_and_batch(): batch = await ctx.process_profiling_trace_batch( "SYS/U2", max_samples=200, timeout=1.0 ) - assert batch.process_id == "proc-trace" + assert batch.process_id == process_key assert len(batch.samples) > 0 disable_response = await ctx.process_set_profiling_trace( @@ -116,14 +120,16 @@ async def test_profiling_snapshot_all_and_unroutable_error_code(): ctx = GraphContext(address, auto_start=False) await ctx.__aenter__() - process = ProcessControlClient(address, process_id="proc-all") + process = ProcessControlClient(address) await process.connect() + assert process.client_id is not None + process_key = process.client_id await process.register(["SYS/U3"]) try: snapshots = await ctx.profiling_snapshot_all(timeout_per_process=0.5) - assert "proc-all" in snapshots - assert snapshots["proc-all"].process_id == "proc-all" + assert process_key in snapshots + assert snapshots[process_key].process_id == process_key response = await ctx.process_request( "SYS/MISSING", @@ -146,8 +152,10 @@ async def test_process_profiling_trace_subscription_push(): ctx = GraphContext(address, auto_start=False) await ctx.__aenter__() - process = ProcessControlClient(address, process_id="proc-stream") + process = ProcessControlClient(address) await process.connect() + assert process.client_id is not None + process_key = process.client_id await process.register(["SYS/U4"]) pub = await ctx.publisher("TOPIC_STREAM") @@ -182,9 +190,9 @@ async def test_process_profiling_trace_subscription_push(): batch = await asyncio.wait_for(anext(stream), timeout=1.0) assert batch.timestamp > 0.0 - assert "proc-stream" in batch.batches - process_batch = batch.batches["proc-stream"] - assert process_batch.process_id == "proc-stream" + assert process_key in batch.batches + process_batch = batch.batches[process_key] + assert process_batch.process_id == process_key assert len(process_batch.samples) > 0 finally: if stream is not None: @@ -202,7 +210,7 @@ async def test_process_profiling_trace_control_endpoint_metric_and_ttl(): ctx = GraphContext(address, auto_start=False) await ctx.__aenter__() - process = ProcessControlClient(address, process_id="proc-trace-filter") + process = ProcessControlClient(address) await process.connect() await process.register(["SYS/U5"]) @@ -292,8 +300,10 @@ async def test_process_profiling_trace_subscription_stream_control(): ctx = GraphContext(address, auto_start=False) await ctx.__aenter__() - process_a = ProcessControlClient(address, process_id="proc-stream-a") + process_a = ProcessControlClient(address) await process_a.connect() + assert process_a.client_id is not None + process_a_key = process_a.client_id await process_a.register(["SYS/U6"]) stream = None @@ -302,12 +312,12 @@ async def test_process_profiling_trace_subscription_stream_control(): ProfilingStreamControl( interval=0.02, max_samples=64, - process_ids=["proc-stream-a"], + process_ids=[process_a_key], include_empty_batches=True, ) ) batch = await asyncio.wait_for(anext(stream), timeout=1.0) - assert "proc-stream-a" in batch.batches + assert process_a_key in batch.batches assert len(batch.batches) == 1 finally: if stream is not None: @@ -325,8 +335,10 @@ async def test_process_profiling_trace_subscription_does_not_starve_peer_subscri ctx = GraphContext(address, auto_start=False) await ctx.__aenter__() - process = ProcessControlClient(address, process_id="proc-stream-multi") + process = ProcessControlClient(address) await process.connect() + assert process.client_id is not None + process_key = process.client_id await process.register(["SYS/U7"]) pub = await ctx.publisher("TOPIC_STREAM_MULTI") @@ -363,10 +375,10 @@ async def test_process_profiling_trace_subscription_does_not_starve_peer_subscri batch_a = await asyncio.wait_for(anext(stream_a), timeout=1.0) batch_b = await asyncio.wait_for(anext(stream_b), timeout=1.0) - assert "proc-stream-multi" in batch_a.batches - assert "proc-stream-multi" in batch_b.batches - assert len(batch_a.batches["proc-stream-multi"].samples) > 0 - assert len(batch_b.batches["proc-stream-multi"].samples) > 0 + assert process_key in batch_a.batches + assert process_key in batch_b.batches + assert len(batch_a.batches[process_key].samples) > 0 + assert len(batch_b.batches[process_key].samples) > 0 finally: if stream_a is not None: await stream_a.aclose() diff --git a/tests/test_settings_api.py b/tests/test_settings_api.py index e987a98a..c5938e9d 100644 --- a/tests/test_settings_api.py +++ b/tests/test_settings_api.py @@ -203,8 +203,10 @@ async def test_graphcontext_update_setting_field_routes_to_process(): observer = GraphContext(address, auto_start=False) await observer.__aenter__() - process = ProcessControlClient(address, process_id="proc-setting-patch") + process = ProcessControlClient(address) await process.connect() + assert process.client_id is not None + process_key = process.client_id await process.register(["SYS/SINK"]) try: @@ -240,7 +242,7 @@ async def test_graphcontext_update_setting_waits_and_propagates_process_failure( observer = GraphContext(address, auto_start=False) await observer.__aenter__() - process = ProcessControlClient(address, process_id="proc-setting-fail") + process = ProcessControlClient(address) await process.connect() await process.register(["SYS/SINK"]) @@ -276,8 +278,10 @@ async def test_process_reported_settings_update_visible_in_snapshot_and_events() observer = GraphContext(address, auto_start=False) await observer.__aenter__() - process = ProcessControlClient(address, process_id="proc-settings") + process = ProcessControlClient(address) await process.connect() + assert process.client_id is not None + process_key = process.client_id try: await process.register(["SYS/UNIT_B"]) @@ -297,7 +301,7 @@ async def test_process_reported_settings_update_visible_in_snapshot_and_events() and event.event_type == SettingsEventType.SETTINGS_UPDATED ] assert matching - assert matching[-1].source_process_id == "proc-settings" + assert matching[-1].source_process_id == process_key stream = observer.subscribe_settings_events(after_seq=0) streamed = await asyncio.wait_for(anext(stream), timeout=1.0) @@ -335,3 +339,29 @@ async def test_session_owned_settings_removed_when_session_drops(): await owner.__aexit__(None, None, None) await observer.__aexit__(None, None, None) graph_server.stop() + + +@pytest.mark.asyncio +async def test_metadata_registration_rejects_component_address_collision(): + graph_server = GraphService().create_server() + address = graph_server.address + + owner_a = GraphContext(address, auto_start=False) + owner_b = GraphContext(address, auto_start=False) + + await owner_a.__aenter__() + await owner_b.__aenter__() + + try: + component_address = "SYS/UNIT_COLLIDE" + metadata = _metadata_with_component(component_address) + await owner_a.register_metadata(metadata) + with pytest.raises( + RuntimeError, + match="Unexpected response to session metadata registration", + ): + await owner_b.register_metadata(metadata) + finally: + await owner_a.__aexit__(None, None, None) + await owner_b.__aexit__(None, None, None) + graph_server.stop() diff --git a/tests/test_topology_api.py b/tests/test_topology_api.py index bd4b66a4..4633c093 100644 --- a/tests/test_topology_api.py +++ b/tests/test_topology_api.py @@ -69,18 +69,20 @@ async def test_topology_subscription_reports_process_changes(): observer = GraphContext(address, auto_start=False) await observer.__aenter__() - process = ProcessControlClient(address, process_id="proc-topology") + process = ProcessControlClient(address) stream = observer.subscribe_topology_events(after_seq=0) try: await process.connect() + assert process.client_id is not None + process_key = process.client_id await process.register(["SYS/U1"]) registered = await _next_matching_event( stream, lambda e: ( e.event_type == TopologyEventType.PROCESS_CHANGED - and e.source_process_id == "proc-topology" + and e.source_process_id == process_key ), timeout=1.0, ) @@ -91,7 +93,7 @@ async def test_topology_subscription_reports_process_changes(): stream, lambda e: ( e.event_type == TopologyEventType.PROCESS_CHANGED - and e.source_process_id == "proc-topology" + and e.source_process_id == process_key ), timeout=1.0, ) From 565528870c7cf5b4df2c3eaf032878a57655d5ad Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Tue, 17 Mar 2026 13:21:26 -0400 Subject: [PATCH 24/52] fix: profiling registry --- src/ezmsg/core/processclient.py | 3 +-- src/ezmsg/core/profiling.py | 2 +- tests/test_profiling_api.py | 33 +++++++++++++++++++++++++++++++++ 3 files changed, 35 insertions(+), 3 deletions(-) diff --git a/src/ezmsg/core/processclient.py b/src/ezmsg/core/processclient.py index df62ba9d..e84f1e46 100644 --- a/src/ezmsg/core/processclient.py +++ b/src/ezmsg/core/processclient.py @@ -71,7 +71,6 @@ def __init__(self, graph_address: AddressType | None = None) -> None: self._trace_push_max_samples = int( os.environ.get("EZMSG_PROFILE_TRACE_PUSH_MAX_SAMPLES", "1000") ) - PROFILES.set_process_id(UUID(int=0), reset=True) def _require_client_id(self) -> UUID: if self._client_id is None: @@ -101,7 +100,7 @@ async def connect(self) -> None: raise RuntimeError("Failed to create process control connection") self._client_id = client_id - PROFILES.set_process_id(client_id, reset=True) + PROFILES.set_process_id(client_id) self._reader = reader self._writer = writer self._io_task = asyncio.create_task( diff --git a/src/ezmsg/core/profiling.py b/src/ezmsg/core/profiling.py index 6c8cbeca..318ec002 100644 --- a/src/ezmsg/core/profiling.py +++ b/src/ezmsg/core/profiling.py @@ -292,7 +292,7 @@ def __init__(self) -> None: self._trace_control_expires_ns: int | None = None def set_process_id(self, process_id: UUID, *, reset: bool = False) -> None: - if reset or (self._process_id and self._process_id != process_id): + if reset: self._publishers.clear() self._subscribers.clear() self._default_trace_control = ProfilingTraceControl(enabled=False) diff --git a/tests/test_profiling_api.py b/tests/test_profiling_api.py index 78df76d7..21047390 100644 --- a/tests/test_profiling_api.py +++ b/tests/test_profiling_api.py @@ -55,6 +55,39 @@ async def test_process_profiling_snapshot_collects_pub_sub_metrics(): graph_server.stop() +@pytest.mark.asyncio +async def test_process_connect_does_not_clear_preexisting_profile_metrics(): + graph_server = GraphService().create_server() + address = graph_server.address + + ctx = GraphContext(address, auto_start=False) + await ctx.__aenter__() + + pub = await ctx.publisher("TOPIC_PRECONNECT") + sub = await ctx.subscriber("TOPIC_PRECONNECT") + for idx in range(6): + await pub.broadcast(idx) + async with sub.recv_zero_copy() as _msg: + await asyncio.sleep(0) + + process = ProcessControlClient(address) + await process.connect() + await process.register(["SYS/U_PRE"]) + + try: + snap = await ctx.process_profiling_snapshot("SYS/U_PRE", timeout=1.0) + assert len(snap.publishers) >= 1 + assert len(snap.subscribers) >= 1 + pub_metrics = next(iter(snap.publishers.values())) + sub_metrics = next(iter(snap.subscribers.values())) + assert pub_metrics.messages_published_total >= 6 + assert sub_metrics.messages_received_total >= 6 + finally: + await process.close() + await ctx.__aexit__(None, None, None) + graph_server.stop() + + @pytest.mark.asyncio async def test_process_profiling_trace_control_and_batch(): graph_server = GraphService().create_server() From 2b5ff3134ff1b88dbc32da2a6a3c42eebf99639c Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Tue, 17 Mar 2026 13:29:15 -0400 Subject: [PATCH 25/52] better errors for high-level API name collisions, and fixed brittle tests --- src/ezmsg/core/graphcontext.py | 34 ++++++++++++++++++++++++++++++++-- tests/test_profiling_api.py | 16 ++++++++++++---- tests/test_settings_api.py | 2 +- 3 files changed, 45 insertions(+), 7 deletions(-) diff --git a/src/ezmsg/core/graphcontext.py b/src/ezmsg/core/graphcontext.py index 0c07438b..81970387 100644 --- a/src/ezmsg/core/graphcontext.py +++ b/src/ezmsg/core/graphcontext.py @@ -360,8 +360,38 @@ async def register_metadata(self, metadata: GraphMetadata) -> None: payload=payload, response_kind=_SessionResponseKind.BYTE, ) - if response != Command.COMPLETE.value: - raise RuntimeError("Unexpected response to session metadata registration") + if response == Command.COMPLETE.value: + return + if response == Command.ERROR.value: + requested = set(metadata.components.keys()) + collisions: set[str] = set() + if len(requested) > 0: + own_session_id = str(self._session_id) if self._session_id is not None else None + try: + snapshot = await self.snapshot() + for session_id, session in snapshot.sessions.items(): + if own_session_id is not None and session_id == own_session_id: + continue + if session.metadata is None: + continue + collisions.update( + requested.intersection(session.metadata.components.keys()) + ) + except Exception: + # Fall back to a generic error if snapshot lookup fails. + pass + + if len(collisions) > 0: + collision_str = ", ".join(sorted(collisions)) + raise RuntimeError( + "Session metadata registration rejected by GraphServer due to " + f"component address collision(s): {collision_str}" + ) + raise RuntimeError("Session metadata registration rejected by GraphServer") + raise RuntimeError( + "Unexpected response to session metadata registration: " + f"{response!r}" + ) async def snapshot(self) -> GraphSnapshot: snapshot = await self._session_command( diff --git a/tests/test_profiling_api.py b/tests/test_profiling_api.py index 21047390..fe00d732 100644 --- a/tests/test_profiling_api.py +++ b/tests/test_profiling_api.py @@ -41,11 +41,15 @@ async def test_process_profiling_snapshot_collects_pub_sub_metrics(): assert len(snap.publishers) >= 1 assert len(snap.subscribers) >= 1 - pub_metrics = next(iter(snap.publishers.values())) + pub_metrics = next( + pub for pub in snap.publishers.values() if pub.topic == "TOPIC_PROF" + ) assert pub_metrics.messages_published_total >= 8 assert pub_metrics.publish_rate_hz_window >= 0.0 - sub_metrics = next(iter(snap.subscribers.values())) + sub_metrics = next( + sub for sub in snap.subscribers.values() if sub.topic == "TOPIC_PROF" + ) assert sub_metrics.messages_received_total >= 8 assert sub_metrics.lease_time_ns_total > 0 assert sub_metrics.lease_time_ns_avg_window >= 0.0 @@ -78,8 +82,12 @@ async def test_process_connect_does_not_clear_preexisting_profile_metrics(): snap = await ctx.process_profiling_snapshot("SYS/U_PRE", timeout=1.0) assert len(snap.publishers) >= 1 assert len(snap.subscribers) >= 1 - pub_metrics = next(iter(snap.publishers.values())) - sub_metrics = next(iter(snap.subscribers.values())) + pub_metrics = next( + pub for pub in snap.publishers.values() if pub.topic == "TOPIC_PRECONNECT" + ) + sub_metrics = next( + sub for sub in snap.subscribers.values() if sub.topic == "TOPIC_PRECONNECT" + ) assert pub_metrics.messages_published_total >= 6 assert sub_metrics.messages_received_total >= 6 finally: diff --git a/tests/test_settings_api.py b/tests/test_settings_api.py index c5938e9d..55f4b8f6 100644 --- a/tests/test_settings_api.py +++ b/tests/test_settings_api.py @@ -358,7 +358,7 @@ async def test_metadata_registration_rejects_component_address_collision(): await owner_a.register_metadata(metadata) with pytest.raises( RuntimeError, - match="Unexpected response to session metadata registration", + match="component address collision\\(s\\): SYS/UNIT_COLLIDE", ): await owner_b.register_metadata(metadata) finally: From 09810f109b0102b19cb32f906588593bf40e37e0 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Thu, 19 Mar 2026 14:04:20 -0400 Subject: [PATCH 26/52] Include num_buffers in publisher profiling snapshots --- src/ezmsg/core/graphmeta.py | 1 + src/ezmsg/core/profiling.py | 5 ++++- src/ezmsg/core/pubclient.py | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/ezmsg/core/graphmeta.py b/src/ezmsg/core/graphmeta.py index 0eafc11a..38fe8737 100644 --- a/src/ezmsg/core/graphmeta.py +++ b/src/ezmsg/core/graphmeta.py @@ -282,6 +282,7 @@ class PublisherProfileSnapshot: publish_delta_ns_avg_window: float publish_rate_hz_window: float inflight_messages_current: int + num_buffers: int inflight_messages_peak_window: int backpressure_wait_ns_total: int backpressure_wait_ns_window: int diff --git a/src/ezmsg/core/profiling.py b/src/ezmsg/core/profiling.py index 318ec002..3fba37e4 100644 --- a/src/ezmsg/core/profiling.py +++ b/src/ezmsg/core/profiling.py @@ -96,6 +96,7 @@ def avg(self) -> float: class _PublisherMetrics: topic: str endpoint_id: str + num_buffers: int messages_published_total: int = 0 backpressure_wait_ns_total: int = 0 inflight_messages_current: int = 0 @@ -168,6 +169,7 @@ def snapshot(self) -> PublisherProfileSnapshot: publish_delta_ns_avg_window=self._publish_delta.avg(), publish_rate_hz_window=float(window_msgs) / max(WINDOW_SECONDS, 1e-9), inflight_messages_current=self.inflight_messages_current, + num_buffers=self.num_buffers, inflight_messages_peak_window=self._inflight.max_total(), backpressure_wait_ns_total=self.backpressure_wait_ns_total, backpressure_wait_ns_window=self._backpressure_wait.sum_total(), @@ -299,10 +301,11 @@ def set_process_id(self, process_id: UUID, *, reset: bool = False) -> None: self._trace_control_expires_ns = None self._process_id = process_id - def register_publisher(self, pub_id: UUID, topic: str) -> None: + def register_publisher(self, pub_id: UUID, topic: str, num_buffers: int) -> None: metric = _PublisherMetrics( topic=topic, endpoint_id=_endpoint_id(topic, pub_id), + num_buffers=max(1, int(num_buffers)), ) self._publishers[pub_id] = metric self._apply_trace_control_to_publisher(metric) diff --git a/src/ezmsg/core/pubclient.py b/src/ezmsg/core/pubclient.py index 270e50b7..abc97641 100644 --- a/src/ezmsg/core/pubclient.py +++ b/src/ezmsg/core/pubclient.py @@ -231,7 +231,7 @@ def __init__( self._force_tcp = force_tcp self._last_backpressure_event = -1 self._graph_address = graph_address - PROFILES.register_publisher(self.id, self.topic) + PROFILES.register_publisher(self.id, self.topic, self._num_buffers) @property def log_name(self) -> str: From 39bb093cf701f35bd7fada2b53d31cb79911275b Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Thu, 19 Mar 2026 15:17:52 -0400 Subject: [PATCH 27/52] modified the toy example to add some dynamic settings --- examples/ezmsg_toy.py | 38 ++++++++++++++++++++++++++++++++------ 1 file changed, 32 insertions(+), 6 deletions(-) diff --git a/examples/ezmsg_toy.py b/examples/ezmsg_toy.py index 08f51f5f..69f7141c 100644 --- a/examples/ezmsg_toy.py +++ b/examples/ezmsg_toy.py @@ -24,20 +24,33 @@ class LFOSettings(ez.Settings): update_rate: float = 2.0 # Hz, update rate +class LFOState(ez.State): + start_time: float + cur_settings: LFOSettings + + class LFO(ez.Unit): SETTINGS = LFOSettings + STATE = LFOState OUTPUT = ez.OutputStream(float) + INPUT_SETTINGS = ez.InputStream(LFOSettings) + async def initialize(self) -> None: - self.start_time = time.time() + self.STATE.cur_settings = self.SETTINGS + self.STATE.start_time = time.time() + @ez.subscriber(INPUT_SETTINGS) + async def on_settings(self, msg: LFOSettings) -> None: + self.STATE.cur_settings = msg + @ez.publisher(OUTPUT) async def generate(self) -> AsyncGenerator: while True: - t = time.time() - self.start_time - yield self.OUTPUT, math.sin(2.0 * math.pi * self.SETTINGS.freq * t) - await asyncio.sleep(1.0 / self.SETTINGS.update_rate) + t = time.time() - self.STATE.start_time + yield self.OUTPUT, math.sin(2.0 * math.pi * self.STATE.cur_settings.freq * t) + await asyncio.sleep(1.0 / self.STATE.cur_settings.update_rate) # MESSAGE GENERATOR @@ -45,17 +58,30 @@ class MessageGeneratorSettings(ez.Settings): message: str +class MessageGeneratorState(ez.State): + cur_settings: MessageGeneratorSettings + + class MessageGenerator(ez.Unit): SETTINGS = MessageGeneratorSettings + STATE = MessageGeneratorState OUTPUT = ez.OutputStream(str) + INPUT_SETTINGS = ez.InputStream(MessageGeneratorSettings) + + async def initialize(self) -> None: + self.STATE.cur_settings = self.SETTINGS + + @ez.subscriber(INPUT_SETTINGS) + async def on_settings(self, msg: MessageGeneratorSettings) -> None: + self.STATE.cur_settings = msg @ez.publisher(OUTPUT) async def spawn_message(self) -> AsyncGenerator: while True: await asyncio.sleep(1.0) - ez.logger.info(f"Spawning {self.SETTINGS.message}") - yield self.OUTPUT, self.SETTINGS.message + ez.logger.info(f"Spawning {self.STATE.cur_settings.message}") + yield self.OUTPUT, self.STATE.cur_settings.message @ez.publisher(OUTPUT) async def spawn_once(self) -> AsyncGenerator: From 66e8261e838a441516861d3ceae74fc5b395d10d Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Thu, 19 Mar 2026 15:57:13 -0400 Subject: [PATCH 28/52] Add trace sample sequence IDs for timing alignment --- src/ezmsg/core/graphmeta.py | 1 + src/ezmsg/core/messagechannel.py | 6 +++- src/ezmsg/core/profiling.py | 60 +++++++++++++++++++++++++------- src/ezmsg/core/pubclient.py | 12 +++++-- src/ezmsg/core/subclient.py | 27 +++++++++----- 5 files changed, 82 insertions(+), 24 deletions(-) diff --git a/src/ezmsg/core/graphmeta.py b/src/ezmsg/core/graphmeta.py index 38fe8737..c96e4d7d 100644 --- a/src/ezmsg/core/graphmeta.py +++ b/src/ezmsg/core/graphmeta.py @@ -337,6 +337,7 @@ class ProfilingTraceSample: metric: str value: float channel_kind: ProfileChannelType | None = None + sample_seq: int | None = None @dataclass diff --git a/src/ezmsg/core/messagechannel.py b/src/ezmsg/core/messagechannel.py index 35cc8b26..3c86fa58 100644 --- a/src/ezmsg/core/messagechannel.py +++ b/src/ezmsg/core/messagechannel.py @@ -414,7 +414,11 @@ def _release_backpressure(self, msg_id: int, client_id: UUID) -> None: if lease is not None: start_ns = lease PROFILES.subscriber_attributed_backpressure( - client_id, now_ns, now_ns - start_ns, self._channel_kind + client_id, + now_ns, + now_ns - start_ns, + self._channel_kind, + msg_seq=msg_id, ) buf_idx = msg_id % self.num_buffers diff --git a/src/ezmsg/core/profiling.py b/src/ezmsg/core/profiling.py index 3fba37e4..77530242 100644 --- a/src/ezmsg/core/profiling.py +++ b/src/ezmsg/core/profiling.py @@ -116,7 +116,9 @@ class _PublisherMetrics: def _trace_metric_enabled(self, metric: str) -> bool: return self.trace_metrics is None or metric in self.trace_metrics - def record_publish(self, ts_ns: int, inflight: int) -> None: + def record_publish( + self, ts_ns: int, inflight: int, msg_seq: int | None = None + ) -> None: self.messages_published_total += 1 self._publish_count.add(ts_ns, 1) publish_delta_ns = 0 @@ -138,10 +140,13 @@ def record_publish(self, ts_ns: int, inflight: int) -> None: topic=self.topic, metric="publish_delta_ns", value=float(publish_delta_ns), + sample_seq=msg_seq, ) ) - def record_backpressure_wait(self, ts_ns: int, wait_ns: int) -> None: + def record_backpressure_wait( + self, ts_ns: int, wait_ns: int, msg_seq: int | None = None + ) -> None: self.backpressure_wait_ns_total += wait_ns self._backpressure_wait.add(ts_ns, wait_ns) if self.trace_enabled and self._trace_metric_enabled("backpressure_wait_ns"): @@ -152,6 +157,7 @@ def record_backpressure_wait(self, ts_ns: int, wait_ns: int) -> None: topic=self.topic, metric="backpressure_wait_ns", value=float(wait_ns), + sample_seq=msg_seq, ) ) @@ -202,7 +208,13 @@ class _SubscriberMetrics: def _trace_metric_enabled(self, metric: str) -> bool: return self.trace_metrics is None or metric in self.trace_metrics - def record_receive(self, ts_ns: int, lease_ns: int, channel_kind: ProfileChannelType) -> None: + def record_receive( + self, + ts_ns: int, + lease_ns: int, + channel_kind: ProfileChannelType, + msg_seq: int | None = None, + ) -> None: self.messages_received_total += 1 self.lease_time_ns_total += lease_ns self.channel_kind_last = channel_kind @@ -222,10 +234,13 @@ def record_receive(self, ts_ns: int, lease_ns: int, channel_kind: ProfileChannel metric="lease_time_ns", value=float(lease_ns), channel_kind=channel_kind, + sample_seq=msg_seq, ) ) - def record_user_span(self, ts_ns: int, span_ns: int, label: str | None) -> None: + def record_user_span( + self, ts_ns: int, span_ns: int, label: str | None, msg_seq: int | None = None + ) -> None: self.user_span_ns_total += span_ns self._user_span.add(ts_ns, span_ns) if self.trace_enabled and self._trace_metric_enabled("user_span_ns"): @@ -237,11 +252,16 @@ def record_user_span(self, ts_ns: int, span_ns: int, label: str | None) -> None: metric="user_span_ns", value=float(span_ns), channel_kind=self.channel_kind_last, + sample_seq=msg_seq, ) ) def record_attributed_backpressure( - self, ts_ns: int, duration_ns: int, channel_kind: ProfileChannelType + self, + ts_ns: int, + duration_ns: int, + channel_kind: ProfileChannelType, + msg_seq: int | None = None, ) -> None: self.attributable_backpressure_ns_total += duration_ns self.attributable_backpressure_events_total += 1 @@ -256,6 +276,7 @@ def record_attributed_backpressure( metric="attributable_backpressure_ns", value=float(duration_ns), channel_kind=channel_kind, + sample_seq=msg_seq, ) ) @@ -324,17 +345,21 @@ def register_subscriber(self, sub_id: UUID, topic: str) -> None: def unregister_subscriber(self, sub_id: UUID) -> None: self._subscribers.pop(sub_id, None) - def publisher_publish(self, pub_id: UUID, ts_ns: int, inflight: int) -> None: + def publisher_publish( + self, pub_id: UUID, ts_ns: int, inflight: int, msg_seq: int | None = None + ) -> None: self._expire_trace_control_if_needed(ts_ns) metric = self._publishers.get(pub_id) if metric is not None: - metric.record_publish(ts_ns, inflight) + metric.record_publish(ts_ns, inflight, msg_seq) - def publisher_backpressure_wait(self, pub_id: UUID, ts_ns: int, wait_ns: int) -> None: + def publisher_backpressure_wait( + self, pub_id: UUID, ts_ns: int, wait_ns: int, msg_seq: int | None = None + ) -> None: self._expire_trace_control_if_needed(ts_ns) metric = self._publishers.get(pub_id) if metric is not None: - metric.record_backpressure_wait(ts_ns, wait_ns) + metric.record_backpressure_wait(ts_ns, wait_ns, msg_seq) def publisher_sample_inflight(self, pub_id: UUID, ts_ns: int, inflight: int) -> None: metric = self._publishers.get(pub_id) @@ -347,19 +372,25 @@ def subscriber_receive( ts_ns: int, lease_ns: int, channel_kind: ProfileChannelType, + msg_seq: int | None = None, ) -> None: self._expire_trace_control_if_needed(ts_ns) metric = self._subscribers.get(sub_id) if metric is not None: - metric.record_receive(ts_ns, lease_ns, channel_kind) + metric.record_receive(ts_ns, lease_ns, channel_kind, msg_seq) def subscriber_user_span( - self, sub_id: UUID, ts_ns: int, span_ns: int, label: str | None + self, + sub_id: UUID, + ts_ns: int, + span_ns: int, + label: str | None, + msg_seq: int | None = None, ) -> None: self._expire_trace_control_if_needed(ts_ns) metric = self._subscribers.get(sub_id) if metric is not None: - metric.record_user_span(ts_ns, span_ns, label) + metric.record_user_span(ts_ns, span_ns, label, msg_seq) def subscriber_attributed_backpressure( self, @@ -367,11 +398,14 @@ def subscriber_attributed_backpressure( ts_ns: int, duration_ns: int, channel_kind: ProfileChannelType, + msg_seq: int | None = None, ) -> None: self._expire_trace_control_if_needed(ts_ns) metric = self._subscribers.get(sub_id) if metric is not None: - metric.record_attributed_backpressure(ts_ns, duration_ns, channel_kind) + metric.record_attributed_backpressure( + ts_ns, duration_ns, channel_kind, msg_seq + ) def snapshot(self) -> ProcessProfilingSnapshot: return ProcessProfilingSnapshot( diff --git a/src/ezmsg/core/pubclient.py b/src/ezmsg/core/pubclient.py index abc97641..cf5f2cee 100644 --- a/src/ezmsg/core/pubclient.py +++ b/src/ezmsg/core/pubclient.py @@ -447,7 +447,10 @@ async def broadcast(self, obj: Any) -> None: await self._backpressure.wait(buf_idx) wait_end_ns = PROFILE_TIME() PROFILES.publisher_backpressure_wait( - self.id, wait_end_ns, wait_end_ns - wait_start_ns + self.id, + wait_end_ns, + wait_end_ns - wait_start_ns, + msg_seq=self._msg_id, ) # Get local channel and put variable there for local tx @@ -526,5 +529,10 @@ async def broadcast(self, obj: Any) -> None: ) now_ns = PROFILE_TIME() - PROFILES.publisher_publish(self.id, now_ns, self._backpressure.pressure) + PROFILES.publisher_publish( + self.id, + now_ns, + self._backpressure.pressure, + msg_seq=self._msg_id, + ) self._msg_id += 1 diff --git a/src/ezmsg/core/subclient.py b/src/ezmsg/core/subclient.py index 0f650421..4b9f8719 100644 --- a/src/ezmsg/core/subclient.py +++ b/src/ezmsg/core/subclient.py @@ -130,6 +130,7 @@ def __init__( self._graph_address = graph_address self._channels = dict() + self._active_msg_seq: int | None = None if self.leaky: self._incoming = LeakyQueue( 1 if max_queue is None else max_queue, self._handle_dropped_notification @@ -301,17 +302,27 @@ async def recv_zero_copy(self) -> typing.AsyncGenerator[typing.Any, None]: channel = self._channels[pub_id] channel_kind = getattr(channel, "channel_kind", ProfileChannelType.UNKNOWN) - start_ns = PROFILE_TIME() - with channel.get(msg_id, self.id) as msg: - yield msg - end_ns = PROFILE_TIME() - PROFILES.subscriber_receive( - self.id, end_ns, end_ns - start_ns, channel_kind - ) + self._active_msg_seq = msg_id + try: + start_ns = PROFILE_TIME() + with channel.get(msg_id, self.id) as msg: + yield msg + end_ns = PROFILE_TIME() + PROFILES.subscriber_receive( + self.id, end_ns, end_ns - start_ns, channel_kind, msg_seq=msg_id + ) + finally: + self._active_msg_seq = None def begin_profile(self) -> int: return PROFILE_TIME() def end_profile(self, start_ns: int, label: str | None = None) -> None: end_ns = PROFILE_TIME() - PROFILES.subscriber_user_span(self.id, end_ns, end_ns - start_ns, label) + PROFILES.subscriber_user_span( + self.id, + end_ns, + end_ns - start_ns, + label, + msg_seq=self._active_msg_seq, + ) From 690e12ea50ec0ed5ce7e6bd13afe823e366738d8 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Fri, 20 Mar 2026 11:07:46 -0400 Subject: [PATCH 29/52] Improve high-rate profiling trace throughput and fairness --- examples/profiling_tui.py | 4 +- src/ezmsg/core/processclient.py | 4 +- src/ezmsg/core/profiling.py | 52 ++++++++++++--- tests/test_profiling_api.py | 109 ++++++++++++++++++++++++++++++++ 4 files changed, 155 insertions(+), 14 deletions(-) diff --git a/examples/profiling_tui.py b/examples/profiling_tui.py index 17582968..3f49556f 100644 --- a/examples/profiling_tui.py +++ b/examples/profiling_tui.py @@ -404,13 +404,13 @@ def _build_parser() -> argparse.ArgumentParser: parser.add_argument( "--trace-interval", type=float, - default=0.05, + default=0.02, help="Seconds between GraphServer trace stream batches", ) parser.add_argument( "--max-samples", type=int, - default=512, + default=5000, help="Max samples per process per streamed batch", ) parser.add_argument( diff --git a/src/ezmsg/core/processclient.py b/src/ezmsg/core/processclient.py index e84f1e46..97c12cec 100644 --- a/src/ezmsg/core/processclient.py +++ b/src/ezmsg/core/processclient.py @@ -66,10 +66,10 @@ def __init__(self, graph_address: AddressType | None = None) -> None: self._owned_units = set() self._trace_push_task = None self._trace_push_interval_s = float( - os.environ.get("EZMSG_PROFILE_TRACE_PUSH_INTERVAL_S", "0.05") + os.environ.get("EZMSG_PROFILE_TRACE_PUSH_INTERVAL_S", "0.02") ) self._trace_push_max_samples = int( - os.environ.get("EZMSG_PROFILE_TRACE_PUSH_MAX_SAMPLES", "1000") + os.environ.get("EZMSG_PROFILE_TRACE_PUSH_MAX_SAMPLES", "5000") ) def _require_client_id(self) -> UUID: diff --git a/src/ezmsg/core/profiling.py b/src/ezmsg/core/profiling.py index 77530242..e38ce891 100644 --- a/src/ezmsg/core/profiling.py +++ b/src/ezmsg/core/profiling.py @@ -1,6 +1,7 @@ import os import socket import time +import heapq from collections import deque from dataclasses import dataclass, field from typing import Callable, TypeAlias @@ -425,6 +426,9 @@ def snapshot(self) -> ProcessProfilingSnapshot: ) def set_trace_control(self, control: ProfilingTraceControl) -> None: + # Changing filters/mode should start from a clean trace buffer so new + # consumers do not receive stale samples from an old control scope. + self._clear_trace_samples() self._default_trace_control = control if control.enabled and control.ttl_seconds is not None: self._trace_control_expires_ns = PROFILE_TIME() + max( @@ -442,17 +446,39 @@ def set_trace_control(self, control: ProfilingTraceControl) -> None: def trace_batch(self, max_samples: int = 1000) -> ProcessProfilingTraceBatch: self._expire_trace_control_if_needed() samples: list[ProfilingTraceSample] = [] + limit = max(1, int(max_samples)) + + queues: list[deque[ProfilingTraceSample]] = [] for metric in self._publishers.values(): - while metric.trace_samples and len(samples) < max_samples: - samples.append(metric.trace_samples.popleft()) - if len(samples) >= max_samples: - break - if len(samples) < max_samples: - for metric in self._subscribers.values(): - while metric.trace_samples and len(samples) < max_samples: - samples.append(metric.trace_samples.popleft()) - if len(samples) >= max_samples: - break + if metric.trace_samples: + queues.append(metric.trace_samples) + for metric in self._subscribers.values(): + if metric.trace_samples: + queues.append(metric.trace_samples) + + if len(queues) == 1: + queue = queues[0] + while queue and len(samples) < limit: + samples.append(queue.popleft()) + elif len(queues) > 1: + heap: list[tuple[float, int, int]] = [] + for idx, queue in enumerate(queues): + sample = queue[0] + # Include sample_seq to keep deterministic ordering when timestamps tie. + seq = sample.sample_seq if sample.sample_seq is not None else -1 + heapq.heappush(heap, (sample.timestamp, seq, idx)) + + while heap and len(samples) < limit: + _timestamp, _seq, queue_idx = heapq.heappop(heap) + queue = queues[queue_idx] + if not queue: + continue + sample = queue.popleft() + samples.append(sample) + if queue: + nxt = queue[0] + nxt_seq = nxt.sample_seq if nxt.sample_seq is not None else -1 + heapq.heappush(heap, (nxt.timestamp, nxt_seq, queue_idx)) return ProcessProfilingTraceBatch( process_id=self._process_id, @@ -509,5 +535,11 @@ def _apply_trace_control_to_subscriber(self, metric: _SubscriberMetrics) -> None metric.trace_sample_mod = sample_mod metric.trace_metrics = trace_metrics + def _clear_trace_samples(self) -> None: + for metric in self._publishers.values(): + metric.trace_samples.clear() + for metric in self._subscribers.values(): + metric.trace_samples.clear() + PROFILES = ProfileRegistry() diff --git a/tests/test_profiling_api.py b/tests/test_profiling_api.py index fe00d732..2b07dd0a 100644 --- a/tests/test_profiling_api.py +++ b/tests/test_profiling_api.py @@ -428,3 +428,112 @@ async def test_process_profiling_trace_subscription_does_not_starve_peer_subscri await process.close() await ctx.__aexit__(None, None, None) graph_server.stop() + + +@pytest.mark.asyncio +async def test_process_profiling_trace_batch_interleaves_publisher_and_subscriber_samples(): + graph_server = GraphService().create_server() + address = graph_server.address + + ctx = GraphContext(address, auto_start=False) + await ctx.__aenter__() + + process = ProcessControlClient(address) + await process.connect() + await process.register(["SYS/U8"]) + + pub = await ctx.publisher("TOPIC_TRACE_MIX") + sub = await ctx.subscriber("TOPIC_TRACE_MIX") + + try: + response = await ctx.process_set_profiling_trace( + "SYS/U8", + ProfilingTraceControl( + enabled=True, + sample_mod=1, + publisher_topics=["TOPIC_TRACE_MIX"], + subscriber_topics=["TOPIC_TRACE_MIX"], + metrics=["publish_delta_ns", "lease_time_ns"], + ), + timeout=1.0, + ) + assert response.ok + + for idx in range(64): + await pub.broadcast(idx) + async with sub.recv_zero_copy() as _msg: + await asyncio.sleep(0) + + batch = await ctx.process_profiling_trace_batch( + "SYS/U8", max_samples=32, timeout=1.0 + ) + metrics = {sample.metric for sample in batch.samples} + assert "publish_delta_ns" in metrics + assert "lease_time_ns" in metrics + finally: + await process.close() + await ctx.__aexit__(None, None, None) + graph_server.stop() + + +@pytest.mark.asyncio +async def test_process_profiling_trace_control_change_clears_stale_trace_samples(): + graph_server = GraphService().create_server() + address = graph_server.address + + ctx = GraphContext(address, auto_start=False) + await ctx.__aenter__() + + process = ProcessControlClient(address) + await process.connect() + await process.register(["SYS/U9"]) + + pub_old = await ctx.publisher("TOPIC_TRACE_OLD") + sub_old = await ctx.subscriber("TOPIC_TRACE_OLD") + pub_new = await ctx.publisher("TOPIC_TRACE_NEW") + sub_new = await ctx.subscriber("TOPIC_TRACE_NEW") + + try: + old_response = await ctx.process_set_profiling_trace( + "SYS/U9", + ProfilingTraceControl( + enabled=True, + sample_mod=1, + publisher_topics=["TOPIC_TRACE_OLD"], + metrics=["publish_delta_ns"], + ), + timeout=1.0, + ) + assert old_response.ok + + for idx in range(12): + await pub_old.broadcast(idx) + async with sub_old.recv_zero_copy() as _msg: + await asyncio.sleep(0) + + new_response = await ctx.process_set_profiling_trace( + "SYS/U9", + ProfilingTraceControl( + enabled=True, + sample_mod=1, + publisher_topics=["TOPIC_TRACE_NEW"], + metrics=["publish_delta_ns"], + ), + timeout=1.0, + ) + assert new_response.ok + + for idx in range(8): + await pub_new.broadcast(idx) + async with sub_new.recv_zero_copy() as _msg: + await asyncio.sleep(0) + + batch = await ctx.process_profiling_trace_batch( + "SYS/U9", max_samples=256, timeout=1.0 + ) + assert len(batch.samples) > 0 + assert all(sample.topic == "TOPIC_TRACE_NEW" for sample in batch.samples) + finally: + await process.close() + await ctx.__aexit__(None, None, None) + graph_server.stop() From 553a353d428676dcdfad249bf7443c69ae233b9b Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Mon, 23 Mar 2026 17:10:43 -0400 Subject: [PATCH 30/52] address PROCESS_SETTINGS_UPDATE leakage --- src/ezmsg/core/graphserver.py | 102 ++++++++++++++++++++++++++- tests/test_settings_api.py | 129 ++++++++++++++++++++++++++++++---- 2 files changed, 215 insertions(+), 16 deletions(-) diff --git a/src/ezmsg/core/graphserver.py b/src/ezmsg/core/graphserver.py index a1d4e76e..db76aefb 100644 --- a/src/ezmsg/core/graphserver.py +++ b/src/ezmsg/core/graphserver.py @@ -96,6 +96,7 @@ class GraphServer(threading.Thread): _command_lock: asyncio.Lock _settings_current: dict[str, SettingsSnapshotValue] _settings_source_session: dict[str, UUID | None] + _settings_source_process: dict[str, UUID | None] _settings_events: list[SettingsChangedEvent] _settings_event_seq: int _settings_owned_by_session: dict[UUID, set[str]] @@ -126,6 +127,7 @@ def __init__(self, **kwargs) -> None: self._address = None self._settings_current = {} self._settings_source_session = {} + self._settings_source_process = {} self._settings_events = [] self._settings_event_seq = 0 self._settings_owned_by_session = {} @@ -707,6 +709,7 @@ async def _handle_process( process_id=self._process_key(process_client_id), ) ) + self._remove_settings_for_process_locked(process_client_id) if process_info is not None: source_process_id = self._process_key(process_client_id) self._profiling_trace_buffers.pop(source_process_id, None) @@ -1126,15 +1129,46 @@ async def _handle_process_settings_update_request( process_info = self._process_info(process_client_id) if process_info is None: return Command.ERROR.value + if update.component_address not in process_info.units: + metadata_owner = self._session_owner_for_component_locked( + update.component_address + ) + known_owner = self._process_owner_for_unit(update.component_address) + allow_startup_race = ( + len(process_info.units) == 0 and metadata_owner is not None + ) + if known_owner == process_client_id or allow_startup_race: + pass + else: + logger.warning( + "Process control %s settings update rejected for unowned component: %s", + process_client_id, + update.component_address, + ) + return Command.ERROR.value + else: + metadata_owner = self._session_owner_for_component_locked( + update.component_address + ) + + if metadata_owner is None: + source_session_id = self._settings_source_session.get( + update.component_address + ) + else: + source_session_id = metadata_owner source_process_id = self._process_key(process_client_id) self._settings_current[update.component_address] = update.value - self._settings_source_session[update.component_address] = None + self._settings_source_session[update.component_address] = source_session_id + self._settings_source_process[update.component_address] = source_process_id self._append_settings_event_locked( event_type=SettingsEventType.SETTINGS_UPDATED, component_address=update.component_address, value=update.value, - source_session_id=None, + source_session_id=( + str(source_session_id) if source_session_id is not None else None + ), source_process_id=source_process_id, timestamp=update.timestamp, ) @@ -1386,6 +1420,69 @@ def _remove_settings_for_session_locked(self, session_id: UUID) -> None: if self._settings_source_session.get(component_address) == session_id: self._settings_current.pop(component_address, None) self._settings_source_session.pop(component_address, None) + self._settings_source_process.pop(component_address, None) + + def _session_owner_for_component_locked(self, component_address: str) -> UUID | None: + for client_id, info in self.clients.items(): + if not isinstance(info, SessionInfo): + continue + if info.metadata is None: + continue + if component_address in info.metadata.components: + return client_id + return None + + def _initial_settings_for_component_locked( + self, session_id: UUID, component_address: str + ) -> SettingsSnapshotValue | None: + session = self._session_info(session_id) + if session is None or session.metadata is None: + return None + component = session.metadata.components.get(component_address) + if component is None: + return None + initial_repr = component.initial_settings[1] + return SettingsSnapshotValue( + serialized=component.initial_settings[0], + repr_value=initial_repr, + structured_value=initial_repr if isinstance(initial_repr, dict) else None, + settings_schema=component.settings_schema, + ) + + def _remove_settings_for_process_locked(self, process_client_id: UUID) -> None: + source_process_id = self._process_key(process_client_id) + component_addresses = [ + component_address + for component_address, owner_process_id in self._settings_source_process.items() + if owner_process_id == source_process_id + ] + + for component_address in component_addresses: + source_session_id = self._settings_source_session.get(component_address) + if source_session_id is None: + self._settings_current.pop(component_address, None) + self._settings_source_session.pop(component_address, None) + self._settings_source_process.pop(component_address, None) + continue + + restored = self._initial_settings_for_component_locked( + source_session_id, component_address + ) + if restored is None: + self._settings_current.pop(component_address, None) + self._settings_source_session.pop(component_address, None) + self._settings_source_process.pop(component_address, None) + continue + + self._settings_current[component_address] = restored + self._settings_source_process[component_address] = None + self._append_settings_event_locked( + event_type=SettingsEventType.SETTINGS_UPDATED, + component_address=component_address, + value=restored, + source_session_id=str(source_session_id), + source_process_id=None, + ) def _apply_session_metadata_settings_locked( self, session_id: UUID, metadata: GraphMetadata @@ -1401,6 +1498,7 @@ def _apply_session_metadata_settings_locked( ) self._settings_current[component.address] = value self._settings_source_session[component.address] = session_id + self._settings_source_process[component.address] = None session_components.add(component.address) self._append_settings_event_locked( event_type=SettingsEventType.INITIAL_SETTINGS, diff --git a/tests/test_settings_api.py b/tests/test_settings_api.py index 55f4b8f6..effbe08a 100644 --- a/tests/test_settings_api.py +++ b/tests/test_settings_api.py @@ -128,26 +128,27 @@ async def observe() -> None: try: settings = await observer.settings_snapshot() sink_address = "SYS/SINK" - assert sink_address in settings - assert settings[sink_address].repr_value == {"gain": 7} - assert settings[sink_address].structured_value == {"gain": 7} - assert settings[sink_address].settings_schema is not None - schema = settings[sink_address].settings_schema - assert schema is not None - assert schema.provider == "dataclass" - assert any( - field.name == "gain" and "int" in field.field_type.lower() - for field in schema.fields - ) + # Process-owned settings are cleaned up when the process exits. + assert sink_address not in settings events = await observer.settings_events(after_seq=0) matching = [ event for event in events if event.component_address == sink_address - and event.event_type == SettingsEventType.SETTINGS_UPDATED - ] + and event.event_type == SettingsEventType.SETTINGS_UPDATED + and event.value.repr_value == {"gain": 7} + ] assert matching + latest = matching[-1].value + assert latest.structured_value == {"gain": 7} + assert latest.settings_schema is not None + schema = latest.settings_schema + assert schema.provider == "dataclass" + assert any( + field.name == "gain" and "int" in field.field_type.lower() + for field in schema.fields + ) finally: await observer.__aexit__(None, None, None) @@ -187,7 +188,17 @@ async def test_graphcontext_update_settings_via_input_settings_topic(): await asyncio.wait_for(run_task, timeout=5.0) settings = await observer.settings_snapshot() - assert settings["SYS/SINK"].repr_value == {"gain": 11} + assert "SYS/SINK" not in settings + + events = await observer.settings_events(after_seq=0) + matching = [ + event + for event in events + if event.component_address == "SYS/SINK" + and event.event_type == SettingsEventType.SETTINGS_UPDATED + and event.value.repr_value == {"gain": 11} + ] + assert matching finally: if not run_task.done(): @@ -365,3 +376,93 @@ async def test_metadata_registration_rejects_component_address_collision(): await owner_a.__aexit__(None, None, None) await owner_b.__aexit__(None, None, None) graph_server.stop() + + +@pytest.mark.asyncio +async def test_process_owned_settings_removed_when_process_disconnects_without_session_owner(): + graph_server = GraphService().create_server() + address = graph_server.address + + observer = GraphContext(address, auto_start=False) + await observer.__aenter__() + + process = ProcessControlClient(address) + await process.connect() + await process.register(["SYS/UNIT_ORPHAN"]) + + try: + await process.report_settings_update( + component_address="SYS/UNIT_ORPHAN", + value=SettingsSnapshotValue(serialized=None, repr_value={"gain": 5}), + ) + settings = await observer.settings_snapshot() + assert settings["SYS/UNIT_ORPHAN"].repr_value == {"gain": 5} + finally: + await process.close() + + await asyncio.sleep(0.05) + settings = await observer.settings_snapshot() + assert "SYS/UNIT_ORPHAN" not in settings + + await observer.__aexit__(None, None, None) + graph_server.stop() + + +@pytest.mark.asyncio +async def test_process_disconnect_restores_metadata_initial_settings(): + graph_server = GraphService().create_server() + address = graph_server.address + + owner = GraphContext(address, auto_start=False) + observer = GraphContext(address, auto_start=False) + await owner.__aenter__() + await observer.__aenter__() + await owner.register_metadata(_metadata_with_component("SYS/UNIT_RESTORE")) + + process = ProcessControlClient(address) + await process.connect() + await process.register(["SYS/UNIT_RESTORE"]) + + try: + await process.report_settings_update( + component_address="SYS/UNIT_RESTORE", + value=SettingsSnapshotValue(serialized=None, repr_value={"alpha": 9}), + ) + settings = await observer.settings_snapshot() + assert settings["SYS/UNIT_RESTORE"].repr_value == {"alpha": 9} + finally: + await process.close() + + await asyncio.sleep(0.05) + settings = await observer.settings_snapshot() + assert settings["SYS/UNIT_RESTORE"].repr_value == {"alpha": 1} + + await owner.__aexit__(None, None, None) + await observer.__aexit__(None, None, None) + graph_server.stop() + + +@pytest.mark.asyncio +async def test_process_settings_update_rejected_for_unowned_component(): + graph_server = GraphService().create_server() + address = graph_server.address + + observer = GraphContext(address, auto_start=False) + await observer.__aenter__() + + process = ProcessControlClient(address) + await process.connect() + await process.register(["SYS/UNIT_OWNED"]) + + try: + with pytest.raises(RuntimeError, match="Process control command failed"): + await process.report_settings_update( + component_address="SYS/UNIT_UNOWNED", + value=SettingsSnapshotValue(serialized=None, repr_value={"gain": 7}), + ) + settings = await observer.settings_snapshot() + assert "SYS/UNIT_UNOWNED" not in settings + finally: + await process.close() + await observer.__aexit__(None, None, None) + graph_server.stop() From 7efabd604e7aebb7e2085b51de358d0b1c166519 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Tue, 24 Mar 2026 12:46:13 -0400 Subject: [PATCH 31/52] profiling age-out on snapshot --- src/ezmsg/core/profiling.py | 57 +++++++++++++++++++++++++------------ tests/test_profiling_api.py | 40 ++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 18 deletions(-) diff --git a/src/ezmsg/core/profiling.py b/src/ezmsg/core/profiling.py index e38ce891..2bce4f8c 100644 --- a/src/ezmsg/core/profiling.py +++ b/src/ezmsg/core/profiling.py @@ -39,7 +39,7 @@ class _Rolling: max_value: list[int] = field(default_factory=list) _num_buckets: int = 0 _bucket_ns: int = 0 - _last_bucket: int | None = None + _last_bucket_tick: int | None = None def __post_init__(self) -> None: self._num_buckets = max(1, int(self.window_seconds / self.bucket_seconds)) @@ -48,28 +48,39 @@ def __post_init__(self) -> None: self.value_sum = [0 for _ in range(self._num_buckets)] self.max_value = [0 for _ in range(self._num_buckets)] + def _bucket_tick(self, ts_ns: int) -> int: + return ts_ns // self._bucket_ns + def _bucket(self, ts_ns: int) -> int: - return (ts_ns // self._bucket_ns) % self._num_buckets + return self._bucket_tick(ts_ns) % self._num_buckets + + def _reset_bucket(self, idx: int) -> None: + self.count[idx] = 0 + self.value_sum[idx] = 0 + self.max_value[idx] = 0 def _advance(self, ts_ns: int) -> int: - bucket = self._bucket(ts_ns) - if self._last_bucket is None: - self._last_bucket = bucket + bucket_tick = self._bucket_tick(ts_ns) + bucket = bucket_tick % self._num_buckets + if self._last_bucket_tick is None: + self._last_bucket_tick = bucket_tick return bucket - if bucket == self._last_bucket: + if bucket_tick <= self._last_bucket_tick: return bucket - idx = (self._last_bucket + 1) % self._num_buckets - while idx != bucket: - self.count[idx] = 0 - self.value_sum[idx] = 0 - self.max_value[idx] = 0 - idx = (idx + 1) % self._num_buckets - self.count[bucket] = 0 - self.value_sum[bucket] = 0 - self.max_value[bucket] = 0 - self._last_bucket = bucket + elapsed_buckets = bucket_tick - self._last_bucket_tick + if elapsed_buckets >= self._num_buckets: + for idx in range(self._num_buckets): + self._reset_bucket(idx) + else: + previous_bucket = self._last_bucket_tick % self._num_buckets + for step in range(1, elapsed_buckets + 1): + self._reset_bucket((previous_bucket + step) % self._num_buckets) + self._last_bucket_tick = bucket_tick return bucket + def advance_to(self, ts_ns: int) -> None: + self._advance(ts_ns) + def add(self, ts_ns: int, value: int) -> None: idx = self._advance(ts_ns) self.count[idx] += 1 @@ -167,6 +178,11 @@ def sample_inflight(self, ts_ns: int, inflight: int) -> None: self._inflight.add(ts_ns, inflight) def snapshot(self) -> PublisherProfileSnapshot: + now_ns = PROFILE_TIME() + self._publish_delta.advance_to(now_ns) + self._publish_count.advance_to(now_ns) + self._backpressure_wait.advance_to(now_ns) + self._inflight.advance_to(now_ns) window_msgs = self._publish_count.count_total() return PublisherProfileSnapshot( endpoint_id=self.endpoint_id, @@ -180,7 +196,7 @@ def snapshot(self) -> PublisherProfileSnapshot: inflight_messages_peak_window=self._inflight.max_total(), backpressure_wait_ns_total=self.backpressure_wait_ns_total, backpressure_wait_ns_window=self._backpressure_wait.sum_total(), - timestamp=float(PROFILE_TIME()), + timestamp=float(now_ns), ) @@ -282,6 +298,11 @@ def record_attributed_backpressure( ) def snapshot(self) -> SubscriberProfileSnapshot: + now_ns = PROFILE_TIME() + self._recv_count.advance_to(now_ns) + self._lease_time.advance_to(now_ns) + self._user_span.advance_to(now_ns) + self._attrib_bp.advance_to(now_ns) recv_count = self._recv_count.count_total() user_count = self._user_span.count_total() return SubscriberProfileSnapshot( @@ -301,7 +322,7 @@ def snapshot(self) -> SubscriberProfileSnapshot: attributable_backpressure_ns_window=self._attrib_bp.sum_total(), attributable_backpressure_events_total=self.attributable_backpressure_events_total, channel_kind_last=self.channel_kind_last, - timestamp=float(PROFILE_TIME()), + timestamp=float(now_ns), ) diff --git a/tests/test_profiling_api.py b/tests/test_profiling_api.py index 2b07dd0a..317faa8b 100644 --- a/tests/test_profiling_api.py +++ b/tests/test_profiling_api.py @@ -2,6 +2,7 @@ import pytest +from ezmsg.core import profiling as profiling_core from ezmsg.core.graphcontext import GraphContext from ezmsg.core.graphmeta import ( ProcessControlErrorCode, @@ -12,6 +13,45 @@ from ezmsg.core.processclient import ProcessControlClient +def test_profiling_windows_age_out_during_idle_snapshots(monkeypatch: pytest.MonkeyPatch): + publisher = profiling_core._PublisherMetrics( + topic="TOPIC_IDLE", + endpoint_id="TOPIC_IDLE:ep1", + num_buffers=4, + ) + subscriber = profiling_core._SubscriberMetrics( + topic="TOPIC_IDLE", + endpoint_id="TOPIC_IDLE:sub1", + ) + + publisher.record_publish(0, inflight=0) + publisher.record_publish(int(0.1e9), inflight=0) + subscriber.record_receive( + int(0.1e9), + lease_ns=int(0.2e6), + channel_kind=profiling_core.ProfileChannelType.LOCAL, + ) + + now_ns = {"value": int(0.2e9)} + monkeypatch.setattr(profiling_core, "PROFILE_TIME", lambda: now_ns["value"]) + + active_pub = publisher.snapshot() + active_sub = subscriber.snapshot() + assert active_pub.messages_published_window == 2 + assert active_pub.publish_rate_hz_window > 0.0 + assert active_sub.messages_received_window == 1 + + now_ns["value"] = int(30e9) + idle_pub = publisher.snapshot() + idle_sub = subscriber.snapshot() + assert idle_pub.messages_published_window == 0 + assert idle_pub.publish_rate_hz_window == 0.0 + assert idle_pub.backpressure_wait_ns_window == 0 + assert idle_sub.messages_received_window == 0 + assert idle_sub.attributable_backpressure_ns_window == 0 + assert idle_sub.lease_time_ns_avg_window == 0.0 + + @pytest.mark.asyncio async def test_process_profiling_snapshot_collects_pub_sub_metrics(): graph_server = GraphService().create_server() From 32924aa8c4e544e6fa4f45bd442eeddd8be4ce51 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Mon, 30 Mar 2026 11:28:23 -0400 Subject: [PATCH 32/52] better global topic output --- examples/ezmsg_toy.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/ezmsg_toy.py b/examples/ezmsg_toy.py index 69f7141c..154d401c 100644 --- a/examples/ezmsg_toy.py +++ b/examples/ezmsg_toy.py @@ -178,6 +178,8 @@ class TestSystemSettings(ez.Settings): class TestSystem(ez.Collection): SETTINGS = TestSystemSettings + OUTPUT_PING = ez.OutputTopic(str) + # Publishers PING = MessageGenerator() FOO = MessageGenerator() @@ -199,6 +201,7 @@ def configure(self) -> None: # Define Connections def network(self) -> ez.NetworkDefinition: return ( + (self.PING.OUTPUT, self.OUTPUT_PING), (self.PING.OUTPUT, self.PINGSUB1.INPUT), (self.PING.OUTPUT, self.MODIFIER_COLLECTION.INPUT), (self.MODIFIER_COLLECTION.OUTPUT, self.PINGSUB2.INPUT), @@ -219,7 +222,7 @@ def process_components(self): ez.run( SYSTEM=system, connections=[ - # Make PING.OUTPUT available on a topic ezmsg_attach.py - (system.PING.OUTPUT, "GLOBAL_PING_TOPIC"), + # Make a system output available on a topic ezmsg_attach.py + (system.OUTPUT_PING, "GLOBAL_PING_TOPIC"), ], ) From 1b84fed151e8726eef27d32376d4710f2e861e78 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Wed, 1 Apr 2026 12:01:10 -0400 Subject: [PATCH 33/52] quicker/better hotpath and a/b testing --- scripts/perf_ab.py | 5 + src/ezmsg/core/backendprocess.py | 1 + src/ezmsg/core/pubclient.py | 95 +++++-- src/ezmsg/core/stream.py | 20 +- src/ezmsg/util/perf/ab.py | 446 +++++++++++++++++++++++++++++++ src/ezmsg/util/perf/command.py | 4 + src/ezmsg/util/perf/hotpath.py | 437 ++++++++++++++++++++++++++++++ tests/test_perf_ab.py | 62 +++++ tests/test_perf_hotpath.py | 45 ++++ tests/test_pubclient.py | 203 ++++++++++++++ 10 files changed, 1294 insertions(+), 24 deletions(-) create mode 100644 scripts/perf_ab.py create mode 100644 src/ezmsg/util/perf/ab.py create mode 100644 src/ezmsg/util/perf/hotpath.py create mode 100644 tests/test_perf_ab.py create mode 100644 tests/test_perf_hotpath.py create mode 100644 tests/test_pubclient.py diff --git a/scripts/perf_ab.py b/scripts/perf_ab.py new file mode 100644 index 00000000..5e7a8b8d --- /dev/null +++ b/scripts/perf_ab.py @@ -0,0 +1,5 @@ +from ezmsg.util.perf.ab import main + + +if __name__ == "__main__": + main() diff --git a/src/ezmsg/core/backendprocess.py b/src/ezmsg/core/backendprocess.py index 8ef22a72..f3ab935d 100644 --- a/src/ezmsg/core/backendprocess.py +++ b/src/ezmsg/core/backendprocess.py @@ -299,6 +299,7 @@ async def setup_state(): buf_size=stream.buf_size, start_paused=True, force_tcp=stream.force_tcp, + allow_local=stream.allow_local, ), loop=loop, ).result() diff --git a/src/ezmsg/core/pubclient.py b/src/ezmsg/core/pubclient.py index f9c42952..86cf320e 100644 --- a/src/ezmsg/core/pubclient.py +++ b/src/ezmsg/core/pubclient.py @@ -37,6 +37,36 @@ BACKPRESSURE_WARNING = "EZMSG_DISABLE_BACKPRESSURE_WARNING" not in os.environ BACKPRESSURE_REFRACTORY = 5.0 # sec +ALLOW_LOCAL_ENV = "EZMSG_ALLOW_LOCAL" +FORCE_TCP_ENV = "EZMSG_FORCE_TCP" + + +def _process_allow_local_default() -> bool: + value = os.environ.get(ALLOW_LOCAL_ENV, "") + if value == "": + return True + return value.lower() in ("1", "true", "yes", "on") + + +def _process_force_tcp_default() -> bool: + value = os.environ.get(FORCE_TCP_ENV, "") + if value == "": + return False + return value.lower() in ("1", "true", "yes", "on") + + +def _resolve_force_tcp(force_tcp: bool | None) -> bool: + if force_tcp is None: + return _process_force_tcp_default() + return force_tcp + + +def _resolve_allow_local(force_tcp: bool, allow_local: bool | None) -> bool: + resolved = _process_allow_local_default() if allow_local is None else allow_local + if force_tcp and resolved: + logger.info("force_tcp=True disables local delivery for this publisher") + return False + return resolved # Publisher needs a bit more information about connected channels @@ -75,6 +105,7 @@ class Publisher: _msg_id: int _shm: SHMContext _force_tcp: bool + _allow_local: bool _last_backpressure_event: float _graph_address: AddressType | None @@ -99,7 +130,8 @@ async def create( buf_size: int = DEFAULT_SHM_SIZE, num_buffers: int = 32, start_paused: bool = False, - force_tcp: bool = False, + force_tcp: bool | None = None, + allow_local: bool | None = None, ) -> "Publisher": """ Create a new Publisher instance and register it with the graph server. @@ -116,6 +148,16 @@ async def create( :type port: int | None :param buf_size: Size of shared memory buffers. :type buf_size: int + :param force_tcp: Whether to force TCP transport instead of shared memory. + If None, inherit the process default from ``EZMSG_FORCE_TCP`` which + defaults to disabled. + :type force_tcp: bool | None + :param allow_local: Whether to allow the in-process fast path when available. + If None, inherit the process default from ``EZMSG_ALLOW_LOCAL`` which + defaults to enabled. Set to False to bypass local delivery and + characterize same-process SHM or TCP. When ``force_tcp=True``, local + delivery is disabled regardless of this value. + :type allow_local: bool | None :param kwargs: Additional keyword arguments for Publisher constructor. :return: Initialized and registered Publisher instance. :rtype: Publisher @@ -127,6 +169,8 @@ async def create( writer.write(Command.PUBLISH.value) writer.write(encode_str(topic)) + resolved_force_tcp = _resolve_force_tcp(force_tcp) + pub_id = UUID(await read_str(reader)) pub = cls( id=pub_id, @@ -135,7 +179,8 @@ async def create( graph_address=graph_address, num_buffers=num_buffers, start_paused=start_paused, - force_tcp=force_tcp, + force_tcp=resolved_force_tcp, + allow_local=allow_local, _guard=cls._SENTINEL, ) @@ -189,7 +234,8 @@ def __init__( graph_address: AddressType | None = None, num_buffers: int = 32, start_paused: bool = False, - force_tcp: bool = False, + force_tcp: bool | None = None, + allow_local: bool | None = None, _guard = None ) -> None: """ @@ -207,7 +253,12 @@ def __init__( :param start_paused: Whether to start in paused state. :type start_paused: bool :param force_tcp: Whether to force TCP transport instead of shared memory. - :type force_tcp: bool + If None, inherit the process default from ``EZMSG_FORCE_TCP``. + :type force_tcp: bool | None + :param allow_local: Whether to allow the direct in-process fast path when available. + If None, inherit the process default from ``EZMSG_ALLOW_LOCAL``. + When ``force_tcp=True``, local delivery is disabled regardless of this value. + :type allow_local: bool | None """ if _guard is not self._SENTINEL: raise TypeError( @@ -227,7 +278,8 @@ def __init__( self._running.set() self._num_buffers = num_buffers self._backpressure = Backpressure(num_buffers) - self._force_tcp = force_tcp + self._force_tcp = _resolve_force_tcp(force_tcp) + self._allow_local = _resolve_allow_local(self._force_tcp, allow_local) self._last_backpressure_event = -1 self._graph_address = graph_address @@ -436,12 +488,10 @@ async def broadcast(self, obj: Any) -> None: self._last_backpressure_event = time.time() await self._backpressure.wait(buf_idx) - # Get local channel and put variable there for local tx - self._local_channel.put_local(self._msg_id, obj) + if self._should_use_local_fast_path(): + self._local_channel.put_local(self._msg_id, obj) - if self._force_tcp or any( - ch.pid != self.pid or not ch.shm_ok for ch in self._channels.values() - ): + if any(not self._can_deliver_locally(ch) for ch in self._channels.values()): with MessageMarshal.serialize(self._msg_id, obj) as ( total_size, header, @@ -449,9 +499,7 @@ async def broadcast(self, obj: Any) -> None: ): total_size_bytes = uint64_to_bytes(total_size) - if not self._force_tcp and any( - ch.pid != self.pid and ch.shm_ok for ch in self._channels.values() - ): + if any(self._can_deliver_via_shm(ch) for ch in self._channels.values()): if self._shm.buf_size < total_size: new_shm = await GraphService(self._graph_address).create_shm( self._num_buffers, total_size * 2 @@ -475,14 +523,10 @@ async def broadcast(self, obj: Any) -> None: for channel in self._channels.values(): msg: bytes = b"" - if self.pid == channel.pid and channel.shm_ok: + if self._can_deliver_locally(channel): continue # Local transmission handled by channel.put - elif ( - (not self._force_tcp) - and self.pid != channel.pid - and channel.shm_ok - ): + elif self._can_deliver_via_shm(channel): msg = ( Command.TX_SHM.value + msg_id_bytes @@ -509,3 +553,16 @@ async def broadcast(self, obj: Any) -> None: ) self._msg_id += 1 + + def _should_use_local_fast_path(self) -> bool: + return any(self._can_deliver_locally(ch) for ch in self._channels.values()) + + def _can_deliver_locally(self, channel: PubChannelInfo) -> bool: + return self._allow_local and self.pid == channel.pid and channel.shm_ok + + def _can_deliver_via_shm(self, channel: PubChannelInfo) -> bool: + return ( + (not self._force_tcp) + and channel.shm_ok + and not self._can_deliver_locally(channel) + ) diff --git a/src/ezmsg/core/stream.py b/src/ezmsg/core/stream.py index c719c92a..79089d86 100644 --- a/src/ezmsg/core/stream.py +++ b/src/ezmsg/core/stream.py @@ -124,15 +124,20 @@ class OutputStream(Stream): :type num_buffers: int :param buf_size: Size of each message buffer in bytes :type buf_size: int - :param force_tcp: Whether to force TCP transport instead of shared memory - :type force_tcp: bool + :param force_tcp: Whether to force TCP transport instead of shared memory. + If None, inherit the process default from ``EZMSG_FORCE_TCP``. + :type force_tcp: bool | None + :param allow_local: Whether to allow the in-process fast path when available. + If None, inherit the process default from ``EZMSG_ALLOW_LOCAL``. + :type allow_local: bool | None """ host: str | None port: int | None num_buffers: int buf_size: int - force_tcp: bool + force_tcp: bool | None + allow_local: bool | None def __init__( self, @@ -141,7 +146,8 @@ def __init__( port: int | None = None, num_buffers: int = 32, buf_size: int = DEFAULT_SHM_SIZE, - force_tcp: bool = False, + force_tcp: bool | None = None, + allow_local: bool | None = None, ) -> None: super().__init__(msg_type) self.host = host @@ -149,7 +155,11 @@ def __init__( self.num_buffers = num_buffers self.buf_size = buf_size self.force_tcp = force_tcp + self.allow_local = allow_local def __repr__(self) -> str: preamble = f"Output{super().__repr__()}" - return f"{preamble}({self.num_buffers=}, {self.force_tcp=})" + return ( + f"{preamble}({self.num_buffers=}, {self.force_tcp=}, " + f"{self.allow_local=})" + ) diff --git a/src/ezmsg/util/perf/ab.py b/src/ezmsg/util/perf/ab.py new file mode 100644 index 00000000..5a2da621 --- /dev/null +++ b/src/ezmsg/util/perf/ab.py @@ -0,0 +1,446 @@ +from __future__ import annotations + +import argparse +import contextlib +import json +import os +import random +import shutil +import subprocess +import sys +import tempfile + +from dataclasses import asdict, dataclass +from pathlib import Path + + +DEFAULT_PAIR_SEED = 0 + + +@dataclass(frozen=True) +class ABCaseSummary: + case_id: str + a_us_per_message_median: float + b_us_per_message_median: float + delta_pct_median: float + delta_pct_mean: float + pair_count: int + b_faster_pairs: int + + +@dataclass(frozen=True) +class ABRunSummary: + ref_a: str + ref_b: str + rounds: int + seed: int + cases: list[ABCaseSummary] + + +def build_pair_order(rounds: int, seed: int) -> list[tuple[str, str]]: + base = [("A", "B"), ("B", "A")] * ((rounds + 1) // 2) + order = base[:rounds] + random.Random(seed).shuffle(order) + return order + + +def _hotpath_json_arg(path: Path) -> list[str]: + return ["--json-out", str(path)] + + +def build_hotpath_command( + output_path: Path, + count: int, + warmup: int, + payload_sizes: list[int], + transports: list[str], + apis: list[str], + num_buffers: int, + quiet: bool, +) -> list[str]: + cmd = [ + "uv", + "run", + "python", + "-m", + "ezmsg.util.perf.hotpath", + "--count", + str(count), + "--warmup", + str(warmup), + "--samples", + "1", + "--num-buffers", + str(num_buffers), + "--payload-sizes", + *[str(payload_size) for payload_size in payload_sizes], + "--transports", + *transports, + "--apis", + *apis, + *_hotpath_json_arg(output_path), + ] + if quiet: + cmd.append("--quiet") + return cmd + + +def load_hotpath_summary(path: Path) -> dict[str, float]: + payload = json.loads(path.read_text()) + return { + entry["case_id"]: float(entry["summary"]["us_per_message_median"]) + for entry in payload["results"] + } + + +def summarize_ab_results( + ref_a: str, + ref_b: str, + rounds: int, + seed: int, + paired_runs: list[tuple[dict[str, float], dict[str, float]]], +) -> ABRunSummary: + case_ids = sorted(paired_runs[0][0].keys()) + cases: list[ABCaseSummary] = [] + + for case_id in case_ids: + a_values = [pair[0][case_id] for pair in paired_runs] + b_values = [pair[1][case_id] for pair in paired_runs] + deltas = [((b / a) - 1.0) * 100.0 for a, b in zip(a_values, b_values)] + cases.append( + ABCaseSummary( + case_id=case_id, + a_us_per_message_median=_median(a_values), + b_us_per_message_median=_median(b_values), + delta_pct_median=_median(deltas), + delta_pct_mean=sum(deltas) / len(deltas), + pair_count=len(deltas), + b_faster_pairs=sum(1 for a, b in zip(a_values, b_values) if b < a), + ) + ) + + return ABRunSummary( + ref_a=ref_a, + ref_b=ref_b, + rounds=rounds, + seed=seed, + cases=cases, + ) + + +def _median(values: list[float]) -> float: + ordered = sorted(values) + mid = len(ordered) // 2 + if len(ordered) % 2: + return ordered[mid] + return (ordered[mid - 1] + ordered[mid]) / 2.0 + + +def _run_checked(cmd: list[str], cwd: Path) -> None: + env = os.environ.copy() + env.pop("VIRTUAL_ENV", None) + completed = subprocess.run(cmd, cwd=cwd, capture_output=True, text=True, env=env) + if completed.returncode == 0: + return + + raise RuntimeError( + f"Command failed in {cwd}:\n" + f"$ {' '.join(cmd)}\n\n" + f"stdout:\n{completed.stdout}\n" + f"stderr:\n{completed.stderr}" + ) + + +def _is_current_ref(ref: str) -> bool: + return ref.upper() == "CURRENT" + + +@contextlib.contextmanager +def _provision_tree( + repo_root: Path, + ref: str, + label: str, + keep: bool, +) -> Path: + if _is_current_ref(ref): + yield repo_root + return + + parent = Path(tempfile.mkdtemp(prefix=f"ezmsg-perf-{label.lower()}-")) + tree_path = parent / "tree" + _run_checked( + ["git", "worktree", "add", "--detach", str(tree_path), ref], + cwd=repo_root, + ) + + try: + yield tree_path + finally: + if keep: + return + try: + _run_checked(["git", "worktree", "remove", "--force", str(tree_path)], cwd=repo_root) + finally: + shutil.rmtree(parent, ignore_errors=True) + + +def _maybe_sync(tree: Path) -> None: + _run_checked(["uv", "sync", "--group", "dev"], cwd=tree) + + +def _mirror_hotpath_module(source_root: Path, target_tree: Path) -> None: + source = source_root / "src" / "ezmsg" / "util" / "perf" / "hotpath.py" + target = target_tree / "src" / "ezmsg" / "util" / "perf" / "hotpath.py" + shutil.copy2(source, target) + + +def _ensure_json_files_match( + left: dict[str, float], + right: dict[str, float], + label_left: str, + label_right: str, +) -> None: + if left.keys() == right.keys(): + return + + raise RuntimeError( + f"Benchmark cases differ between {label_left} and {label_right}: " + f"{sorted(left.keys())} != {sorted(right.keys())}" + ) + + +def _print_summary(summary: ABRunSummary) -> None: + print( + f"Interleaved hot-path comparison: A={summary.ref_a}, " + f"B={summary.ref_b}, rounds={summary.rounds}, seed={summary.seed}" + ) + for case in summary.cases: + sign = "regression" if case.delta_pct_median > 0 else "improvement" + print( + f"{case.case_id:<36} " + f"A={case.a_us_per_message_median:>10.2f} us/msg " + f"B={case.b_us_per_message_median:>10.2f} us/msg " + f"delta={case.delta_pct_median:>7.2f}% ({sign}) " + f"wins={case.b_faster_pairs}/{case.pair_count}" + ) + + +def dump_ab_json(summary: ABRunSummary, path: Path) -> None: + payload = { + "suite": "hotpath-ab", + "ref_a": summary.ref_a, + "ref_b": summary.ref_b, + "rounds": summary.rounds, + "seed": summary.seed, + "cases": [asdict(case) for case in summary.cases], + } + path.write_text(json.dumps(payload, indent=2) + "\n") + + +def perf_ab( + ref_a: str, + ref_b: str, + rounds: int, + count: int, + warmup: int, + prewarm: int, + payload_sizes: list[int], + transports: list[str], + apis: list[str], + num_buffers: int, + seed: int, + json_out: Path | None, + keep_worktrees: bool, + sync: bool, + quiet: bool, +) -> None: + if rounds <= 0: + raise ValueError("rounds must be > 0") + + repo_root = Path( + subprocess.run( + ["git", "rev-parse", "--show-toplevel"], + check=True, + capture_output=True, + text=True, + ).stdout.strip() + ) + pair_order = build_pair_order(rounds, seed) + + with _provision_tree(repo_root, ref_a, "A", keep_worktrees) as tree_a: + with _provision_tree(repo_root, ref_b, "B", keep_worktrees) as tree_b: + if tree_a != repo_root: + _mirror_hotpath_module(repo_root, tree_a) + if tree_b != repo_root: + _mirror_hotpath_module(repo_root, tree_b) + + if sync: + _maybe_sync(tree_a) + if tree_b != tree_a: + _maybe_sync(tree_b) + + with tempfile.TemporaryDirectory(prefix="ezmsg-perf-ab-runs-") as tmpdir_name: + tmpdir = Path(tmpdir_name) + cmd_by_label = { + "A": lambda path: build_hotpath_command( + path, + count=count, + warmup=warmup, + payload_sizes=payload_sizes, + transports=transports, + apis=apis, + num_buffers=num_buffers, + quiet=quiet, + ), + "B": lambda path: build_hotpath_command( + path, + count=count, + warmup=warmup, + payload_sizes=payload_sizes, + transports=transports, + apis=apis, + num_buffers=num_buffers, + quiet=quiet, + ), + } + tree_by_label = {"A": tree_a, "B": tree_b} + + for idx in range(prewarm): + for label in ("A", "B"): + if label == "B" and tree_b == tree_a: + continue + warm_path = tmpdir / f"warm-{label}-{idx}.json" + _run_checked(cmd_by_label[label](warm_path), cwd=tree_by_label[label]) + + paired_runs: list[tuple[dict[str, float], dict[str, float]]] = [] + for round_idx, (first, second) in enumerate(pair_order, start=1): + outputs: dict[str, dict[str, float]] = {} + for label in (first, second): + output_path = tmpdir / f"round-{round_idx:02d}-{label}.json" + _run_checked(cmd_by_label[label](output_path), cwd=tree_by_label[label]) + outputs[label] = load_hotpath_summary(output_path) + + _ensure_json_files_match(outputs["A"], outputs["B"], ref_a, ref_b) + paired_runs.append((outputs["A"], outputs["B"])) + + summary = summarize_ab_results(ref_a, ref_b, rounds, seed, paired_runs) + _print_summary(summary) + if json_out is not None: + dump_ab_json(summary, json_out) + print(f"Wrote JSON results to {json_out}") + + +def setup_ab_cmdline(subparsers: argparse._SubParsersAction) -> None: + p_ab = subparsers.add_parser( + "ab", + help="run interleaved A/B hot-path comparisons using git worktrees", + ) + p_ab.add_argument("--ref-a", default="dev", help="baseline git ref or CURRENT") + p_ab.add_argument("--ref-b", default="CURRENT", help="candidate git ref or CURRENT") + p_ab.add_argument( + "--rounds", + type=int, + default=6, + help="number of A/B pairs to run (default = 6)", + ) + p_ab.add_argument( + "--count", + type=int, + default=2_000, + help="messages per hot-path sample (default = 2000)", + ) + p_ab.add_argument( + "--warmup", + type=int, + default=200, + help="warmup messages per hot-path sample (default = 200)", + ) + p_ab.add_argument( + "--prewarm", + type=int, + default=1, + help="unmeasured warmup invocations per side (default = 1)", + ) + p_ab.add_argument( + "--payload-sizes", + nargs="*", + type=int, + default=[64, 4096], + help="payload sizes in bytes (default = [64, 4096])", + ) + p_ab.add_argument( + "--transports", + nargs="*", + choices=["local", "shm", "tcp"], + default=["local", "shm", "tcp"], + help="transports to compare (default = ['local', 'shm', 'tcp'])", + ) + p_ab.add_argument( + "--apis", + nargs="*", + choices=["async", "sync"], + default=["async"], + help="apis to compare (default = ['async'])", + ) + p_ab.add_argument( + "--num-buffers", + type=int, + default=1, + help="publisher buffers (default = 1)", + ) + p_ab.add_argument( + "--seed", + type=int, + default=DEFAULT_PAIR_SEED, + help="pair-order shuffle seed (default = 0)", + ) + p_ab.add_argument( + "--json-out", + type=Path, + default=None, + help="optional JSON output path", + ) + p_ab.add_argument( + "--keep-worktrees", + action="store_true", + help="leave auto-provisioned worktrees on disk for inspection", + ) + p_ab.add_argument( + "--sync", + action="store_true", + help="run 'uv sync --group dev' in each provisioned worktree first", + ) + p_ab.add_argument( + "--quiet", + action="store_true", + help="suppress ezmsg runtime logs in child benchmark runs", + ) + p_ab.set_defaults( + _handler=lambda ns: perf_ab( + ref_a=ns.ref_a, + ref_b=ns.ref_b, + rounds=ns.rounds, + count=ns.count, + warmup=ns.warmup, + prewarm=ns.prewarm, + payload_sizes=ns.payload_sizes, + transports=ns.transports, + apis=ns.apis, + num_buffers=ns.num_buffers, + seed=ns.seed, + json_out=ns.json_out, + keep_worktrees=ns.keep_worktrees, + sync=ns.sync, + quiet=ns.quiet, + ) + ) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Run interleaved ezmsg hot-path A/B comparisons." + ) + subparsers = parser.add_subparsers(dest="command", required=True) + setup_ab_cmdline(subparsers) + ns = parser.parse_args(["ab", *sys.argv[1:]]) + ns._handler(ns) diff --git a/src/ezmsg/util/perf/command.py b/src/ezmsg/util/perf/command.py index 21fed7eb..9dab1f8e 100644 --- a/src/ezmsg/util/perf/command.py +++ b/src/ezmsg/util/perf/command.py @@ -1,6 +1,8 @@ import argparse +from .ab import setup_ab_cmdline from .analysis import setup_summary_cmdline +from .hotpath import setup_hotpath_cmdline from .run import setup_run_cmdline @@ -9,6 +11,8 @@ def command() -> None: subparsers = parser.add_subparsers(dest="command", required=True) setup_run_cmdline(subparsers) + setup_hotpath_cmdline(subparsers) + setup_ab_cmdline(subparsers) setup_summary_cmdline(subparsers) ns = parser.parse_args() diff --git a/src/ezmsg/util/perf/hotpath.py b/src/ezmsg/util/perf/hotpath.py new file mode 100644 index 00000000..74585981 --- /dev/null +++ b/src/ezmsg/util/perf/hotpath.py @@ -0,0 +1,437 @@ +from __future__ import annotations + +import argparse +import asyncio +import contextlib +import inspect +import json +import logging +import random +import statistics +import sys +import time + +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Literal +from uuid import uuid4 + +import ezmsg.core as ez + +from ezmsg.core.graphserver import GraphServer + +from .util import coef_var, median_of_means, stable_perf + +ApiName = Literal["async", "sync"] +TransportName = Literal["local", "shm", "tcp"] + +DEFAULT_APIS: tuple[ApiName, ...] = ("async",) +DEFAULT_TRANSPORTS: tuple[TransportName, ...] = ("local", "shm", "tcp") +DEFAULT_PAYLOAD_SIZES = (64, 4096) + + +def _supports_same_process_transport_selection() -> bool: + return "allow_local" in inspect.signature(ez.Publisher.create).parameters + + +def _validate_transport_support(case: "HotPathCase") -> None: + if case.transport == "local": + return + if _supports_same_process_transport_selection(): + return + raise RuntimeError( + "This ref does not support bypassing the same-process local fast path, " + f"so '{case.case_id}' cannot be benchmarked here. Compare against a ref " + "that includes the allow_local transport-selection support, or limit the " + "run to '--transports local'." + ) + + +def _publisher_transport_kwargs(case: "HotPathCase", host: str) -> dict[str, object]: + kwargs: dict[str, object] = { + "host": host, + "num_buffers": case.num_buffers, + "force_tcp": (case.transport == "tcp"), + } + if _supports_same_process_transport_selection(): + kwargs["allow_local"] = case.transport == "local" + return kwargs + + +@dataclass(frozen=True) +class HotPathCase: + api: ApiName + transport: TransportName + payload_size: int + num_buffers: int + + @property + def case_id(self) -> str: + return ( + f"{self.api}/{self.transport}/payload={self.payload_size}" + f"/buffers={self.num_buffers}" + ) + + +@dataclass(frozen=True) +class HotPathSummary: + samples: int + seconds_median: float + seconds_mean: float + seconds_min: float + seconds_max: float + seconds_median_of_means: float + seconds_cv: float + us_per_message_median: float + us_per_message_mean: float + messages_per_second_median: float + messages_per_second_mean: float + + +@dataclass(frozen=True) +class HotPathCaseResult: + case: HotPathCase + count: int + warmup: int + samples_seconds: list[float] + summary: HotPathSummary + + +@dataclass(frozen=True) +class HotPathSuiteResult: + seed: int + count: int + warmup: int + results: list[HotPathCaseResult] + + +@contextlib.contextmanager +def _quiet_ezmsg_logs(enabled: bool): + if not enabled: + yield + return + + old_level = ez.logger.level + ez.logger.setLevel(logging.WARNING) + try: + yield + finally: + ez.logger.setLevel(old_level) + + +def build_cases( + apis: list[str], + transports: list[str], + payload_sizes: list[int], + num_buffers: int, +) -> list[HotPathCase]: + cases = [ + HotPathCase( + api=api, # type: ignore[arg-type] + transport=transport, # type: ignore[arg-type] + payload_size=payload_size, + num_buffers=num_buffers, + ) + for api in apis + for transport in transports + for payload_size in payload_sizes + if api in ("async", "sync") and transport in ("local", "shm", "tcp") + ] + return sorted(cases, key=lambda case: case.case_id) + + +def summarize_samples(samples: list[float], count: int) -> HotPathSummary: + us_per_message = [(sample / count) * 1e6 for sample in samples] + rates = [count / sample for sample in samples] + return HotPathSummary( + samples=len(samples), + seconds_median=statistics.median(samples), + seconds_mean=statistics.fmean(samples), + seconds_min=min(samples), + seconds_max=max(samples), + seconds_median_of_means=median_of_means(samples), + seconds_cv=coef_var(samples), + us_per_message_median=statistics.median(us_per_message), + us_per_message_mean=statistics.fmean(us_per_message), + messages_per_second_median=statistics.median(rates), + messages_per_second_mean=statistics.fmean(rates), + ) + + +async def _async_roundtrip( + case: HotPathCase, + count: int, + warmup: int, + graph_address: tuple[str, int], +) -> float: + _validate_transport_support(case) + topic = f"/EZMSG/PERF/HOTPATH/{uuid4().hex}" + payload = bytes(case.payload_size) + + async with ez.GraphContext(graph_address, auto_start=False) as ctx: + pub = await ctx.publisher(topic, **_publisher_transport_kwargs(case, graph_address[0])) + sub = await ctx.subscriber(topic) + + for _ in range(warmup): + await pub.broadcast(payload) + await sub.recv() + + start = time.perf_counter() + for _ in range(count): + await pub.broadcast(payload) + await sub.recv() + return time.perf_counter() - start + + +def _sync_roundtrip( + case: HotPathCase, + count: int, + warmup: int, + graph_address: tuple[str, int], +) -> float: + _validate_transport_support(case) + topic = f"/EZMSG/PERF/HOTPATH/{uuid4().hex}" + payload = bytes(case.payload_size) + + with ez.sync.init(graph_address, auto_start=False) as ctx: + pub = ctx.create_publisher(topic, **_publisher_transport_kwargs(case, graph_address[0])) + sub = ctx.create_subscription(topic) + + for _ in range(warmup): + pub.publish(payload) + sub.recv() + + start = time.perf_counter() + for _ in range(count): + pub.publish(payload) + sub.recv() + return time.perf_counter() - start + + +def run_hotpath_case( + case: HotPathCase, + count: int, + warmup: int, + samples: int, + graph_address: tuple[str, int], +) -> HotPathCaseResult: + results: list[float] = [] + + for _ in range(samples): + with stable_perf(): + if case.api == "async": + elapsed = asyncio.run( + _async_roundtrip(case, count, warmup, graph_address) + ) + else: + elapsed = _sync_roundtrip(case, count, warmup, graph_address) + results.append(elapsed) + + return HotPathCaseResult( + case=case, + count=count, + warmup=warmup, + samples_seconds=results, + summary=summarize_samples(results, count), + ) + + +def _format_case_result(result: HotPathCaseResult) -> str: + summary = result.summary + return ( + f"{result.case.case_id:<36} " + f"{summary.us_per_message_median:>10.2f} us/msg " + f"{summary.messages_per_second_median:>12,.0f} msg/s " + f"cv={summary.seconds_cv:>5.3f}" + ) + + +def run_hotpath_suite( + count: int, + warmup: int, + samples: int, + apis: list[str], + transports: list[str], + payload_sizes: list[int], + num_buffers: int, + seed: int, + quiet: bool, +) -> HotPathSuiteResult: + rng = random.Random(seed) + cases = build_cases(apis, transports, payload_sizes, num_buffers) + rng.shuffle(cases) + + graph_server = GraphServer() + graph_server.start(("127.0.0.1", 0)) + try: + with _quiet_ezmsg_logs(quiet): + results = [ + run_hotpath_case( + case, + count=count, + warmup=warmup, + samples=samples, + graph_address=graph_server.address, + ) + for case in cases + ] + finally: + graph_server.stop() + + results.sort(key=lambda result: result.case.case_id) + return HotPathSuiteResult(seed=seed, count=count, warmup=warmup, results=results) + + +def dump_suite_json(result: HotPathSuiteResult, path: Path) -> None: + payload = { + "suite": "hotpath", + "seed": result.seed, + "count": result.count, + "warmup": result.warmup, + "results": [ + { + "case": asdict(case_result.case), + "case_id": case_result.case.case_id, + "count": case_result.count, + "warmup": case_result.warmup, + "samples_seconds": case_result.samples_seconds, + "summary": asdict(case_result.summary), + } + for case_result in result.results + ], + } + path.write_text(json.dumps(payload, indent=2) + "\n") + + +def perf_hotpath( + count: int, + warmup: int, + samples: int, + apis: list[str], + transports: list[str], + payload_sizes: list[int], + num_buffers: int, + seed: int, + json_out: Path | None, + quiet: bool, +) -> None: + result = run_hotpath_suite( + count=count, + warmup=warmup, + samples=samples, + apis=apis, + transports=transports, + payload_sizes=payload_sizes, + num_buffers=num_buffers, + seed=seed, + quiet=quiet, + ) + + print("Hot-path roundtrip benchmark") + print( + f"count={count}, warmup={warmup}, samples={samples}, " + f"payload_sizes={payload_sizes}, transports={transports}, apis={apis}" + ) + for case_result in result.results: + print(_format_case_result(case_result)) + + if json_out is not None: + dump_suite_json(result, json_out) + print(f"Wrote JSON results to {json_out}") + + +def setup_hotpath_cmdline(subparsers: argparse._SubParsersAction) -> None: + p_hotpath = subparsers.add_parser( + "hotpath", + help="run fast, focused hot-path roundtrip benchmarks", + ) + p_hotpath.add_argument( + "--count", + type=int, + default=5_000, + help="messages per sample (default = 5000)", + ) + p_hotpath.add_argument( + "--warmup", + type=int, + default=500, + help="warmup messages per sample (default = 500)", + ) + p_hotpath.add_argument( + "--samples", + type=int, + default=5, + help="timed samples per case (default = 5)", + ) + p_hotpath.add_argument( + "--apis", + nargs="*", + choices=DEFAULT_APIS + ("sync",), + default=list(DEFAULT_APIS), + help=f"apis to benchmark (default = {list(DEFAULT_APIS)})", + ) + p_hotpath.add_argument( + "--transports", + nargs="*", + choices=DEFAULT_TRANSPORTS, + default=list(DEFAULT_TRANSPORTS), + help=f"transports to benchmark (default = {list(DEFAULT_TRANSPORTS)})", + ) + p_hotpath.add_argument( + "--payload-sizes", + nargs="*", + type=int, + default=list(DEFAULT_PAYLOAD_SIZES), + help=f"payload sizes in bytes (default = {list(DEFAULT_PAYLOAD_SIZES)})", + ) + p_hotpath.add_argument( + "--num-buffers", + type=int, + default=1, + help="publisher buffers (default = 1)", + ) + p_hotpath.add_argument( + "--seed", + type=int, + default=0, + help="shuffle seed for case order (default = 0)", + ) + p_hotpath.add_argument( + "--json-out", + type=Path, + default=None, + help="optional JSON output path", + ) + p_hotpath.add_argument( + "--quiet", + action="store_true", + help="suppress ezmsg runtime logs during the benchmark", + ) + p_hotpath.set_defaults( + _handler=lambda ns: perf_hotpath( + count=ns.count, + warmup=ns.warmup, + samples=ns.samples, + apis=ns.apis, + transports=ns.transports, + payload_sizes=ns.payload_sizes, + num_buffers=ns.num_buffers, + seed=ns.seed, + json_out=ns.json_out, + quiet=ns.quiet, + ) + ) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Run ezmsg hot-path roundtrip benchmarks." + ) + subparsers = parser.add_subparsers(dest="command", required=True) + setup_hotpath_cmdline(subparsers) + ns = parser.parse_args(["hotpath", *sys.argv[1:]]) + ns._handler(ns) + + +if __name__ == "__main__": + main() diff --git a/tests/test_perf_ab.py b/tests/test_perf_ab.py new file mode 100644 index 00000000..3fbb935c --- /dev/null +++ b/tests/test_perf_ab.py @@ -0,0 +1,62 @@ +from ezmsg.util.perf.ab import ( + build_hotpath_command, + build_pair_order, + summarize_ab_results, +) + + +def test_build_pair_order_is_balanced_and_reproducible(): + first = build_pair_order(6, seed=123) + second = build_pair_order(6, seed=123) + + assert first == second + assert len(first) == 6 + assert first.count(("A", "B")) == 3 + assert first.count(("B", "A")) == 3 + + +def test_build_hotpath_command_contains_expected_args(tmp_path): + cmd = build_hotpath_command( + tmp_path / "out.json", + count=100, + warmup=10, + payload_sizes=[64, 256], + transports=["local", "shm"], + apis=["async", "sync"], + num_buffers=2, + quiet=True, + ) + + assert cmd[:5] == ["uv", "run", "python", "-m", "ezmsg.util.perf.hotpath"] + assert "--count" in cmd + assert "--payload-sizes" in cmd + assert "--quiet" in cmd + + +def test_summarize_ab_results_uses_b_vs_a_delta(): + paired_runs = [ + ( + {"async/shm/payload=64/buffers=1": 10.0}, + {"async/shm/payload=64/buffers=1": 12.0}, + ), + ( + {"async/shm/payload=64/buffers=1": 8.0}, + {"async/shm/payload=64/buffers=1": 9.0}, + ), + ] + + summary = summarize_ab_results( + ref_a="dev", + ref_b="CURRENT", + rounds=2, + seed=0, + paired_runs=paired_runs, + ) + + assert len(summary.cases) == 1 + case = summary.cases[0] + assert case.case_id == "async/shm/payload=64/buffers=1" + assert case.a_us_per_message_median == 9.0 + assert case.b_us_per_message_median == 10.5 + assert case.delta_pct_median > 0 + assert case.b_faster_pairs == 0 diff --git a/tests/test_perf_hotpath.py b/tests/test_perf_hotpath.py new file mode 100644 index 00000000..9e0fcc6f --- /dev/null +++ b/tests/test_perf_hotpath.py @@ -0,0 +1,45 @@ +import pytest + +from ezmsg.util.perf.hotpath import HotPathCase, build_cases, run_hotpath_case + +from ezmsg.core.graphserver import GraphServer + + +def test_build_cases_are_sorted_by_case_id(): + cases = build_cases( + apis=["sync", "async"], + transports=["local", "tcp", "shm"], + payload_sizes=[1024, 64], + num_buffers=1, + ) + assert [case.case_id for case in cases] == sorted(case.case_id for case in cases) + assert "async/shm/payload=64/buffers=1" in {case.case_id for case in cases} + assert "async/local/payload=64/buffers=1" in {case.case_id for case in cases} + + +def test_run_hotpath_case_smoke(): + server = GraphServer() + try: + server.start(("127.0.0.1", 0)) + except PermissionError: + pytest.skip("Local socket binding is unavailable in this environment") + try: + result = run_hotpath_case( + HotPathCase( + api="sync", + transport="tcp", + payload_size=64, + num_buffers=1, + ), + count=8, + warmup=2, + samples=2, + graph_address=server.address, + ) + finally: + server.stop() + + assert result.case.case_id == "sync/tcp/payload=64/buffers=1" + assert len(result.samples_seconds) == 2 + assert all(sample > 0 for sample in result.samples_seconds) + assert result.summary.us_per_message_median > 0 diff --git a/tests/test_pubclient.py b/tests/test_pubclient.py new file mode 100644 index 00000000..27048b4d --- /dev/null +++ b/tests/test_pubclient.py @@ -0,0 +1,203 @@ +from contextlib import contextmanager +from uuid import uuid4 + +import pytest + +from ezmsg.core.netprotocol import Address, Command +from ezmsg.core.pubclient import ( + ALLOW_LOCAL_ENV, + FORCE_TCP_ENV, + PubChannelInfo, + Publisher, +) + + +class DummyLocalChannel: + def __init__(self) -> None: + self.calls: list[tuple[int, object]] = [] + + def put_local(self, msg_id: int, obj: object) -> None: + self.calls.append((msg_id, obj)) + + +class DummyWriter: + def __init__(self) -> None: + self.buffer: list[bytes] = [] + + def write(self, data: bytes) -> None: + self.buffer.append(data) + + async def drain(self) -> None: + return None + + +class DummyShm: + def __init__(self, num_buffers: int, buf_size: int = 65536) -> None: + self.name = "dummy-shm" + self.buf_size = buf_size + self._buffers = [bytearray(buf_size) for _ in range(num_buffers)] + + @contextmanager + def buffer(self, idx: int, readonly: bool = False): + del readonly + yield memoryview(self._buffers[idx]) + + +def _make_publisher( + *, + force_tcp: bool | None, + allow_local: bool | None, + channel_pid: int, + shm_ok: bool = True, +) -> tuple[Publisher, DummyLocalChannel, DummyWriter]: + pub = Publisher( + id=uuid4(), + topic="/TEST", + shm=DummyShm(num_buffers=2), + graph_address=Address("127.0.0.1", 25978), + num_buffers=2, + force_tcp=force_tcp, + allow_local=allow_local, + _guard=Publisher._SENTINEL, + ) + pub._running.set() + + local_channel = DummyLocalChannel() + writer = DummyWriter() + channel = PubChannelInfo( + id=uuid4(), + writer=writer, + pub_id=pub.id, + pid=channel_pid, + shm_ok=shm_ok, + ) + pub._channels[channel.id] = channel + pub._local_channel = local_channel # type: ignore[assignment] + return pub, local_channel, writer + + +@pytest.mark.asyncio +async def test_broadcast_same_process_prefers_local_fast_path(): + pub, local_channel, writer = _make_publisher( + force_tcp=False, + allow_local=True, + channel_pid=0, + ) + pub.pid = 0 + + await pub.broadcast(b"payload") + + assert local_channel.calls == [(0, b"payload")] + assert writer.buffer == [] + + +@pytest.mark.asyncio +async def test_broadcast_same_process_can_force_shm_path(): + pub, local_channel, writer = _make_publisher( + force_tcp=False, + allow_local=False, + channel_pid=0, + ) + pub.pid = 0 + + await pub.broadcast(b"payload") + + assert local_channel.calls == [] + assert writer.buffer + assert writer.buffer[0].startswith(Command.TX_SHM.value) + + +@pytest.mark.asyncio +async def test_broadcast_same_process_can_force_tcp_path(): + pub, local_channel, writer = _make_publisher( + force_tcp=True, + allow_local=False, + channel_pid=0, + ) + pub.pid = 0 + + await pub.broadcast(b"payload") + + assert local_channel.calls == [] + assert writer.buffer + assert writer.buffer[0].startswith(Command.TX_TCP.value) + + +def test_force_tcp_disables_allow_local_from_env(monkeypatch, caplog): + monkeypatch.setenv(ALLOW_LOCAL_ENV, "1") + with caplog.at_level("INFO"): + pub, _, _ = _make_publisher( + force_tcp=True, + allow_local=None, + channel_pid=0, + ) + + assert pub._allow_local is False + assert "force_tcp=True disables local delivery" in caplog.text + + +def test_force_tcp_disables_explicit_allow_local(caplog): + with caplog.at_level("INFO"): + pub, _, _ = _make_publisher( + force_tcp=True, + allow_local=True, + channel_pid=0, + ) + + assert pub._allow_local is False + assert "force_tcp=True disables local delivery" in caplog.text + + +def test_force_tcp_uses_env_default_when_none(monkeypatch): + monkeypatch.setenv(FORCE_TCP_ENV, "1") + pub, _, _ = _make_publisher( + force_tcp=None, + allow_local=False, + channel_pid=0, + ) + + assert pub._force_tcp is True + + +def test_explicit_force_tcp_false_overrides_env(monkeypatch): + monkeypatch.setenv(FORCE_TCP_ENV, "1") + pub, _, _ = _make_publisher( + force_tcp=False, + allow_local=False, + channel_pid=0, + ) + + assert pub._force_tcp is False + + +@pytest.mark.asyncio +async def test_broadcast_same_process_uses_env_default_when_allow_local_is_none(monkeypatch): + monkeypatch.setenv(ALLOW_LOCAL_ENV, "0") + pub, local_channel, writer = _make_publisher( + force_tcp=False, + allow_local=None, + channel_pid=0, + ) + pub.pid = 0 + + await pub.broadcast(b"payload") + + assert local_channel.calls == [] + assert writer.buffer + assert writer.buffer[0].startswith(Command.TX_SHM.value) + + +@pytest.mark.asyncio +async def test_broadcast_same_process_explicit_allow_local_overrides_env(monkeypatch): + monkeypatch.setenv(ALLOW_LOCAL_ENV, "0") + pub, local_channel, writer = _make_publisher( + force_tcp=False, + allow_local=True, + channel_pid=0, + ) + pub.pid = 0 + + await pub.broadcast(b"payload") + + assert local_channel.calls == [(0, b"payload")] + assert writer.buffer == [] From e912e98556bbec308266b7201fb9c4b01c97fc9a Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Wed, 1 Apr 2026 12:09:37 -0400 Subject: [PATCH 34/52] less expensive bookkeeping --- src/ezmsg/core/messagechannel.py | 22 --- src/ezmsg/core/profiling.py | 256 +++++++++++++++++++------------ 2 files changed, 154 insertions(+), 124 deletions(-) diff --git a/src/ezmsg/core/messagechannel.py b/src/ezmsg/core/messagechannel.py index 3c86fa58..130895f1 100644 --- a/src/ezmsg/core/messagechannel.py +++ b/src/ezmsg/core/messagechannel.py @@ -21,7 +21,6 @@ encode_str, close_stream_writer, ) -from .profiling import PROFILES, PROFILE_TIME from .graphmeta import ProfileChannelType logger = logging.getLogger("ezmsg") @@ -102,7 +101,6 @@ class Channel: _graph_address: AddressType | None _local_backpressure: Backpressure | None _channel_kind: ProfileChannelType - _lease_start: dict[tuple[UUID, int], int] def __init__( self, @@ -130,7 +128,6 @@ def __init__( self._graph_address = graph_address self._local_backpressure = None self._channel_kind = ProfileChannelType.UNKNOWN - self._lease_start = {} @classmethod async def create( @@ -340,12 +337,10 @@ def channel_kind(self) -> ProfileChannelType: def _notify_clients(self, msg_id: int) -> bool: """notify interested clients and return true if any were notified""" buf_idx = msg_id % self.num_buffers - now_ns = PROFILE_TIME() for client_id, queue in self.clients.items(): if queue is None: continue # queue is none if this is the pub self.backpressure.lease(client_id, buf_idx) - self._lease_start[(client_id, msg_id)] = now_ns queue.put_nowait((self.pub_id, msg_id)) return not self.backpressure.available(buf_idx) @@ -409,18 +404,6 @@ def _release_backpressure(self, msg_id: int, client_id: UUID) -> None: :param client_id: UUID of client releasing this message :type client_id: UUID """ - now_ns = PROFILE_TIME() - lease = self._lease_start.pop((client_id, msg_id), None) - if lease is not None: - start_ns = lease - PROFILES.subscriber_attributed_backpressure( - client_id, - now_ns, - now_ns - start_ns, - self._channel_kind, - msg_seq=msg_id, - ) - buf_idx = msg_id % self.num_buffers self.backpressure.free(client_id, buf_idx) if self.backpressure.buffers[buf_idx].is_empty: @@ -476,11 +459,6 @@ def unregister_client(self, client_id: UUID) -> None: queue.put_nowait((pub_id, msg_id)) self.backpressure.free(client_id) - stale = [ - key for key in self._lease_start.keys() if key[0] == client_id - ] - for key in stale: - self._lease_start.pop(key, None) elif client_id == self.pub_id and self._local_backpressure is not None: self._local_backpressure.free(self.id) diff --git a/src/ezmsg/core/profiling.py b/src/ezmsg/core/profiling.py index 2bce4f8c..2893c55b 100644 --- a/src/ezmsg/core/profiling.py +++ b/src/ezmsg/core/profiling.py @@ -31,77 +31,91 @@ def _endpoint_id(topic: str, id: UUID) -> str: @dataclass -class _Rolling: - window_seconds: float = WINDOW_SECONDS - bucket_seconds: float = BUCKET_SECONDS - count: list[int] = field(default_factory=list) - value_sum: list[int] = field(default_factory=list) - max_value: list[int] = field(default_factory=list) - _num_buckets: int = 0 - _bucket_ns: int = 0 - _last_bucket_tick: int | None = None - - def __post_init__(self) -> None: - self._num_buckets = max(1, int(self.window_seconds / self.bucket_seconds)) - self._bucket_ns = max(1, int(self.bucket_seconds * 1e9)) - self.count = [0 for _ in range(self._num_buckets)] - self.value_sum = [0 for _ in range(self._num_buckets)] - self.max_value = [0 for _ in range(self._num_buckets)] +class _PublisherBucket: + publish_count: int = 0 + publish_delta_sum: int = 0 + publish_delta_count: int = 0 + backpressure_wait_sum: int = 0 + inflight_peak: int = 0 + + def reset(self) -> None: + self.publish_count = 0 + self.publish_delta_sum = 0 + self.publish_delta_count = 0 + self.backpressure_wait_sum = 0 + self.inflight_peak = 0 + + +@dataclass +class _SubscriberBucket: + recv_count: int = 0 + lease_time_sum: int = 0 + user_span_sum: int = 0 + user_span_count: int = 0 + attributable_backpressure_sum: int = 0 + attributable_backpressure_count: int = 0 + + def reset(self) -> None: + self.recv_count = 0 + self.lease_time_sum = 0 + self.user_span_sum = 0 + self.user_span_count = 0 + self.attributable_backpressure_sum = 0 + self.attributable_backpressure_count = 0 + + +class _BucketWindow: + def __init__(self, bucket_type: type[_PublisherBucket | _SubscriberBucket]) -> None: + self._num_buckets = max(1, int(WINDOW_SECONDS / BUCKET_SECONDS)) + self._bucket_ns = max(1, int(BUCKET_SECONDS * 1e9)) + self._bucket_type = bucket_type + self._ticks = [-1 for _ in range(self._num_buckets)] + self._buckets = [bucket_type() for _ in range(self._num_buckets)] + self._last_tick: int | None = None def _bucket_tick(self, ts_ns: int) -> int: return ts_ns // self._bucket_ns - def _bucket(self, ts_ns: int) -> int: - return self._bucket_tick(ts_ns) % self._num_buckets - - def _reset_bucket(self, idx: int) -> None: - self.count[idx] = 0 - self.value_sum[idx] = 0 - self.max_value[idx] = 0 + def _clear_bucket(self, idx: int) -> None: + self._ticks[idx] = -1 + self._buckets[idx].reset() def _advance(self, ts_ns: int) -> int: - bucket_tick = self._bucket_tick(ts_ns) - bucket = bucket_tick % self._num_buckets - if self._last_bucket_tick is None: - self._last_bucket_tick = bucket_tick - return bucket - if bucket_tick <= self._last_bucket_tick: - return bucket - elapsed_buckets = bucket_tick - self._last_bucket_tick - if elapsed_buckets >= self._num_buckets: + tick = self._bucket_tick(ts_ns) + if self._last_tick is None: + self._last_tick = tick + return tick + if tick <= self._last_tick: + return tick + + elapsed = tick - self._last_tick + if elapsed >= self._num_buckets: for idx in range(self._num_buckets): - self._reset_bucket(idx) + self._clear_bucket(idx) else: - previous_bucket = self._last_bucket_tick % self._num_buckets - for step in range(1, elapsed_buckets + 1): - self._reset_bucket((previous_bucket + step) % self._num_buckets) - self._last_bucket_tick = bucket_tick - return bucket - - def advance_to(self, ts_ns: int) -> None: - self._advance(ts_ns) - - def add(self, ts_ns: int, value: int) -> None: - idx = self._advance(ts_ns) - self.count[idx] += 1 - self.value_sum[idx] += value - if value > self.max_value[idx]: - self.max_value[idx] = value - - def count_total(self) -> int: - return sum(self.count) - - def sum_total(self) -> int: - return sum(self.value_sum) - - def max_total(self) -> int: - return max(self.max_value) if self.max_value else 0 - - def avg(self) -> float: - c = self.count_total() - if c == 0: - return 0.0 - return float(self.sum_total()) / float(c) + previous_idx = self._last_tick % self._num_buckets + for step in range(1, elapsed + 1): + self._clear_bucket((previous_idx + step) % self._num_buckets) + + self._last_tick = tick + return tick + + def bucket(self, ts_ns: int) -> _PublisherBucket | _SubscriberBucket: + tick = self._advance(ts_ns) + idx = tick % self._num_buckets + if self._ticks[idx] != tick: + self._ticks[idx] = tick + self._buckets[idx].reset() + return self._buckets[idx] + + def buckets(self, ts_ns: int) -> list[_PublisherBucket | _SubscriberBucket]: + tick = self._advance(ts_ns) + min_tick = tick - self._num_buckets + 1 + return [ + bucket + for bucket_tick, bucket in zip(self._ticks, self._buckets) + if bucket_tick >= min_tick + ] @dataclass @@ -113,10 +127,9 @@ class _PublisherMetrics: backpressure_wait_ns_total: int = 0 inflight_messages_current: int = 0 _last_publish_ts_ns: int | None = None - _publish_delta: _Rolling = field(default_factory=_Rolling) - _publish_count: _Rolling = field(default_factory=lambda: _Rolling()) - _backpressure_wait: _Rolling = field(default_factory=_Rolling) - _inflight: _Rolling = field(default_factory=_Rolling) + _window: _BucketWindow = field( + default_factory=lambda: _BucketWindow(_PublisherBucket) + ) trace_enabled: bool = False trace_sample_mod: int = 1 trace_metrics: set[str] | None = None @@ -132,13 +145,18 @@ def record_publish( self, ts_ns: int, inflight: int, msg_seq: int | None = None ) -> None: self.messages_published_total += 1 - self._publish_count.add(ts_ns, 1) + bucket = self._window.bucket(ts_ns) + assert isinstance(bucket, _PublisherBucket) + bucket.publish_count += 1 publish_delta_ns = 0 if self._last_publish_ts_ns is not None: publish_delta_ns = ts_ns - self._last_publish_ts_ns - self._publish_delta.add(ts_ns, publish_delta_ns) + bucket.publish_delta_sum += publish_delta_ns + bucket.publish_delta_count += 1 self._last_publish_ts_ns = ts_ns - self.sample_inflight(ts_ns, inflight) + self.inflight_messages_current = inflight + if inflight > bucket.inflight_peak: + bucket.inflight_peak = inflight self._trace_counter += 1 if ( self.trace_enabled @@ -147,7 +165,7 @@ def record_publish( ): self.trace_samples.append( ProfilingTraceSample( - timestamp=float(PROFILE_TIME()), + timestamp=float(ts_ns), endpoint_id=self.endpoint_id, topic=self.topic, metric="publish_delta_ns", @@ -160,11 +178,13 @@ def record_backpressure_wait( self, ts_ns: int, wait_ns: int, msg_seq: int | None = None ) -> None: self.backpressure_wait_ns_total += wait_ns - self._backpressure_wait.add(ts_ns, wait_ns) + bucket = self._window.bucket(ts_ns) + assert isinstance(bucket, _PublisherBucket) + bucket.backpressure_wait_sum += wait_ns if self.trace_enabled and self._trace_metric_enabled("backpressure_wait_ns"): self.trace_samples.append( ProfilingTraceSample( - timestamp=float(PROFILE_TIME()), + timestamp=float(ts_ns), endpoint_id=self.endpoint_id, topic=self.topic, metric="backpressure_wait_ns", @@ -175,27 +195,43 @@ def record_backpressure_wait( def sample_inflight(self, ts_ns: int, inflight: int) -> None: self.inflight_messages_current = inflight - self._inflight.add(ts_ns, inflight) + bucket = self._window.bucket(ts_ns) + assert isinstance(bucket, _PublisherBucket) + if inflight > bucket.inflight_peak: + bucket.inflight_peak = inflight def snapshot(self) -> PublisherProfileSnapshot: now_ns = PROFILE_TIME() - self._publish_delta.advance_to(now_ns) - self._publish_count.advance_to(now_ns) - self._backpressure_wait.advance_to(now_ns) - self._inflight.advance_to(now_ns) - window_msgs = self._publish_count.count_total() + buckets = self._window.buckets(now_ns) + window_msgs = 0 + publish_delta_sum = 0 + publish_delta_count = 0 + backpressure_wait_sum = 0 + inflight_peak = 0 + for bucket in buckets: + assert isinstance(bucket, _PublisherBucket) + window_msgs += bucket.publish_count + publish_delta_sum += bucket.publish_delta_sum + publish_delta_count += bucket.publish_delta_count + backpressure_wait_sum += bucket.backpressure_wait_sum + if bucket.inflight_peak > inflight_peak: + inflight_peak = bucket.inflight_peak return PublisherProfileSnapshot( endpoint_id=self.endpoint_id, topic=self.topic, messages_published_total=self.messages_published_total, messages_published_window=window_msgs, - publish_delta_ns_avg_window=self._publish_delta.avg(), + publish_delta_ns_avg_window=( + float(publish_delta_sum) / float(publish_delta_count) + if publish_delta_count > 0 + else 0.0 + ), publish_rate_hz_window=float(window_msgs) / max(WINDOW_SECONDS, 1e-9), inflight_messages_current=self.inflight_messages_current, num_buffers=self.num_buffers, - inflight_messages_peak_window=self._inflight.max_total(), + inflight_messages_peak_window=inflight_peak, backpressure_wait_ns_total=self.backpressure_wait_ns_total, - backpressure_wait_ns_window=self._backpressure_wait.sum_total(), + backpressure_wait_ns_window=backpressure_wait_sum, timestamp=float(now_ns), ) @@ -210,10 +246,9 @@ class _SubscriberMetrics: attributable_backpressure_ns_total: int = 0 attributable_backpressure_events_total: int = 0 channel_kind_last: ProfileChannelType = ProfileChannelType.UNKNOWN - _recv_count: _Rolling = field(default_factory=lambda: _Rolling()) - _lease_time: _Rolling = field(default_factory=_Rolling) - _user_span: _Rolling = field(default_factory=_Rolling) - _attrib_bp: _Rolling = field(default_factory=_Rolling) + _window: _BucketWindow = field( + default_factory=lambda: _BucketWindow(_SubscriberBucket) + ) trace_enabled: bool = False trace_sample_mod: int = 1 trace_metrics: set[str] | None = None @@ -235,8 +270,10 @@ def record_receive( self.messages_received_total += 1 self.lease_time_ns_total += lease_ns self.channel_kind_last = channel_kind - self._recv_count.add(ts_ns, 1) - self._lease_time.add(ts_ns, lease_ns) + bucket = self._window.bucket(ts_ns) + assert isinstance(bucket, _SubscriberBucket) + bucket.recv_count += 1 + bucket.lease_time_sum += lease_ns self._trace_counter += 1 if ( self.trace_enabled @@ -245,7 +282,7 @@ def record_receive( ): self.trace_samples.append( ProfilingTraceSample( - timestamp=float(PROFILE_TIME()), + timestamp=float(ts_ns), endpoint_id=self.endpoint_id, topic=self.topic, metric="lease_time_ns", @@ -259,11 +296,14 @@ def record_user_span( self, ts_ns: int, span_ns: int, label: str | None, msg_seq: int | None = None ) -> None: self.user_span_ns_total += span_ns - self._user_span.add(ts_ns, span_ns) + bucket = self._window.bucket(ts_ns) + assert isinstance(bucket, _SubscriberBucket) + bucket.user_span_sum += span_ns + bucket.user_span_count += 1 if self.trace_enabled and self._trace_metric_enabled("user_span_ns"): self.trace_samples.append( ProfilingTraceSample( - timestamp=float(PROFILE_TIME()), + timestamp=float(ts_ns), endpoint_id=self.endpoint_id, topic=self.topic if label is None else f"{self.topic}:{label}", metric="user_span_ns", @@ -283,11 +323,14 @@ def record_attributed_backpressure( self.attributable_backpressure_ns_total += duration_ns self.attributable_backpressure_events_total += 1 self.channel_kind_last = channel_kind - self._attrib_bp.add(ts_ns, duration_ns) + bucket = self._window.bucket(ts_ns) + assert isinstance(bucket, _SubscriberBucket) + bucket.attributable_backpressure_sum += duration_ns + bucket.attributable_backpressure_count += 1 if self.trace_enabled and self._trace_metric_enabled("attributable_backpressure_ns"): self.trace_samples.append( ProfilingTraceSample( - timestamp=float(PROFILE_TIME()), + timestamp=float(ts_ns), endpoint_id=self.endpoint_id, topic=self.topic, metric="attributable_backpressure_ns", @@ -299,27 +342,36 @@ def record_attributed_backpressure( def snapshot(self) -> SubscriberProfileSnapshot: now_ns = PROFILE_TIME() - self._recv_count.advance_to(now_ns) - self._lease_time.advance_to(now_ns) - self._user_span.advance_to(now_ns) - self._attrib_bp.advance_to(now_ns) - recv_count = self._recv_count.count_total() - user_count = self._user_span.count_total() + buckets = self._window.buckets(now_ns) + recv_count = 0 + lease_time_sum = 0 + user_span_sum = 0 + user_count = 0 + attributable_backpressure_sum = 0 + for bucket in buckets: + assert isinstance(bucket, _SubscriberBucket) + recv_count += bucket.recv_count + lease_time_sum += bucket.lease_time_sum + user_span_sum += bucket.user_span_sum + user_count += bucket.user_span_count + attributable_backpressure_sum += bucket.attributable_backpressure_sum return SubscriberProfileSnapshot( endpoint_id=self.endpoint_id, topic=self.topic, messages_received_total=self.messages_received_total, messages_received_window=recv_count, lease_time_ns_total=self.lease_time_ns_total, - lease_time_ns_avg_window=self._lease_time.avg(), + lease_time_ns_avg_window=( + float(lease_time_sum) / float(recv_count) if recv_count > 0 else 0.0 + ), user_span_ns_total=self.user_span_ns_total, user_span_ns_avg_window=( - float(self._user_span.sum_total()) / float(user_count) + float(user_span_sum) / float(user_count) if user_count > 0 else 0.0 ), attributable_backpressure_ns_total=self.attributable_backpressure_ns_total, - attributable_backpressure_ns_window=self._attrib_bp.sum_total(), + attributable_backpressure_ns_window=attributable_backpressure_sum, attributable_backpressure_events_total=self.attributable_backpressure_events_total, channel_kind_last=self.channel_kind_last, timestamp=float(now_ns), From bf638b89443b024b3521a34672be0c98ed5d49f6 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Wed, 1 Apr 2026 13:36:09 -0400 Subject: [PATCH 35/52] profiling refactor for perf --- src/ezmsg/core/graphmeta.py | 11 - src/ezmsg/core/profiling.py | 480 +++++++++++------------------------- src/ezmsg/core/pubclient.py | 39 ++- src/ezmsg/core/subclient.py | 22 +- tests/test_profiling_api.py | 88 ++++--- 5 files changed, 214 insertions(+), 426 deletions(-) diff --git a/src/ezmsg/core/graphmeta.py b/src/ezmsg/core/graphmeta.py index c96e4d7d..7bb0973f 100644 --- a/src/ezmsg/core/graphmeta.py +++ b/src/ezmsg/core/graphmeta.py @@ -279,13 +279,9 @@ class PublisherProfileSnapshot: topic: str messages_published_total: int messages_published_window: int - publish_delta_ns_avg_window: float publish_rate_hz_window: float inflight_messages_current: int num_buffers: int - inflight_messages_peak_window: int - backpressure_wait_ns_total: int - backpressure_wait_ns_window: int timestamp: float @@ -295,13 +291,6 @@ class SubscriberProfileSnapshot: topic: str messages_received_total: int messages_received_window: int - lease_time_ns_total: int - lease_time_ns_avg_window: float - user_span_ns_total: int - user_span_ns_avg_window: float - attributable_backpressure_ns_total: int - attributable_backpressure_ns_window: int - attributable_backpressure_events_total: int channel_kind_last: ProfileChannelType timestamp: float diff --git a/src/ezmsg/core/profiling.py b/src/ezmsg/core/profiling.py index 2893c55b..9cf5b102 100644 --- a/src/ezmsg/core/profiling.py +++ b/src/ezmsg/core/profiling.py @@ -2,6 +2,7 @@ import socket import time import heapq + from collections import deque from dataclasses import dataclass, field from typing import Callable, TypeAlias @@ -18,8 +19,6 @@ ) -WINDOW_SECONDS = float(os.environ.get("EZMSG_PROFILE_WINDOW_SECONDS", "10.0")) -BUCKET_SECONDS = float(os.environ.get("EZMSG_PROFILE_BUCKET_SECONDS", "0.1")) TRACE_MAX_SAMPLES = int(os.environ.get("EZMSG_PROFILE_TRACE_MAX_SAMPLES", "10000")) # Must return monotonic nanoseconds so *_ns metrics remain unit-consistent. PROFILE_TIME_TYPE: TypeAlias = Callable[[], int] @@ -30,106 +29,15 @@ def _endpoint_id(topic: str, id: UUID) -> str: return f"{topic}:{id}" -@dataclass -class _PublisherBucket: - publish_count: int = 0 - publish_delta_sum: int = 0 - publish_delta_count: int = 0 - backpressure_wait_sum: int = 0 - inflight_peak: int = 0 - - def reset(self) -> None: - self.publish_count = 0 - self.publish_delta_sum = 0 - self.publish_delta_count = 0 - self.backpressure_wait_sum = 0 - self.inflight_peak = 0 - - -@dataclass -class _SubscriberBucket: - recv_count: int = 0 - lease_time_sum: int = 0 - user_span_sum: int = 0 - user_span_count: int = 0 - attributable_backpressure_sum: int = 0 - attributable_backpressure_count: int = 0 - - def reset(self) -> None: - self.recv_count = 0 - self.lease_time_sum = 0 - self.user_span_sum = 0 - self.user_span_count = 0 - self.attributable_backpressure_sum = 0 - self.attributable_backpressure_count = 0 - - -class _BucketWindow: - def __init__(self, bucket_type: type[_PublisherBucket | _SubscriberBucket]) -> None: - self._num_buckets = max(1, int(WINDOW_SECONDS / BUCKET_SECONDS)) - self._bucket_ns = max(1, int(BUCKET_SECONDS * 1e9)) - self._bucket_type = bucket_type - self._ticks = [-1 for _ in range(self._num_buckets)] - self._buckets = [bucket_type() for _ in range(self._num_buckets)] - self._last_tick: int | None = None - - def _bucket_tick(self, ts_ns: int) -> int: - return ts_ns // self._bucket_ns - - def _clear_bucket(self, idx: int) -> None: - self._ticks[idx] = -1 - self._buckets[idx].reset() - - def _advance(self, ts_ns: int) -> int: - tick = self._bucket_tick(ts_ns) - if self._last_tick is None: - self._last_tick = tick - return tick - if tick <= self._last_tick: - return tick - - elapsed = tick - self._last_tick - if elapsed >= self._num_buckets: - for idx in range(self._num_buckets): - self._clear_bucket(idx) - else: - previous_idx = self._last_tick % self._num_buckets - for step in range(1, elapsed + 1): - self._clear_bucket((previous_idx + step) % self._num_buckets) - - self._last_tick = tick - return tick - - def bucket(self, ts_ns: int) -> _PublisherBucket | _SubscriberBucket: - tick = self._advance(ts_ns) - idx = tick % self._num_buckets - if self._ticks[idx] != tick: - self._ticks[idx] = tick - self._buckets[idx].reset() - return self._buckets[idx] - - def buckets(self, ts_ns: int) -> list[_PublisherBucket | _SubscriberBucket]: - tick = self._advance(ts_ns) - min_tick = tick - self._num_buckets + 1 - return [ - bucket - for bucket_tick, bucket in zip(self._ticks, self._buckets) - if bucket_tick >= min_tick - ] - - @dataclass class _PublisherMetrics: topic: str endpoint_id: str num_buffers: int messages_published_total: int = 0 - backpressure_wait_ns_total: int = 0 inflight_messages_current: int = 0 + _last_snapshot_total: int = 0 _last_publish_ts_ns: int | None = None - _window: _BucketWindow = field( - default_factory=lambda: _BucketWindow(_PublisherBucket) - ) trace_enabled: bool = False trace_sample_mod: int = 1 trace_metrics: set[str] | None = None @@ -138,100 +46,81 @@ class _PublisherMetrics: default_factory=lambda: deque(maxlen=TRACE_MAX_SAMPLES) ) - def _trace_metric_enabled(self, metric: str) -> bool: - return self.trace_metrics is None or metric in self.trace_metrics + def trace_metric_enabled(self, metric: str) -> bool: + return self.trace_enabled and ( + self.trace_metrics is None or metric in self.trace_metrics + ) - def record_publish( - self, ts_ns: int, inflight: int, msg_seq: int | None = None - ) -> None: + def record_publish(self, inflight: int, msg_seq: int | None = None) -> None: self.messages_published_total += 1 - bucket = self._window.bucket(ts_ns) - assert isinstance(bucket, _PublisherBucket) - bucket.publish_count += 1 - publish_delta_ns = 0 - if self._last_publish_ts_ns is not None: - publish_delta_ns = ts_ns - self._last_publish_ts_ns - bucket.publish_delta_sum += publish_delta_ns - bucket.publish_delta_count += 1 - self._last_publish_ts_ns = ts_ns self.inflight_messages_current = inflight - if inflight > bucket.inflight_peak: - bucket.inflight_peak = inflight self._trace_counter += 1 - if ( - self.trace_enabled - and self._trace_metric_enabled("publish_delta_ns") - and (self._trace_counter % max(1, self.trace_sample_mod) == 0) - ): - self.trace_samples.append( - ProfilingTraceSample( - timestamp=float(ts_ns), - endpoint_id=self.endpoint_id, - topic=self.topic, - metric="publish_delta_ns", - value=float(publish_delta_ns), - sample_seq=msg_seq, - ) + + if not self.trace_metric_enabled("publish_delta_ns"): + return + if self._trace_counter % max(1, self.trace_sample_mod) != 0: + return + + now_ns = PROFILE_TIME() + publish_delta_ns = ( + 0 if self._last_publish_ts_ns is None else now_ns - self._last_publish_ts_ns + ) + self._last_publish_ts_ns = now_ns + self.trace_samples.append( + ProfilingTraceSample( + timestamp=float(now_ns), + endpoint_id=self.endpoint_id, + topic=self.topic, + metric="publish_delta_ns", + value=float(publish_delta_ns), + sample_seq=msg_seq, ) + ) - def record_backpressure_wait( - self, ts_ns: int, wait_ns: int, msg_seq: int | None = None - ) -> None: - self.backpressure_wait_ns_total += wait_ns - bucket = self._window.bucket(ts_ns) - assert isinstance(bucket, _PublisherBucket) - bucket.backpressure_wait_sum += wait_ns - if self.trace_enabled and self._trace_metric_enabled("backpressure_wait_ns"): - self.trace_samples.append( - ProfilingTraceSample( - timestamp=float(ts_ns), - endpoint_id=self.endpoint_id, - topic=self.topic, - metric="backpressure_wait_ns", - value=float(wait_ns), - sample_seq=msg_seq, - ) + def record_backpressure_wait(self, wait_ns: int, msg_seq: int | None = None) -> None: + if not self.trace_metric_enabled("backpressure_wait_ns"): + return + + now_ns = PROFILE_TIME() + self.trace_samples.append( + ProfilingTraceSample( + timestamp=float(now_ns), + endpoint_id=self.endpoint_id, + topic=self.topic, + metric="backpressure_wait_ns", + value=float(wait_ns), + sample_seq=msg_seq, ) + ) - def sample_inflight(self, ts_ns: int, inflight: int) -> None: + def sample_inflight(self, inflight: int) -> None: self.inflight_messages_current = inflight - bucket = self._window.bucket(ts_ns) - assert isinstance(bucket, _PublisherBucket) - if inflight > bucket.inflight_peak: - bucket.inflight_peak = inflight - def snapshot(self) -> PublisherProfileSnapshot: - now_ns = PROFILE_TIME() - buckets = self._window.buckets(now_ns) - window_msgs = 0 - publish_delta_sum = 0 - publish_delta_count = 0 - backpressure_wait_sum = 0 - inflight_peak = 0 - for bucket in buckets: - assert isinstance(bucket, _PublisherBucket) - window_msgs += bucket.publish_count - publish_delta_sum += bucket.publish_delta_sum - publish_delta_count += bucket.publish_delta_count - backpressure_wait_sum += bucket.backpressure_wait_sum - if bucket.inflight_peak > inflight_peak: - inflight_peak = bucket.inflight_peak + def snapshot( + self, + now_ns: int, + window_seconds: float, + *, + has_previous_snapshot: bool, + ) -> PublisherProfileSnapshot: + window_count = ( + self.messages_published_total - self._last_snapshot_total + if has_previous_snapshot + else 0 + ) + self._last_snapshot_total = self.messages_published_total return PublisherProfileSnapshot( endpoint_id=self.endpoint_id, topic=self.topic, messages_published_total=self.messages_published_total, - messages_published_window=window_msgs, - publish_delta_ns_avg_window=( - float(publish_delta_sum) / float(publish_delta_count) - if publish_delta_count > 0 + messages_published_window=window_count, + publish_rate_hz_window=( + float(window_count) / max(window_seconds, 1e-9) + if has_previous_snapshot and window_seconds > 0.0 else 0.0 ), - publish_rate_hz_window=float(window_msgs) / max(WINDOW_SECONDS, 1e-9), inflight_messages_current=self.inflight_messages_current, num_buffers=self.num_buffers, - inflight_messages_peak_window=inflight_peak, - backpressure_wait_ns_total=self.backpressure_wait_ns_total, - backpressure_wait_ns_window=backpressure_wait_sum, timestamp=float(now_ns), ) @@ -241,14 +130,8 @@ class _SubscriberMetrics: topic: str endpoint_id: str messages_received_total: int = 0 - lease_time_ns_total: int = 0 - user_span_ns_total: int = 0 - attributable_backpressure_ns_total: int = 0 - attributable_backpressure_events_total: int = 0 channel_kind_last: ProfileChannelType = ProfileChannelType.UNKNOWN - _window: _BucketWindow = field( - default_factory=lambda: _BucketWindow(_SubscriberBucket) - ) + _last_snapshot_total: int = 0 trace_enabled: bool = False trace_sample_mod: int = 1 trace_metrics: set[str] | None = None @@ -257,122 +140,75 @@ class _SubscriberMetrics: default_factory=lambda: deque(maxlen=TRACE_MAX_SAMPLES) ) - def _trace_metric_enabled(self, metric: str) -> bool: - return self.trace_metrics is None or metric in self.trace_metrics + def trace_metric_enabled(self, metric: str) -> bool: + return self.trace_enabled and ( + self.trace_metrics is None or metric in self.trace_metrics + ) def record_receive( self, - ts_ns: int, - lease_ns: int, channel_kind: ProfileChannelType, + lease_ns: int | None = None, msg_seq: int | None = None, ) -> None: self.messages_received_total += 1 - self.lease_time_ns_total += lease_ns self.channel_kind_last = channel_kind - bucket = self._window.bucket(ts_ns) - assert isinstance(bucket, _SubscriberBucket) - bucket.recv_count += 1 - bucket.lease_time_sum += lease_ns self._trace_counter += 1 - if ( - self.trace_enabled - and self._trace_metric_enabled("lease_time_ns") - and (self._trace_counter % max(1, self.trace_sample_mod) == 0) - ): - self.trace_samples.append( - ProfilingTraceSample( - timestamp=float(ts_ns), - endpoint_id=self.endpoint_id, - topic=self.topic, - metric="lease_time_ns", - value=float(lease_ns), - channel_kind=channel_kind, - sample_seq=msg_seq, - ) + + if lease_ns is None or not self.trace_metric_enabled("lease_time_ns"): + return + if self._trace_counter % max(1, self.trace_sample_mod) != 0: + return + + now_ns = PROFILE_TIME() + self.trace_samples.append( + ProfilingTraceSample( + timestamp=float(now_ns), + endpoint_id=self.endpoint_id, + topic=self.topic, + metric="lease_time_ns", + value=float(lease_ns), + channel_kind=channel_kind, + sample_seq=msg_seq, ) + ) def record_user_span( - self, ts_ns: int, span_ns: int, label: str | None, msg_seq: int | None = None + self, span_ns: int, label: str | None, msg_seq: int | None = None ) -> None: - self.user_span_ns_total += span_ns - bucket = self._window.bucket(ts_ns) - assert isinstance(bucket, _SubscriberBucket) - bucket.user_span_sum += span_ns - bucket.user_span_count += 1 - if self.trace_enabled and self._trace_metric_enabled("user_span_ns"): - self.trace_samples.append( - ProfilingTraceSample( - timestamp=float(ts_ns), - endpoint_id=self.endpoint_id, - topic=self.topic if label is None else f"{self.topic}:{label}", - metric="user_span_ns", - value=float(span_ns), - channel_kind=self.channel_kind_last, - sample_seq=msg_seq, - ) - ) + if not self.trace_metric_enabled("user_span_ns"): + return - def record_attributed_backpressure( - self, - ts_ns: int, - duration_ns: int, - channel_kind: ProfileChannelType, - msg_seq: int | None = None, - ) -> None: - self.attributable_backpressure_ns_total += duration_ns - self.attributable_backpressure_events_total += 1 - self.channel_kind_last = channel_kind - bucket = self._window.bucket(ts_ns) - assert isinstance(bucket, _SubscriberBucket) - bucket.attributable_backpressure_sum += duration_ns - bucket.attributable_backpressure_count += 1 - if self.trace_enabled and self._trace_metric_enabled("attributable_backpressure_ns"): - self.trace_samples.append( - ProfilingTraceSample( - timestamp=float(ts_ns), - endpoint_id=self.endpoint_id, - topic=self.topic, - metric="attributable_backpressure_ns", - value=float(duration_ns), - channel_kind=channel_kind, - sample_seq=msg_seq, - ) + now_ns = PROFILE_TIME() + self.trace_samples.append( + ProfilingTraceSample( + timestamp=float(now_ns), + endpoint_id=self.endpoint_id, + topic=self.topic if label is None else f"{self.topic}:{label}", + metric="user_span_ns", + value=float(span_ns), + channel_kind=self.channel_kind_last, + sample_seq=msg_seq, ) + ) - def snapshot(self) -> SubscriberProfileSnapshot: - now_ns = PROFILE_TIME() - buckets = self._window.buckets(now_ns) - recv_count = 0 - lease_time_sum = 0 - user_span_sum = 0 - user_count = 0 - attributable_backpressure_sum = 0 - for bucket in buckets: - assert isinstance(bucket, _SubscriberBucket) - recv_count += bucket.recv_count - lease_time_sum += bucket.lease_time_sum - user_span_sum += bucket.user_span_sum - user_count += bucket.user_span_count - attributable_backpressure_sum += bucket.attributable_backpressure_sum + def snapshot( + self, + now_ns: int, + *, + has_previous_snapshot: bool, + ) -> SubscriberProfileSnapshot: + window_count = ( + self.messages_received_total - self._last_snapshot_total + if has_previous_snapshot + else 0 + ) + self._last_snapshot_total = self.messages_received_total return SubscriberProfileSnapshot( endpoint_id=self.endpoint_id, topic=self.topic, messages_received_total=self.messages_received_total, - messages_received_window=recv_count, - lease_time_ns_total=self.lease_time_ns_total, - lease_time_ns_avg_window=( - float(lease_time_sum) / float(recv_count) if recv_count > 0 else 0.0 - ), - user_span_ns_total=self.user_span_ns_total, - user_span_ns_avg_window=( - float(user_span_sum) / float(user_count) - if user_count > 0 - else 0.0 - ), - attributable_backpressure_ns_total=self.attributable_backpressure_ns_total, - attributable_backpressure_ns_window=attributable_backpressure_sum, - attributable_backpressure_events_total=self.attributable_backpressure_events_total, + messages_received_window=window_count, channel_kind_last=self.channel_kind_last, timestamp=float(now_ns), ) @@ -387,6 +223,7 @@ def __init__(self) -> None: self._subscribers: dict[UUID, _SubscriberMetrics] = {} self._default_trace_control = ProfilingTraceControl(enabled=False) self._trace_control_expires_ns: int | None = None + self._last_snapshot_ts_ns: int | None = None def set_process_id(self, process_id: UUID, *, reset: bool = False) -> None: if reset: @@ -395,8 +232,9 @@ def set_process_id(self, process_id: UUID, *, reset: bool = False) -> None: self._default_trace_control = ProfilingTraceControl(enabled=False) self._trace_control_expires_ns = None self._process_id = process_id + self._last_snapshot_ts_ns = None - def register_publisher(self, pub_id: UUID, topic: str, num_buffers: int) -> None: + def register_publisher(self, pub_id: UUID, topic: str, num_buffers: int) -> _PublisherMetrics: metric = _PublisherMetrics( topic=topic, endpoint_id=_endpoint_id(topic, pub_id), @@ -404,96 +242,52 @@ def register_publisher(self, pub_id: UUID, topic: str, num_buffers: int) -> None ) self._publishers[pub_id] = metric self._apply_trace_control_to_publisher(metric) + return metric def unregister_publisher(self, pub_id: UUID) -> None: self._publishers.pop(pub_id, None) - def register_subscriber(self, sub_id: UUID, topic: str) -> None: + def register_subscriber(self, sub_id: UUID, topic: str) -> _SubscriberMetrics: metric = _SubscriberMetrics( topic=topic, endpoint_id=_endpoint_id(topic, sub_id), ) self._subscribers[sub_id] = metric self._apply_trace_control_to_subscriber(metric) + return metric def unregister_subscriber(self, sub_id: UUID) -> None: self._subscribers.pop(sub_id, None) - def publisher_publish( - self, pub_id: UUID, ts_ns: int, inflight: int, msg_seq: int | None = None - ) -> None: - self._expire_trace_control_if_needed(ts_ns) - metric = self._publishers.get(pub_id) - if metric is not None: - metric.record_publish(ts_ns, inflight, msg_seq) - - def publisher_backpressure_wait( - self, pub_id: UUID, ts_ns: int, wait_ns: int, msg_seq: int | None = None - ) -> None: - self._expire_trace_control_if_needed(ts_ns) - metric = self._publishers.get(pub_id) - if metric is not None: - metric.record_backpressure_wait(ts_ns, wait_ns, msg_seq) - - def publisher_sample_inflight(self, pub_id: UUID, ts_ns: int, inflight: int) -> None: - metric = self._publishers.get(pub_id) - if metric is not None: - metric.sample_inflight(ts_ns, inflight) - - def subscriber_receive( - self, - sub_id: UUID, - ts_ns: int, - lease_ns: int, - channel_kind: ProfileChannelType, - msg_seq: int | None = None, - ) -> None: - self._expire_trace_control_if_needed(ts_ns) - metric = self._subscribers.get(sub_id) - if metric is not None: - metric.record_receive(ts_ns, lease_ns, channel_kind, msg_seq) - - def subscriber_user_span( - self, - sub_id: UUID, - ts_ns: int, - span_ns: int, - label: str | None, - msg_seq: int | None = None, - ) -> None: - self._expire_trace_control_if_needed(ts_ns) - metric = self._subscribers.get(sub_id) - if metric is not None: - metric.record_user_span(ts_ns, span_ns, label, msg_seq) - - def subscriber_attributed_backpressure( - self, - sub_id: UUID, - ts_ns: int, - duration_ns: int, - channel_kind: ProfileChannelType, - msg_seq: int | None = None, - ) -> None: - self._expire_trace_control_if_needed(ts_ns) - metric = self._subscribers.get(sub_id) - if metric is not None: - metric.record_attributed_backpressure( - ts_ns, duration_ns, channel_kind, msg_seq - ) - def snapshot(self) -> ProcessProfilingSnapshot: + now_ns = PROFILE_TIME() + last_snapshot_ts_ns = self._last_snapshot_ts_ns + has_previous_snapshot = last_snapshot_ts_ns is not None + window_seconds = ( + float(now_ns - last_snapshot_ts_ns) / 1e9 + if has_previous_snapshot + else 0.0 + ) + self._last_snapshot_ts_ns = now_ns return ProcessProfilingSnapshot( process_id=self._process_id, pid=self._pid, host=self._host, - window_seconds=WINDOW_SECONDS, - timestamp=float(PROFILE_TIME()), + window_seconds=window_seconds, + timestamp=float(now_ns), publishers={ - metric.endpoint_id: metric.snapshot() + metric.endpoint_id: metric.snapshot( + now_ns, + window_seconds, + has_previous_snapshot=has_previous_snapshot, + ) for metric in self._publishers.values() }, subscribers={ - metric.endpoint_id: metric.snapshot() + metric.endpoint_id: metric.snapshot( + now_ns, + has_previous_snapshot=has_previous_snapshot, + ) for metric in self._subscribers.values() }, ) @@ -537,7 +331,6 @@ def trace_batch(self, max_samples: int = 1000) -> ProcessProfilingTraceBatch: heap: list[tuple[float, int, int]] = [] for idx, queue in enumerate(queues): sample = queue[0] - # Include sample_seq to keep deterministic ordering when timestamps tie. seq = sample.sample_seq if sample.sample_seq is not None else -1 heapq.heappush(heap, (sample.timestamp, seq, idx)) @@ -590,6 +383,8 @@ def _apply_trace_control_to_publisher(self, metric: _PublisherMetrics) -> None: metric.trace_enabled = enabled metric.trace_sample_mod = sample_mod metric.trace_metrics = trace_metrics + metric._trace_counter = 0 + metric._last_publish_ts_ns = None def _apply_trace_control_to_subscriber(self, metric: _SubscriberMetrics) -> None: control = self._default_trace_control @@ -607,6 +402,7 @@ def _apply_trace_control_to_subscriber(self, metric: _SubscriberMetrics) -> None metric.trace_enabled = enabled metric.trace_sample_mod = sample_mod metric.trace_metrics = trace_metrics + metric._trace_counter = 0 def _clear_trace_samples(self) -> None: for metric in self._publishers.values(): diff --git a/src/ezmsg/core/pubclient.py b/src/ezmsg/core/pubclient.py index 0929e7b8..d7c89d56 100644 --- a/src/ezmsg/core/pubclient.py +++ b/src/ezmsg/core/pubclient.py @@ -108,6 +108,7 @@ class Publisher: _force_tcp: bool _allow_local: bool _last_backpressure_event: float + _profile: object _graph_address: AddressType | None @@ -283,7 +284,7 @@ def __init__( self._allow_local = _resolve_allow_local(self._force_tcp, allow_local) self._last_backpressure_event = -1 self._graph_address = graph_address - PROFILES.register_publisher(self.id, self.topic, self._num_buffers) + self._profile = PROFILES.register_publisher(self.id, self.topic, self._num_buffers) @property def log_name(self) -> str: @@ -424,18 +425,14 @@ async def _handle_channel( elif msg == Command.RX_ACK.value: msg_id = await read_int(reader) self._backpressure.free(info.id, msg_id % self._num_buffers) - PROFILES.publisher_sample_inflight( - self.id, PROFILE_TIME(), self._backpressure.pressure - ) + self._profile.sample_inflight(self._backpressure.pressure) except (ConnectionResetError, BrokenPipeError): logger.debug(f"Publisher {self.id}: Channel {info.id} connection fail") finally: self._backpressure.free(info.id) - PROFILES.publisher_sample_inflight( - self.id, PROFILE_TIME(), self._backpressure.pressure - ) + self._profile.sample_inflight(self._backpressure.pressure) await close_stream_writer(self._channels[info.id].writer) del self._channels[info.id] @@ -495,15 +492,15 @@ async def broadcast(self, obj: Any) -> None: if BACKPRESSURE_WARNING and (delta > BACKPRESSURE_REFRACTORY): logger.warning(f"{self.topic} under subscriber backpressure!") self._last_backpressure_event = time.time() - wait_start_ns = PROFILE_TIME() - await self._backpressure.wait(buf_idx) - wait_end_ns = PROFILE_TIME() - PROFILES.publisher_backpressure_wait( - self.id, - wait_end_ns, - wait_end_ns - wait_start_ns, - msg_seq=self._msg_id, + trace_backpressure = self._profile.trace_metric_enabled( + "backpressure_wait_ns" ) + wait_start_ns = PROFILE_TIME() if trace_backpressure else None + await self._backpressure.wait(buf_idx) + if trace_backpressure and wait_start_ns is not None: + self._profile.record_backpressure_wait( + PROFILE_TIME() - wait_start_ns, msg_seq=self._msg_id + ) if self._should_use_local_fast_path(): self._local_channel.put_local(self._msg_id, obj) @@ -563,22 +560,14 @@ async def broadcast(self, obj: Any) -> None: channel.writer.write(msg) await channel.writer.drain() self._backpressure.lease(channel.id, buf_idx) - PROFILES.publisher_sample_inflight( - self.id, PROFILE_TIME(), self._backpressure.pressure - ) + self._profile.sample_inflight(self._backpressure.pressure) except (ConnectionResetError, BrokenPipeError): logger.debug( f"Publisher {self.id}: Channel {channel.id} connection fail" ) - now_ns = PROFILE_TIME() - PROFILES.publisher_publish( - self.id, - now_ns, - self._backpressure.pressure, - msg_seq=self._msg_id, - ) + self._profile.record_publish(self._backpressure.pressure, msg_seq=self._msg_id) self._msg_id += 1 def _should_use_local_fast_path(self) -> bool: diff --git a/src/ezmsg/core/subclient.py b/src/ezmsg/core/subclient.py index 4b9f8719..694e6d4f 100644 --- a/src/ezmsg/core/subclient.py +++ b/src/ezmsg/core/subclient.py @@ -42,6 +42,7 @@ class Subscriber: _graph_address: AddressType | None _graph_task: asyncio.Task[None] _incoming: NotificationQueue + _profile: object # FIXME: This event allows Subscriber.create to block until # incoming initial connections (UPDATE) has completed. The @@ -138,7 +139,7 @@ def __init__( else: self._incoming = asyncio.Queue() self._initialized = asyncio.Event() - PROFILES.register_subscriber(self.id, self.topic) + self._profile = PROFILES.register_subscriber(self.id, self.topic) def _handle_dropped_notification( self, notification: typing.Tuple[UUID, int] @@ -304,24 +305,27 @@ async def recv_zero_copy(self) -> typing.AsyncGenerator[typing.Any, None]: channel_kind = getattr(channel, "channel_kind", ProfileChannelType.UNKNOWN) self._active_msg_seq = msg_id try: - start_ns = PROFILE_TIME() + trace_lease = self._profile.trace_metric_enabled("lease_time_ns") + start_ns = PROFILE_TIME() if trace_lease else None with channel.get(msg_id, self.id) as msg: yield msg - end_ns = PROFILE_TIME() - PROFILES.subscriber_receive( - self.id, end_ns, end_ns - start_ns, channel_kind, msg_seq=msg_id - ) + lease_ns = None + if trace_lease and start_ns is not None: + lease_ns = PROFILE_TIME() - start_ns + self._profile.record_receive(channel_kind, lease_ns, msg_seq=msg_id) finally: self._active_msg_seq = None def begin_profile(self) -> int: + if not self._profile.trace_metric_enabled("user_span_ns"): + return 0 return PROFILE_TIME() def end_profile(self, start_ns: int, label: str | None = None) -> None: + if start_ns <= 0: + return end_ns = PROFILE_TIME() - PROFILES.subscriber_user_span( - self.id, - end_ns, + self._profile.record_user_span( end_ns - start_ns, label, msg_seq=self._active_msg_seq, diff --git a/tests/test_profiling_api.py b/tests/test_profiling_api.py index 317faa8b..61efafff 100644 --- a/tests/test_profiling_api.py +++ b/tests/test_profiling_api.py @@ -2,10 +2,13 @@ import pytest +from uuid import uuid4 + from ezmsg.core import profiling as profiling_core from ezmsg.core.graphcontext import GraphContext from ezmsg.core.graphmeta import ( ProcessControlErrorCode, + ProfileChannelType, ProfilingStreamControl, ProfilingTraceControl, ) @@ -13,43 +16,47 @@ from ezmsg.core.processclient import ProcessControlClient -def test_profiling_windows_age_out_during_idle_snapshots(monkeypatch: pytest.MonkeyPatch): - publisher = profiling_core._PublisherMetrics( - topic="TOPIC_IDLE", - endpoint_id="TOPIC_IDLE:ep1", - num_buffers=4, - ) - subscriber = profiling_core._SubscriberMetrics( - topic="TOPIC_IDLE", - endpoint_id="TOPIC_IDLE:sub1", - ) - - publisher.record_publish(0, inflight=0) - publisher.record_publish(int(0.1e9), inflight=0) - subscriber.record_receive( - int(0.1e9), - lease_ns=int(0.2e6), - channel_kind=profiling_core.ProfileChannelType.LOCAL, - ) - - now_ns = {"value": int(0.2e9)} - monkeypatch.setattr(profiling_core, "PROFILE_TIME", lambda: now_ns["value"]) +def test_profiling_snapshot_uses_counter_deltas_between_snapshots( + monkeypatch: pytest.MonkeyPatch, +): + registry = profiling_core.ProfileRegistry() + publisher = registry.register_publisher(uuid4(), "TOPIC_IDLE", 4) + subscriber = registry.register_subscriber(uuid4(), "TOPIC_IDLE") - active_pub = publisher.snapshot() - active_sub = subscriber.snapshot() - assert active_pub.messages_published_window == 2 - assert active_pub.publish_rate_hz_window > 0.0 - assert active_sub.messages_received_window == 1 + now_ns = {"value": int(0.0)} + monkeypatch.setattr(profiling_core, "PROFILE_TIME", lambda: now_ns["value"]) - now_ns["value"] = int(30e9) - idle_pub = publisher.snapshot() - idle_sub = subscriber.snapshot() - assert idle_pub.messages_published_window == 0 - assert idle_pub.publish_rate_hz_window == 0.0 - assert idle_pub.backpressure_wait_ns_window == 0 - assert idle_sub.messages_received_window == 0 - assert idle_sub.attributable_backpressure_ns_window == 0 - assert idle_sub.lease_time_ns_avg_window == 0.0 + publisher.record_publish(inflight=0) + publisher.record_publish(inflight=1) + subscriber.record_receive(ProfileChannelType.LOCAL) + + first = registry.snapshot() + pub_first = next(iter(first.publishers.values())) + sub_first = next(iter(first.subscribers.values())) + assert first.window_seconds == 0.0 + assert pub_first.messages_published_total == 2 + assert pub_first.messages_published_window == 0 + assert pub_first.publish_rate_hz_window == 0.0 + assert pub_first.inflight_messages_current == 1 + assert sub_first.messages_received_total == 1 + assert sub_first.messages_received_window == 0 + assert sub_first.channel_kind_last == ProfileChannelType.LOCAL + + now_ns["value"] = int(1.0e9) + publisher.record_publish(inflight=0) + subscriber.record_receive(ProfileChannelType.SHM) + + second = registry.snapshot() + pub_second = next(iter(second.publishers.values())) + sub_second = next(iter(second.subscribers.values())) + assert second.window_seconds == pytest.approx(1.0) + assert pub_second.messages_published_total == 3 + assert pub_second.messages_published_window == 1 + assert pub_second.publish_rate_hz_window == pytest.approx(1.0) + assert pub_second.inflight_messages_current == 0 + assert sub_second.messages_received_total == 2 + assert sub_second.messages_received_window == 1 + assert sub_second.channel_kind_last == ProfileChannelType.SHM @pytest.mark.asyncio @@ -77,7 +84,7 @@ async def test_process_profiling_snapshot_collects_pub_sub_metrics(): snap = await ctx.process_profiling_snapshot("SYS/U1", timeout=1.0) assert snap.process_id == process_key - assert snap.window_seconds > 0 + assert snap.window_seconds >= 0.0 assert len(snap.publishers) >= 1 assert len(snap.subscribers) >= 1 @@ -85,14 +92,14 @@ async def test_process_profiling_snapshot_collects_pub_sub_metrics(): pub for pub in snap.publishers.values() if pub.topic == "TOPIC_PROF" ) assert pub_metrics.messages_published_total >= 8 - assert pub_metrics.publish_rate_hz_window >= 0.0 + assert pub_metrics.num_buffers > 0 + assert pub_metrics.inflight_messages_current >= 0 sub_metrics = next( sub for sub in snap.subscribers.values() if sub.topic == "TOPIC_PROF" ) assert sub_metrics.messages_received_total >= 8 - assert sub_metrics.lease_time_ns_total > 0 - assert sub_metrics.lease_time_ns_avg_window >= 0.0 + assert sub_metrics.channel_kind_last != ProfileChannelType.UNKNOWN finally: await process.close() await ctx.__aexit__(None, None, None) @@ -180,6 +187,9 @@ async def test_process_profiling_trace_control_and_batch(): ) assert batch.process_id == process_key assert len(batch.samples) > 0 + assert "publish_delta_ns" in {sample.metric for sample in batch.samples} + assert "lease_time_ns" in {sample.metric for sample in batch.samples} + assert "user_span_ns" in {sample.metric for sample in batch.samples} disable_response = await ctx.process_set_profiling_trace( "SYS/U2", From 91db6984c877e193056177c23dabba420bfe2582 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Wed, 1 Apr 2026 13:46:55 -0400 Subject: [PATCH 36/52] small enhancements to hotpath --- src/ezmsg/core/profiling.py | 38 +++++++++++++++++++++---------------- src/ezmsg/core/pubclient.py | 4 +--- src/ezmsg/core/subclient.py | 7 +++---- tests/test_subclient.py | 2 ++ 4 files changed, 28 insertions(+), 23 deletions(-) diff --git a/src/ezmsg/core/profiling.py b/src/ezmsg/core/profiling.py index 9cf5b102..4b8621ea 100644 --- a/src/ezmsg/core/profiling.py +++ b/src/ezmsg/core/profiling.py @@ -42,22 +42,19 @@ class _PublisherMetrics: trace_sample_mod: int = 1 trace_metrics: set[str] | None = None _trace_counter: int = 0 + _trace_publish_delta_enabled: bool = False + _trace_backpressure_wait_enabled: bool = False trace_samples: deque[ProfilingTraceSample] = field( default_factory=lambda: deque(maxlen=TRACE_MAX_SAMPLES) ) - def trace_metric_enabled(self, metric: str) -> bool: - return self.trace_enabled and ( - self.trace_metrics is None or metric in self.trace_metrics - ) - def record_publish(self, inflight: int, msg_seq: int | None = None) -> None: self.messages_published_total += 1 self.inflight_messages_current = inflight - self._trace_counter += 1 - if not self.trace_metric_enabled("publish_delta_ns"): + if not self._trace_publish_delta_enabled: return + self._trace_counter += 1 if self._trace_counter % max(1, self.trace_sample_mod) != 0: return @@ -78,7 +75,7 @@ def record_publish(self, inflight: int, msg_seq: int | None = None) -> None: ) def record_backpressure_wait(self, wait_ns: int, msg_seq: int | None = None) -> None: - if not self.trace_metric_enabled("backpressure_wait_ns"): + if not self._trace_backpressure_wait_enabled: return now_ns = PROFILE_TIME() @@ -136,15 +133,12 @@ class _SubscriberMetrics: trace_sample_mod: int = 1 trace_metrics: set[str] | None = None _trace_counter: int = 0 + _trace_lease_time_enabled: bool = False + _trace_user_span_enabled: bool = False trace_samples: deque[ProfilingTraceSample] = field( default_factory=lambda: deque(maxlen=TRACE_MAX_SAMPLES) ) - def trace_metric_enabled(self, metric: str) -> bool: - return self.trace_enabled and ( - self.trace_metrics is None or metric in self.trace_metrics - ) - def record_receive( self, channel_kind: ProfileChannelType, @@ -153,10 +147,10 @@ def record_receive( ) -> None: self.messages_received_total += 1 self.channel_kind_last = channel_kind - self._trace_counter += 1 - if lease_ns is None or not self.trace_metric_enabled("lease_time_ns"): + if lease_ns is None or not self._trace_lease_time_enabled: return + self._trace_counter += 1 if self._trace_counter % max(1, self.trace_sample_mod) != 0: return @@ -176,7 +170,7 @@ def record_receive( def record_user_span( self, span_ns: int, label: str | None, msg_seq: int | None = None ) -> None: - if not self.trace_metric_enabled("user_span_ns"): + if not self._trace_user_span_enabled: return now_ns = PROFILE_TIME() @@ -385,6 +379,12 @@ def _apply_trace_control_to_publisher(self, metric: _PublisherMetrics) -> None: metric.trace_metrics = trace_metrics metric._trace_counter = 0 metric._last_publish_ts_ns = None + metric._trace_publish_delta_enabled = enabled and ( + trace_metrics is None or "publish_delta_ns" in trace_metrics + ) + metric._trace_backpressure_wait_enabled = enabled and ( + trace_metrics is None or "backpressure_wait_ns" in trace_metrics + ) def _apply_trace_control_to_subscriber(self, metric: _SubscriberMetrics) -> None: control = self._default_trace_control @@ -403,6 +403,12 @@ def _apply_trace_control_to_subscriber(self, metric: _SubscriberMetrics) -> None metric.trace_sample_mod = sample_mod metric.trace_metrics = trace_metrics metric._trace_counter = 0 + metric._trace_lease_time_enabled = enabled and ( + trace_metrics is None or "lease_time_ns" in trace_metrics + ) + metric._trace_user_span_enabled = enabled and ( + trace_metrics is None or "user_span_ns" in trace_metrics + ) def _clear_trace_samples(self) -> None: for metric in self._publishers.values(): diff --git a/src/ezmsg/core/pubclient.py b/src/ezmsg/core/pubclient.py index d7c89d56..706c6c7a 100644 --- a/src/ezmsg/core/pubclient.py +++ b/src/ezmsg/core/pubclient.py @@ -492,9 +492,7 @@ async def broadcast(self, obj: Any) -> None: if BACKPRESSURE_WARNING and (delta > BACKPRESSURE_REFRACTORY): logger.warning(f"{self.topic} under subscriber backpressure!") self._last_backpressure_event = time.time() - trace_backpressure = self._profile.trace_metric_enabled( - "backpressure_wait_ns" - ) + trace_backpressure = self._profile._trace_backpressure_wait_enabled wait_start_ns = PROFILE_TIME() if trace_backpressure else None await self._backpressure.wait(buf_idx) if trace_backpressure and wait_start_ns is not None: diff --git a/src/ezmsg/core/subclient.py b/src/ezmsg/core/subclient.py index 694e6d4f..34a69799 100644 --- a/src/ezmsg/core/subclient.py +++ b/src/ezmsg/core/subclient.py @@ -10,7 +10,6 @@ from .channelmanager import CHANNELS from .messagechannel import NotificationQueue, LeakyQueue, Channel from .profiling import PROFILES, PROFILE_TIME -from .graphmeta import ProfileChannelType from .netprotocol import ( AddressType, @@ -302,10 +301,10 @@ async def recv_zero_copy(self) -> typing.AsyncGenerator[typing.Any, None]: # Stale notification from an unregistered publisher — skip. channel = self._channels[pub_id] - channel_kind = getattr(channel, "channel_kind", ProfileChannelType.UNKNOWN) + channel_kind = channel.channel_kind self._active_msg_seq = msg_id try: - trace_lease = self._profile.trace_metric_enabled("lease_time_ns") + trace_lease = self._profile._trace_lease_time_enabled start_ns = PROFILE_TIME() if trace_lease else None with channel.get(msg_id, self.id) as msg: yield msg @@ -317,7 +316,7 @@ async def recv_zero_copy(self) -> typing.AsyncGenerator[typing.Any, None]: self._active_msg_seq = None def begin_profile(self) -> int: - if not self._profile.trace_metric_enabled("user_span_ns"): + if not self._profile._trace_user_span_enabled: return 0 return PROFILE_TIME() diff --git a/tests/test_subclient.py b/tests/test_subclient.py index d4750790..20c5ea27 100644 --- a/tests/test_subclient.py +++ b/tests/test_subclient.py @@ -5,6 +5,7 @@ import pytest from ezmsg.core.subclient import Subscriber +from ezmsg.core.graphmeta import ProfileChannelType from ezmsg.core.netprotocol import Command, encode_str from ezmsg.core import channelmanager as channelmanager_module from ezmsg.core import subclient as subclient_module @@ -19,6 +20,7 @@ def __init__(self): self.waited = False self.topic = "test" self.num_buffers = 8 + self.channel_kind = ProfileChannelType.LOCAL def register_client(self, client_id, queue, local_backpressure=None): self.clients[client_id] = queue From d2cc166975a2c65e223ca6b5a390bcfcf9defe67 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Wed, 1 Apr 2026 19:46:33 -0400 Subject: [PATCH 37/52] fix metadata registration race --- src/ezmsg/core/backendprocess.py | 19 ++--- src/ezmsg/core/graphserver.py | 116 +++++++++++++++---------------- tests/test_settings_api.py | 70 +++++++++++++++++++ 3 files changed, 133 insertions(+), 72 deletions(-) diff --git a/src/ezmsg/core/backendprocess.py b/src/ezmsg/core/backendprocess.py index 3212a174..64ec1bb2 100644 --- a/src/ezmsg/core/backendprocess.py +++ b/src/ezmsg/core/backendprocess.py @@ -319,7 +319,6 @@ def process(self, loop: asyncio.AbstractEventLoop) -> None: main_func = None context = GraphContext(self.graph_address) process_client = ProcessControlClient(self.graph_address) - process_register_future: concurrent.futures.Future[None] | None = None coro_callables: dict[str, Callable[[], Coroutine[Any, Any, None]]] = dict() settings_input_topics: dict[str, str] = {} current_settings: dict[str, object] = {} @@ -523,6 +522,11 @@ async def report_settings_update_cb( loop=loop, ).result() + asyncio.run_coroutine_threadsafe( + process_client.register([unit.address for unit in self.units]), + loop, + ).result() + except asyncio.CancelledError: pass @@ -535,17 +539,6 @@ async def report_settings_update_cb( logger.debug("Waiting at start barrier!") self.start_barrier.wait() - async def register_process_control() -> None: - try: - await process_client.register([unit.address for unit in self.units]) - except Exception as exc: - logger.warning(f"Process control registration failed: {exc}") - - process_register_future = asyncio.run_coroutine_threadsafe( - register_process_control(), - loop, - ) - for unit in self.units: for thread_fn in unit.threads.values(): loop.run_in_executor(None, thread_fn, unit) @@ -643,8 +636,6 @@ async def shutdown_units() -> None: loop=loop, ) with suppress(Exception): - if process_register_future is not None: - process_register_future.result(timeout=0.5) process_close_future.result() logger.debug(f"Remaining tasks in event loop = {asyncio.all_tasks(loop)}") diff --git a/src/ezmsg/core/graphserver.py b/src/ezmsg/core/graphserver.py index db76aefb..630cbfc2 100644 --- a/src/ezmsg/core/graphserver.py +++ b/src/ezmsg/core/graphserver.py @@ -9,6 +9,7 @@ from collections.abc import Sequence from contextlib import suppress from uuid import UUID, uuid1 +from dataclasses import dataclass from . import __version__ @@ -66,6 +67,13 @@ PERSISTENT_EDGE_OWNER = None +@dataclass +class _SettingsState: + value: SettingsSnapshotValue + metadata_session_id: UUID | None + source_process_id: UUID | None + + class GraphServer(threading.Thread): """ Pub-sub directed acyclic graph (DAG) server. @@ -94,9 +102,7 @@ class GraphServer(threading.Thread): _client_tasks: dict[UUID, "asyncio.Task[None]"] _command_lock: asyncio.Lock - _settings_current: dict[str, SettingsSnapshotValue] - _settings_source_session: dict[str, UUID | None] - _settings_source_process: dict[str, UUID | None] + _settings_state: dict[str, _SettingsState] _settings_events: list[SettingsChangedEvent] _settings_event_seq: int _settings_owned_by_session: dict[UUID, set[str]] @@ -125,9 +131,7 @@ def __init__(self, **kwargs) -> None: self._client_tasks = {} self.shms = {} self._address = None - self._settings_current = {} - self._settings_source_session = {} - self._settings_source_process = {} + self._settings_state = {} self._settings_events = [] self._settings_event_seq = 0 self._settings_owned_by_session = {} @@ -1130,44 +1134,32 @@ async def _handle_process_settings_update_request( if process_info is None: return Command.ERROR.value if update.component_address not in process_info.units: - metadata_owner = self._session_owner_for_component_locked( - update.component_address - ) - known_owner = self._process_owner_for_unit(update.component_address) - allow_startup_race = ( - len(process_info.units) == 0 and metadata_owner is not None - ) - if known_owner == process_client_id or allow_startup_race: - pass - else: - logger.warning( - "Process control %s settings update rejected for unowned component: %s", - process_client_id, - update.component_address, - ) - return Command.ERROR.value - else: - metadata_owner = self._session_owner_for_component_locked( - update.component_address + logger.warning( + "Process control %s settings update rejected for unowned component: %s", + process_client_id, + update.component_address, ) + return Command.ERROR.value - if metadata_owner is None: - source_session_id = self._settings_source_session.get( - update.component_address - ) - else: - source_session_id = metadata_owner + prior_state = self._settings_state.get(update.component_address) + metadata_session_id = self._session_owner_for_component_locked( + update.component_address + ) + if metadata_session_id is None and prior_state is not None: + metadata_session_id = prior_state.metadata_session_id source_process_id = self._process_key(process_client_id) - self._settings_current[update.component_address] = update.value - self._settings_source_session[update.component_address] = source_session_id - self._settings_source_process[update.component_address] = source_process_id + self._settings_state[update.component_address] = _SettingsState( + value=update.value, + metadata_session_id=metadata_session_id, + source_process_id=source_process_id, + ) self._append_settings_event_locked( event_type=SettingsEventType.SETTINGS_UPDATED, component_address=update.component_address, value=update.value, source_session_id=( - str(source_session_id) if source_session_id is not None else None + str(metadata_session_id) if metadata_session_id is not None else None ), source_process_id=source_process_id, timestamp=update.timestamp, @@ -1417,10 +1409,13 @@ def _append_topology_event_locked( def _remove_settings_for_session_locked(self, session_id: UUID) -> None: component_addresses = self._settings_owned_by_session.pop(session_id, set()) for component_address in component_addresses: - if self._settings_source_session.get(component_address) == session_id: - self._settings_current.pop(component_address, None) - self._settings_source_session.pop(component_address, None) - self._settings_source_process.pop(component_address, None) + state = self._settings_state.get(component_address) + if state is None or state.metadata_session_id != session_id: + continue + if state.source_process_id is None: + self._settings_state.pop(component_address, None) + else: + state.metadata_session_id = None def _session_owner_for_component_locked(self, component_address: str) -> UUID | None: for client_id, info in self.clients.items(): @@ -1453,34 +1448,33 @@ def _remove_settings_for_process_locked(self, process_client_id: UUID) -> None: source_process_id = self._process_key(process_client_id) component_addresses = [ component_address - for component_address, owner_process_id in self._settings_source_process.items() - if owner_process_id == source_process_id + for component_address, state in self._settings_state.items() + if state.source_process_id == source_process_id ] for component_address in component_addresses: - source_session_id = self._settings_source_session.get(component_address) - if source_session_id is None: - self._settings_current.pop(component_address, None) - self._settings_source_session.pop(component_address, None) - self._settings_source_process.pop(component_address, None) + state = self._settings_state.get(component_address) + if state is None: + continue + metadata_session_id = state.metadata_session_id + if metadata_session_id is None: + self._settings_state.pop(component_address, None) continue restored = self._initial_settings_for_component_locked( - source_session_id, component_address + metadata_session_id, component_address ) if restored is None: - self._settings_current.pop(component_address, None) - self._settings_source_session.pop(component_address, None) - self._settings_source_process.pop(component_address, None) + self._settings_state.pop(component_address, None) continue - self._settings_current[component_address] = restored - self._settings_source_process[component_address] = None + state.value = restored + state.source_process_id = None self._append_settings_event_locked( event_type=SettingsEventType.SETTINGS_UPDATED, component_address=component_address, value=restored, - source_session_id=str(source_session_id), + source_session_id=str(metadata_session_id), source_process_id=None, ) @@ -1496,9 +1490,15 @@ def _apply_session_metadata_settings_locked( structured_value=initial_repr if isinstance(initial_repr, dict) else None, settings_schema=component.settings_schema, ) - self._settings_current[component.address] = value - self._settings_source_session[component.address] = session_id - self._settings_source_process[component.address] = None + existing_state = self._settings_state.get(component.address) + if existing_state is not None and existing_state.source_process_id is not None: + existing_state.metadata_session_id = session_id + else: + self._settings_state[component.address] = _SettingsState( + value=value, + metadata_session_id=session_id, + source_process_id=None, + ) session_components.add(component.address) self._append_settings_event_locked( event_type=SettingsEventType.INITIAL_SETTINGS, @@ -1591,8 +1591,8 @@ async def _handle_session_settings_snapshot_request( ) -> None: async with self._command_lock: snapshot = { - component_address: self._settings_current[component_address] - for component_address in sorted(self._settings_current) + component_address: self._settings_state[component_address].value + for component_address in sorted(self._settings_state) } snapshot_bytes = pickle.dumps(snapshot) writer.write(uint64_to_bytes(len(snapshot_bytes))) diff --git a/tests/test_settings_api.py b/tests/test_settings_api.py index effbe08a..2eb273dc 100644 --- a/tests/test_settings_api.py +++ b/tests/test_settings_api.py @@ -442,6 +442,46 @@ async def test_process_disconnect_restores_metadata_initial_settings(): graph_server.stop() +@pytest.mark.asyncio +async def test_session_drop_preserves_live_process_owned_settings(): + graph_server = GraphService().create_server() + address = graph_server.address + + owner = GraphContext(address, auto_start=False) + observer = GraphContext(address, auto_start=False) + await owner.__aenter__() + await observer.__aenter__() + await owner.register_metadata(_metadata_with_component("SYS/UNIT_LIVE")) + + process = ProcessControlClient(address) + await process.connect() + await process.register(["SYS/UNIT_LIVE"]) + + try: + await process.report_settings_update( + component_address="SYS/UNIT_LIVE", + value=SettingsSnapshotValue(serialized=None, repr_value={"alpha": 9}), + ) + settings = await observer.settings_snapshot() + assert settings["SYS/UNIT_LIVE"].repr_value == {"alpha": 9} + + await owner._close_session() + await asyncio.sleep(0.05) + + settings = await observer.settings_snapshot() + assert settings["SYS/UNIT_LIVE"].repr_value == {"alpha": 9} + finally: + await process.close() + + await asyncio.sleep(0.05) + settings = await observer.settings_snapshot() + assert "SYS/UNIT_LIVE" not in settings + + await owner.__aexit__(None, None, None) + await observer.__aexit__(None, None, None) + graph_server.stop() + + @pytest.mark.asyncio async def test_process_settings_update_rejected_for_unowned_component(): graph_server = GraphService().create_server() @@ -466,3 +506,33 @@ async def test_process_settings_update_rejected_for_unowned_component(): await process.close() await observer.__aexit__(None, None, None) graph_server.stop() + + +@pytest.mark.asyncio +async def test_process_settings_update_requires_completed_process_registration(): + graph_server = GraphService().create_server() + address = graph_server.address + + owner = GraphContext(address, auto_start=False) + observer = GraphContext(address, auto_start=False) + await owner.__aenter__() + await observer.__aenter__() + await owner.register_metadata(_metadata_with_component("SYS/UNIT_PENDING")) + + process = ProcessControlClient(address) + await process.connect() + + try: + with pytest.raises(RuntimeError, match="Process control command failed"): + await process.report_settings_update( + component_address="SYS/UNIT_PENDING", + value=SettingsSnapshotValue(serialized=None, repr_value={"alpha": 7}), + ) + + settings = await observer.settings_snapshot() + assert settings["SYS/UNIT_PENDING"].repr_value == {"alpha": 1} + finally: + await process.close() + await owner.__aexit__(None, None, None) + await observer.__aexit__(None, None, None) + graph_server.stop() From e25d249e198a219a41d52b02cb5e2d9045d82982 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Wed, 1 Apr 2026 20:11:02 -0400 Subject: [PATCH 38/52] fixed shutdown buffer error --- src/ezmsg/core/backendprocess.py | 51 ++++++++++++++--------------- src/ezmsg/core/graphserver.py | 56 ++++++++++++++++++++++++-------- src/ezmsg/core/profiling.py | 41 ++++++++++++++++++----- src/ezmsg/core/subclient.py | 13 ++++++-- src/ezmsg/util/perf/impl.py | 3 +- tests/test_perf_configs.py | 32 +++++++++++++++++- tests/test_profiling_api.py | 35 ++++++++++++++++++++ 7 files changed, 179 insertions(+), 52 deletions(-) diff --git a/src/ezmsg/core/backendprocess.py b/src/ezmsg/core/backendprocess.py index 64ec1bb2..0216cab9 100644 --- a/src/ezmsg/core/backendprocess.py +++ b/src/ezmsg/core/backendprocess.py @@ -12,14 +12,22 @@ from abc import abstractmethod from dataclasses import dataclass, fields as dataclass_fields, is_dataclass, replace from collections import defaultdict -from collections.abc import Awaitable, Callable, Coroutine, Generator, Mapping, Sequence +from collections.abc import ( + AsyncGenerator, + Awaitable, + Callable, + Coroutine, + Generator, + Mapping, + Sequence, +) from functools import wraps, partial from concurrent.futures import ThreadPoolExecutor from concurrent.futures.thread import _worker from multiprocessing import Process from multiprocessing.synchronize import Event as EventType from multiprocessing.synchronize import Barrier as BarrierType -from contextlib import suppress, contextmanager +from contextlib import suppress, contextmanager, asynccontextmanager from concurrent.futures import TimeoutError from typing import Any @@ -751,14 +759,26 @@ async def handle_subscriber( # Non-leaky subscribers use recv_zero_copy() to hold backpressure during # processing, which provides zero-copy performance but applies backpressure. + @asynccontextmanager + async def next_message() -> AsyncGenerator[Any, None]: + if sub.leaky: + msg = await sub.recv() + try: + yield msg + finally: + del msg + return + + async with sub.recv_zero_copy() as msg: + yield msg + while True: if not callables: sub.close() await sub.wait_closed() break - if sub.leaky: - msg = await sub.recv() + async with next_message() as msg: try: if on_message is not None: try: @@ -780,29 +800,6 @@ async def handle_subscriber( callables.remove(callable) finally: del msg - else: - async with sub.recv_zero_copy() as msg: - try: - if on_message is not None: - try: - await on_message(msg) - except Exception as exc: - logger.warning( - f"Failed to report subscriber message metadata: {exc}" - ) - for callable in list(callables): - try: - span_start_ns = sub.begin_profile() - try: - await callable(msg) - finally: - sub.end_profile( - span_start_ns, getattr(callable, "__name__", None) - ) - except (Complete, NormalTermination): - callables.remove(callable) - finally: - del msg if len(callables) > 1: await asyncio.sleep(0) diff --git a/src/ezmsg/core/graphserver.py b/src/ezmsg/core/graphserver.py index 630cbfc2..e8d96b1d 100644 --- a/src/ezmsg/core/graphserver.py +++ b/src/ezmsg/core/graphserver.py @@ -65,6 +65,9 @@ logger = logging.getLogger("ezmsg") PERSISTENT_EDGE_OWNER = None +SUBSCRIBER_UPDATE_TIMEOUT_SEC = float( + os.environ.get("EZMSG_SUBSCRIBER_UPDATE_TIMEOUT_SEC", "1.0") +) @dataclass @@ -74,6 +77,19 @@ class _SettingsState: source_process_id: UUID | None +@dataclass(frozen=True) +class _RetentionPolicy: + profiling_trace_buffer_limit: int = int( + os.environ.get("EZMSG_PROFILE_TRACE_BUFFER_LIMIT", "200000") + ) + settings_event_history_limit: int = int( + os.environ.get("EZMSG_SETTINGS_EVENT_HISTORY_LIMIT", "10000") + ) + topology_event_history_limit: int = int( + os.environ.get("EZMSG_TOPOLOGY_EVENT_HISTORY_LIMIT", "10000") + ) + + class GraphServer(threading.Thread): """ Pub-sub directed acyclic graph (DAG) server. @@ -116,6 +132,7 @@ class GraphServer(threading.Thread): _profiling_trace_buffers: dict[UUID, deque[tuple[int, ProfilingTraceSample]]] _profiling_trace_process_meta: dict[UUID, tuple[int, str]] _profiling_trace_seq: dict[UUID, int] + _retention_policy: _RetentionPolicy def __init__(self, **kwargs) -> None: super().__init__( @@ -143,6 +160,7 @@ def __init__(self, **kwargs) -> None: self._profiling_trace_buffers = {} self._profiling_trace_process_meta = {} self._profiling_trace_seq = {} + self._retention_policy = _RetentionPolicy() @property def address(self) -> Address: @@ -1014,7 +1032,8 @@ async def _handle_process_profiling_trace_update_request( async with self._command_lock: process_id = self._process_key(process_client_id) trace_buffer = self._profiling_trace_buffers.setdefault( - process_id, deque(maxlen=200_000) + process_id, + deque(maxlen=self._retention_policy.profiling_trace_buffer_limit), ) next_seq = self._profiling_trace_seq.get(process_id, 0) for sample in batch.samples: @@ -1376,7 +1395,7 @@ def _append_settings_event_locked( self._queue_stream_event(queue, event) # Bound memory growth for long-lived servers. - max_events = 10_000 + max_events = self._retention_policy.settings_event_history_limit if len(self._settings_events) > max_events: del self._settings_events[0 : len(self._settings_events) - max_events] @@ -1402,7 +1421,7 @@ def _append_topology_event_locked( for queue in self._topology_subscribers.values(): self._queue_stream_event(queue, event) - max_events = 10_000 + max_events = self._retention_policy.topology_event_history_limit if len(self._topology_events) > max_events: del self._topology_events[0 : len(self._topology_events) - max_events] @@ -1518,36 +1537,40 @@ async def _handle_session_edge_request( ) -> bytes: from_topic = await read_str(reader) to_topic = await read_str(reader) + should_notify = False async with self._command_lock: try: if req == Command.SESSION_CONNECT.value: - topology_changed = self._connect_owner( + should_notify = self._connect_owner( from_topic, to_topic, session_id ) else: - topology_changed = self._disconnect_owner( + should_notify = self._disconnect_owner( from_topic, to_topic, session_id ) except CyclicException: return Command.CYCLIC.value - if topology_changed: + if should_notify: self._append_topology_event_locked( event_type=TopologyEventType.GRAPH_CHANGED, changed_topics=[to_topic], source_session_id=str(session_id), source_process_id=None, ) - await self._notify_downstream_for_topic(to_topic) + + if should_notify: + await self._notify_downstream_for_topic(to_topic) return Command.COMPLETE.value async def _handle_session_clear_request(self, session_id: UUID) -> bytes: async with self._command_lock: notify_topics = self._clear_session_state(session_id) - for topic in notify_topics: - await self._notify_downstream_for_topic(topic) + + for topic in notify_topics: + await self._notify_downstream_for_topic(topic) return Command.COMPLETE.value async def _handle_session_register_request( @@ -1716,13 +1739,20 @@ async def _notify_subscriber(self, sub: SubscriberInfo) -> None: # Update requires us to read a 'COMPLETE' # This cannot be done from this context - async with sub.sync_writer() as writer: - notify_str = ",".join(pub_ids) - writer.write(Command.UPDATE.value) - writer.write(encode_str(notify_str)) + async with asyncio.timeout(SUBSCRIBER_UPDATE_TIMEOUT_SEC): + async with sub.sync_writer() as writer: + notify_str = ",".join(pub_ids) + writer.write(Command.UPDATE.value) + writer.write(encode_str(notify_str)) except (ConnectionResetError, BrokenPipeError) as e: logger.debug(f"Failed to update Subscriber {sub.id}: {e}") + except TimeoutError: + logger.warning( + "Timed out waiting for Subscriber %s to apply routing update for topic %s", + sub.id, + sub.topic, + ) def _publishers(self) -> list[PublisherInfo]: return [ diff --git a/src/ezmsg/core/profiling.py b/src/ezmsg/core/profiling.py index 4b8621ea..27b9cdf8 100644 --- a/src/ezmsg/core/profiling.py +++ b/src/ezmsg/core/profiling.py @@ -139,19 +139,39 @@ class _SubscriberMetrics: default_factory=lambda: deque(maxlen=TRACE_MAX_SAMPLES) ) + def begin_message(self, channel_kind: ProfileChannelType) -> bool: + self.messages_received_total += 1 + self.channel_kind_last = channel_kind + + if not (self._trace_lease_time_enabled or self._trace_user_span_enabled): + return False + + self._trace_counter += 1 + return self._trace_counter % max(1, self.trace_sample_mod) == 0 + def record_receive( self, channel_kind: ProfileChannelType, lease_ns: int | None = None, msg_seq: int | None = None, ) -> None: - self.messages_received_total += 1 - self.channel_kind_last = channel_kind + sampled = self.begin_message(channel_kind) + self.record_lease_time( + channel_kind, + lease_ns, + msg_seq=msg_seq, + sampled=sampled, + ) - if lease_ns is None or not self._trace_lease_time_enabled: - return - self._trace_counter += 1 - if self._trace_counter % max(1, self.trace_sample_mod) != 0: + def record_lease_time( + self, + channel_kind: ProfileChannelType, + lease_ns: int | None, + msg_seq: int | None = None, + *, + sampled: bool, + ) -> None: + if lease_ns is None or not self._trace_lease_time_enabled or not sampled: return now_ns = PROFILE_TIME() @@ -168,9 +188,14 @@ def record_receive( ) def record_user_span( - self, span_ns: int, label: str | None, msg_seq: int | None = None + self, + span_ns: int, + label: str | None, + msg_seq: int | None = None, + *, + sampled: bool, ) -> None: - if not self._trace_user_span_enabled: + if not self._trace_user_span_enabled or not sampled: return now_ns = PROFILE_TIME() diff --git a/src/ezmsg/core/subclient.py b/src/ezmsg/core/subclient.py index 34a69799..a3ad3246 100644 --- a/src/ezmsg/core/subclient.py +++ b/src/ezmsg/core/subclient.py @@ -131,6 +131,7 @@ def __init__( self._channels = dict() self._active_msg_seq: int | None = None + self._active_trace_sampled = False if self.leaky: self._incoming = LeakyQueue( 1 if max_queue is None else max_queue, self._handle_dropped_notification @@ -303,6 +304,7 @@ async def recv_zero_copy(self) -> typing.AsyncGenerator[typing.Any, None]: channel = self._channels[pub_id] channel_kind = channel.channel_kind self._active_msg_seq = msg_id + self._active_trace_sampled = self._profile.begin_message(channel_kind) try: trace_lease = self._profile._trace_lease_time_enabled start_ns = PROFILE_TIME() if trace_lease else None @@ -311,12 +313,18 @@ async def recv_zero_copy(self) -> typing.AsyncGenerator[typing.Any, None]: lease_ns = None if trace_lease and start_ns is not None: lease_ns = PROFILE_TIME() - start_ns - self._profile.record_receive(channel_kind, lease_ns, msg_seq=msg_id) + self._profile.record_lease_time( + channel_kind, + lease_ns, + msg_seq=msg_id, + sampled=self._active_trace_sampled, + ) finally: self._active_msg_seq = None + self._active_trace_sampled = False def begin_profile(self) -> int: - if not self._profile._trace_user_span_enabled: + if not self._profile._trace_user_span_enabled or not self._active_trace_sampled: return 0 return PROFILE_TIME() @@ -328,4 +336,5 @@ def end_profile(self, start_ns: int, label: str | None = None) -> None: end_ns - start_ns, label, msg_seq=self._active_msg_seq, + sampled=self._active_trace_sampled, ) diff --git a/src/ezmsg/util/perf/impl.py b/src/ezmsg/util/perf/impl.py index 6bc83d1d..55e2c30b 100644 --- a/src/ezmsg/util/perf/impl.py +++ b/src/ezmsg/util/perf/impl.py @@ -209,7 +209,8 @@ def fanin(config: ConfigSettings) -> Configuration: """many pubs to one sub""" connections: ez.NetworkDefinition = [(config.source.OUTPUT, config.sink.INPUT)] pubs = [LoadTestSource(config.settings) for _ in range(config.n_clients)] - expected_num_msgs = config.sink.SETTINGS.num_msgs * len(pubs) + total_publishers = 1 + len(pubs) + expected_num_msgs = config.sink.SETTINGS.num_msgs * total_publishers config.sink.SETTINGS = replace(config.sink.SETTINGS, num_msgs=expected_num_msgs) # type: ignore for pub in pubs: connections.append((pub.OUTPUT, config.sink.INPUT)) diff --git a/tests/test_perf_configs.py b/tests/test_perf_configs.py index fc7178aa..239a8868 100644 --- a/tests/test_perf_configs.py +++ b/tests/test_perf_configs.py @@ -1,4 +1,5 @@ import contextlib +from dataclasses import replace import os import tempfile from pathlib import Path @@ -6,7 +7,16 @@ import pytest from ezmsg.core.graphserver import GraphServer -from ezmsg.util.perf.impl import Communication, CONFIGS, perform_test +from ezmsg.util.perf.impl import ( + Communication, + CONFIGS, + ConfigSettings, + LoadTestSettings, + LoadTestSink, + LoadTestSource, + fanin, + perform_test, +) PERF_MAX_DURATION = 0.5 @@ -93,6 +103,26 @@ def test_fanin_perf(perf_graph_server, comm, msg_size): _run_perf_case("fanin", comm, msg_size, perf_graph_server) +def test_fanin_config_counts_all_publishers(): + settings = LoadTestSettings( + max_duration=1.0, + num_msgs=8, + dynamic_size=64, + buffers=2, + force_tcp=False, + ) + source = LoadTestSource(settings) + sink = LoadTestSink(settings) + + clients, connections = fanin( + ConfigSettings(n_clients=2, settings=settings, source=source, sink=sink) + ) + + assert len(clients) == 2 + assert len(connections) == 3 + assert sink.SETTINGS == replace(settings, num_msgs=24) + + @pytest.mark.parametrize("msg_size", PERF_MSG_SIZES, ids=lambda s: f"msg={s}") @pytest.mark.parametrize("comm", list(Communication), ids=lambda c: f"comm={c.value}") def test_relay_perf(perf_graph_server, comm, msg_size): diff --git a/tests/test_profiling_api.py b/tests/test_profiling_api.py index 61efafff..ffdd281d 100644 --- a/tests/test_profiling_api.py +++ b/tests/test_profiling_api.py @@ -59,6 +59,41 @@ def test_profiling_snapshot_uses_counter_deltas_between_snapshots( assert sub_second.channel_kind_last == ProfileChannelType.SHM +def test_subscriber_trace_sampling_uses_one_decision_per_message(): + registry = profiling_core.ProfileRegistry() + subscriber = registry.register_subscriber(uuid4(), "TOPIC_SAMPLE") + registry.set_trace_control( + ProfilingTraceControl( + enabled=True, + sample_mod=3, + subscriber_topics=["TOPIC_SAMPLE"], + metrics=["lease_time_ns", "user_span_ns"], + ) + ) + + for msg_seq in range(7): + sampled = subscriber.begin_message(ProfileChannelType.LOCAL) + subscriber.record_lease_time( + ProfileChannelType.LOCAL, + 100 + msg_seq, + msg_seq=msg_seq, + sampled=sampled, + ) + subscriber.record_user_span( + 200 + msg_seq, + "taskA", + msg_seq=msg_seq, + sampled=sampled, + ) + + batch = registry.trace_batch(max_samples=100) + lease_samples = [sample for sample in batch.samples if sample.metric == "lease_time_ns"] + user_span_samples = [sample for sample in batch.samples if sample.metric == "user_span_ns"] + + assert [sample.sample_seq for sample in lease_samples] == [2, 5] + assert [sample.sample_seq for sample in user_span_samples] == [2, 5] + + @pytest.mark.asyncio async def test_process_profiling_snapshot_collects_pub_sub_metrics(): graph_server = GraphService().create_server() From 0b97f75cda853219b82e2c3e6e4ee624c697aeb7 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Wed, 1 Apr 2026 21:28:49 -0400 Subject: [PATCH 39/52] Fix Python 3.10 subscriber update timeout --- src/ezmsg/core/graphserver.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/ezmsg/core/graphserver.py b/src/ezmsg/core/graphserver.py index e8d96b1d..aa04d2b7 100644 --- a/src/ezmsg/core/graphserver.py +++ b/src/ezmsg/core/graphserver.py @@ -1739,15 +1739,18 @@ async def _notify_subscriber(self, sub: SubscriberInfo) -> None: # Update requires us to read a 'COMPLETE' # This cannot be done from this context - async with asyncio.timeout(SUBSCRIBER_UPDATE_TIMEOUT_SEC): + async def send_update() -> None: async with sub.sync_writer() as writer: notify_str = ",".join(pub_ids) writer.write(Command.UPDATE.value) writer.write(encode_str(notify_str)) + await asyncio.wait_for( + send_update(), timeout=SUBSCRIBER_UPDATE_TIMEOUT_SEC + ) except (ConnectionResetError, BrokenPipeError) as e: logger.debug(f"Failed to update Subscriber {sub.id}: {e}") - except TimeoutError: + except asyncio.TimeoutError: logger.warning( "Timed out waiting for Subscriber %s to apply routing update for topic %s", sub.id, From 71071b36972ec0881539088906567ee30d3173ab Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Thu, 2 Apr 2026 09:23:47 -0400 Subject: [PATCH 40/52] maybe fix windows-only(?) test failure --- tests/test_profiling_api.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/test_profiling_api.py b/tests/test_profiling_api.py index ffdd281d..5c7d2859 100644 --- a/tests/test_profiling_api.py +++ b/tests/test_profiling_api.py @@ -191,6 +191,9 @@ async def test_process_profiling_trace_control_and_batch(): assert process.client_id is not None process_key = process.client_id await process.register(["SYS/U2"]) + # Keep direct batch assertions deterministic by preventing the background + # trace-push loop from draining the local trace buffer during this test. + process._trace_push_interval_s = 60.0 pub = await ctx.publisher("TOPIC_TRACE") sub = await ctx.subscriber("TOPIC_TRACE") @@ -339,6 +342,9 @@ async def test_process_profiling_trace_control_endpoint_metric_and_ttl(): process = ProcessControlClient(address) await process.connect() await process.register(["SYS/U5"]) + # This test reads batches directly from the process, so disable the + # automatic push loop long enough that it cannot race and drain samples. + process._trace_push_interval_s = 60.0 pub_a = await ctx.publisher("TOPIC_A") sub_a = await ctx.subscriber("TOPIC_A") @@ -526,6 +532,9 @@ async def test_process_profiling_trace_batch_interleaves_publisher_and_subscribe process = ProcessControlClient(address) await process.connect() await process.register(["SYS/U8"]) + # Avoid races with the automatic graph-server push path when reading the + # process-local trace batch directly in this test. + process._trace_push_interval_s = 60.0 pub = await ctx.publisher("TOPIC_TRACE_MIX") sub = await ctx.subscriber("TOPIC_TRACE_MIX") @@ -572,6 +581,9 @@ async def test_process_profiling_trace_control_change_clears_stale_trace_samples process = ProcessControlClient(address) await process.connect() await process.register(["SYS/U9"]) + # This test validates local batch contents after a control change, so keep + # the background push loop from draining new samples before the assertion. + process._trace_push_interval_s = 60.0 pub_old = await ctx.publisher("TOPIC_TRACE_OLD") sub_old = await ctx.subscriber("TOPIC_TRACE_OLD") From b70dbde36d58c7667ab7f2c69b0bbed22421ce87 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Mon, 6 Apr 2026 13:44:39 -0400 Subject: [PATCH 41/52] command line TLC --- pyproject.toml | 1 - src/ezmsg/core/command.py | 233 ++++++---------------------- src/ezmsg/core/commands/__init__.py | 15 ++ src/ezmsg/core/commands/common.py | 25 +++ src/ezmsg/core/commands/graphviz.py | 19 +++ src/ezmsg/core/commands/mermaid.py | 77 +++++++++ src/ezmsg/core/commands/serve.py | 30 ++++ src/ezmsg/core/commands/shutdown.py | 26 ++++ src/ezmsg/core/commands/start.py | 36 +++++ src/ezmsg/util/perf/command.py | 27 +++- tests/test_command.py | 68 ++++++++ 11 files changed, 363 insertions(+), 194 deletions(-) create mode 100644 src/ezmsg/core/commands/__init__.py create mode 100644 src/ezmsg/core/commands/common.py create mode 100644 src/ezmsg/core/commands/graphviz.py create mode 100644 src/ezmsg/core/commands/mermaid.py create mode 100644 src/ezmsg/core/commands/serve.py create mode 100644 src/ezmsg/core/commands/shutdown.py create mode 100644 src/ezmsg/core/commands/start.py create mode 100644 tests/test_command.py diff --git a/pyproject.toml b/pyproject.toml index 754f647d..6d796f65 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,7 +55,6 @@ axisarray = [ [project.scripts] ezmsg = "ezmsg.core.command:cmdline" -ezmsg-perf = "ezmsg.util.perf.command:command" [project.optional-dependencies] axisarray = [ diff --git a/src/ezmsg/core/command.py b/src/ezmsg/core/command.py index 09ce2dc3..61e76f97 100644 --- a/src/ezmsg/core/command.py +++ b/src/ezmsg/core/command.py @@ -1,33 +1,30 @@ import argparse import asyncio -import base64 -import json -import logging -import subprocess -import sys -import webbrowser -import zlib +import inspect -from .graphserver import GraphService +from ezmsg.util.perf.command import setup_perf_cmdline + +from .commands import setup_core_cmdline +from .commands.graphviz import handle_graphviz +from .commands.mermaid import handle_mermaid, mermaid_url as mm +from .commands.serve import handle_serve +from .commands.shutdown import handle_shutdown +from .commands.start import handle_start from .netprotocol import ( Address, GRAPHSERVER_ADDR_ENV, GRAPHSERVER_PORT_DEFAULT, PUBLISHER_START_PORT_ENV, PUBLISHER_START_PORT_DEFAULT, - close_stream_writer, ) -logger = logging.getLogger("ezmsg") - -def cmdline() -> None: +def build_parser() -> argparse.ArgumentParser: """ - Command-line interface for ezmsg core server management. + Build the ezmsg core command-line parser. - Provides commands for starting, stopping, and managing ezmsg server - processes including GraphServer and SHMServer, as well as utilities - for graph visualization. + Each command gets its own subparser so command-specific options are not + shared globally across unrelated commands. """ parser = argparse.ArgumentParser( "ezmsg.core", @@ -38,63 +35,27 @@ def cmdline() -> None: Publishers will be assigned available ports starting from {PUBLISHER_START_PORT_DEFAULT}. (Change with ${PUBLISHER_START_PORT_ENV}) """, ) + subparsers = parser.add_subparsers(dest="command", required=True, help="command for ezmsg") - parser.add_argument( - "command", - help="command for ezmsg", - choices=["serve", "start", "shutdown", "graphviz", "mermaid"], - ) - - parser.add_argument("--address", help="Address for GraphServer", default=None) - - parser.add_argument( - "--target", - help="Target for mermaid output. Options are 'ink', 'live', and 'play'.", - default="live", - ) - - parser.add_argument( - "-c", - "--compact", - help="""Use compact graph representation. Only used when `cmd` is 'mermaid' or 'graphviz'. - Removes the lowest level of detail (typically streams). Can be stacked (eg. '-cc'). - Warning: this will also prune the graph of proxy topics (nodes that are both sources and targets). - """, - action="count", - ) - - parser.add_argument( - "-n", - "--nobrowser", - help="Do not automatically open the browser for mermaid output. `--target` value will be ignored.", - action="store_true", - ) - - class Args: - command: str - address: str | None - target: str - compact: int | None - nobrowser: bool + setup_core_cmdline(subparsers) + setup_perf_cmdline(subparsers) + return parser - args = parser.parse_args(namespace=Args) - graph_address = Address("127.0.0.1", GRAPHSERVER_PORT_DEFAULT) - if args.address is not None: - graph_address = Address.from_string(args.address) +def cmdline(argv: list[str] | None = None) -> None: + """ + Command-line interface for ezmsg core server management. - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + Provides commands for starting, stopping, and managing ezmsg server + processes including GraphServer and SHMServer, as well as utilities + for graph visualization. + """ + parser = build_parser() + args = parser.parse_args(args=argv) - loop.run_until_complete( - run_command( - args.command, - graph_address, - args.target, - args.compact, - args.nobrowser, - ) - ) + result = args._handler(args) + if inspect.isawaitable(result): + asyncio.run(result) async def run_command( @@ -104,122 +65,20 @@ async def run_command( compact: int | None = None, nobrowser: bool = False, ) -> None: - """ - Run an ezmsg command with the specified parameters. - - This function handles various ezmsg commands like 'serve', 'start', 'shutdown', etc. - and manages the graph and shared memory services. - - :param cmd: The command to execute ('serve', 'start', 'shutdown', 'graphviz', 'mermaid') - :type cmd: str - :param graph_address: Address of the graph service - :type graph_address: Address - :param target: Target for visualization commands (default: 'live') - :type target: str - :param compact: Compactification level for visualization commands - :type compact: int | None - :param nobrowser: Whether to suppress browser opening for visualization - :type nobrowser: bool - """ - graph_service = GraphService(graph_address) - - if cmd == "serve": - logger.info(f"GraphServer Address: {graph_address}") - - graph_server = graph_service.create_server() - - try: - logger.info("Servers running...") - graph_server.join() - - except KeyboardInterrupt: - logger.info("Interrupt detected; shutting down servers") - - finally: - if graph_server is not None: - graph_server.stop() - - elif cmd == "start": - popen = subprocess.Popen( - [sys.executable, "-m", "ezmsg.core", "serve", f"--address={graph_address}"] - ) - - while True: - try: - _, writer = await graph_service.open_connection() - await close_stream_writer(writer) - break - except ConnectionRefusedError: - await asyncio.sleep(0.1) - - logger.info(f"Forked ezmsg servers in PID: {popen.pid}") - - elif cmd == "shutdown": - try: - await graph_service.shutdown() - logger.info( - f"Issued shutdown command to GraphServer @ {graph_service.address}" - ) - - except ConnectionRefusedError: - logger.warning( - f"Could not issue shutdown command to GraphServer @ {graph_service.address}; server not running?" - ) - - elif cmd in ["graphviz", "mermaid"]: - graph_out = await graph_service.get_formatted_graph( - fmt=cmd, compact_level=compact - ) - print(graph_out) - if cmd == "mermaid": - if not nobrowser: - if target == "live": - print( - "%% If the graph does not render immediately, try toggling the 'Pan & Zoom' button." - ) - webbrowser.open(mm(graph_out, target=target)) - - -def mm(graph: str, target="live") -> str: - """ - Generate a Mermaid visualization URL for the given graph. - - :param graph: Graph representation string to visualize. - :type graph: str - :param target: Target platform ('live' or 'ink'). - :type target: str - :return: URL for graph visualization. - :rtype: str - """ - if target != "ink": - jdict = { - "code": graph, - "mermaid": {"theme": "default"}, - "updateDiagram": True, - "autoSync": True, - "rough": False, - } - graph = json.dumps(jdict) - graphbytes: bytes = graph.encode("utf8") - - if target != "ink": - compress = zlib.compressobj(9, zlib.DEFLATED, 15, 8, zlib.Z_DEFAULT_STRATEGY) - graphbytes = compress.compress(graphbytes) - graphbytes += compress.flush() - - base64_bytes = base64.b64encode(graphbytes) - base64_string = base64_bytes.decode("ascii") - - if target == "ink": - prefix = "https://mermaid.ink/img/" - elif target in ["live", "play"]: - type_str = "pako" # or "base64" if we skip compression above. - if target == "live": - prefix = f"https://mermaid.live/edit#{type_str}:" - else: # "play" - prefix = f"https://www.mermaidchart.com/play#{type_str}:" - else: - raise ValueError( - f"Unknown mermaid target '{target}'. Available options are 'ink', 'live', or 'play'." - ) - return prefix + base64_string + handlers = { + "serve": handle_serve, + "start": handle_start, + "shutdown": handle_shutdown, + "graphviz": handle_graphviz, + "mermaid": handle_mermaid, + } + if cmd not in handlers: + raise ValueError(f"Unknown ezmsg command '{cmd}'") + args = argparse.Namespace( + command=cmd, + address=str(graph_address), + target=target, + compact=compact, + nobrowser=nobrowser, + ) + await handlers[cmd](args) diff --git a/src/ezmsg/core/commands/__init__.py b/src/ezmsg/core/commands/__init__.py new file mode 100644 index 00000000..7177e520 --- /dev/null +++ b/src/ezmsg/core/commands/__init__.py @@ -0,0 +1,15 @@ +import argparse + +from .graphviz import setup_graphviz_cmdline +from .mermaid import setup_mermaid_cmdline +from .serve import setup_serve_cmdline +from .shutdown import setup_shutdown_cmdline +from .start import setup_start_cmdline + + +def setup_core_cmdline(subparsers: argparse._SubParsersAction) -> None: + setup_serve_cmdline(subparsers) + setup_start_cmdline(subparsers) + setup_shutdown_cmdline(subparsers) + setup_graphviz_cmdline(subparsers) + setup_mermaid_cmdline(subparsers) diff --git a/src/ezmsg/core/commands/common.py b/src/ezmsg/core/commands/common.py new file mode 100644 index 00000000..2b8eaf51 --- /dev/null +++ b/src/ezmsg/core/commands/common.py @@ -0,0 +1,25 @@ +import argparse + +from ..netprotocol import Address, GRAPHSERVER_PORT_DEFAULT + + +def add_address_argument(parser: argparse.ArgumentParser) -> None: + parser.add_argument("--address", help="Address for GraphServer", default=None) + + +def add_compact_argument(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "-c", + "--compact", + help="""Use compact graph representation. + Removes the lowest level of detail (typically streams). Can be stacked (eg. '-cc'). + Warning: this will also prune the graph of proxy topics (nodes that are both sources and targets). + """, + action="count", + ) + + +def graph_address_from_args(args: argparse.Namespace) -> Address: + if args.address is None: + return Address("127.0.0.1", GRAPHSERVER_PORT_DEFAULT) + return Address.from_string(args.address) diff --git a/src/ezmsg/core/commands/graphviz.py b/src/ezmsg/core/commands/graphviz.py new file mode 100644 index 00000000..2c0c2e07 --- /dev/null +++ b/src/ezmsg/core/commands/graphviz.py @@ -0,0 +1,19 @@ +import argparse + +from ..graphserver import GraphService +from .common import add_address_argument, add_compact_argument, graph_address_from_args + + +async def handle_graphviz(args: argparse.Namespace) -> None: + graph_service = GraphService(graph_address_from_args(args)) + graph_out = await graph_service.get_formatted_graph( + fmt="graphviz", compact_level=args.compact + ) + print(graph_out) + + +def setup_graphviz_cmdline(subparsers: argparse._SubParsersAction) -> None: + parser = subparsers.add_parser("graphviz") + add_address_argument(parser) + add_compact_argument(parser) + parser.set_defaults(_handler=handle_graphviz) diff --git a/src/ezmsg/core/commands/mermaid.py b/src/ezmsg/core/commands/mermaid.py new file mode 100644 index 00000000..dd59ad76 --- /dev/null +++ b/src/ezmsg/core/commands/mermaid.py @@ -0,0 +1,77 @@ +import argparse +import base64 +import json +import webbrowser +import zlib + +from ..graphserver import GraphService +from .common import add_address_argument, add_compact_argument, graph_address_from_args + + +def mermaid_url(graph: str, target: str = "live") -> str: + if target != "ink": + graph = json.dumps( + { + "code": graph, + "mermaid": {"theme": "default"}, + "updateDiagram": True, + "autoSync": True, + "rough": False, + } + ) + + graphbytes = graph.encode("utf8") + + if target != "ink": + compress = zlib.compressobj(9, zlib.DEFLATED, 15, 8, zlib.Z_DEFAULT_STRATEGY) + graphbytes = compress.compress(graphbytes) + compress.flush() + + base64_string = base64.b64encode(graphbytes).decode("ascii") + + if target == "ink": + prefix = "https://mermaid.ink/img/" + elif target == "live": + prefix = "https://mermaid.live/edit#pako:" + elif target == "play": + prefix = "https://www.mermaidchart.com/play#pako:" + else: + raise ValueError( + f"Unknown mermaid target '{target}'. Available options are 'ink', 'live', or 'play'." + ) + + return prefix + base64_string + + +async def handle_mermaid(args: argparse.Namespace) -> None: + graph_service = GraphService(graph_address_from_args(args)) + graph_out = await graph_service.get_formatted_graph( + fmt="mermaid", compact_level=args.compact + ) + print(graph_out) + + if args.nobrowser: + return + + if args.target == "live": + print( + "%% If the graph does not render immediately, try toggling the 'Pan & Zoom' button." + ) + webbrowser.open(mermaid_url(graph_out, target=args.target)) + + +def setup_mermaid_cmdline(subparsers: argparse._SubParsersAction) -> None: + parser = subparsers.add_parser("mermaid") + add_address_argument(parser) + add_compact_argument(parser) + parser.add_argument( + "--target", + help="Target for mermaid output. Options are 'ink', 'live', and 'play'.", + default="live", + ) + parser.add_argument( + "-n", + "--nobrowser", + help="Do not automatically open the browser for mermaid output. `--target` value will be ignored.", + action="store_true", + ) + parser.set_defaults(_handler=handle_mermaid) diff --git a/src/ezmsg/core/commands/serve.py b/src/ezmsg/core/commands/serve.py new file mode 100644 index 00000000..075bad18 --- /dev/null +++ b/src/ezmsg/core/commands/serve.py @@ -0,0 +1,30 @@ +import argparse +import asyncio +import logging + +from ..graphserver import GraphService +from .common import add_address_argument, graph_address_from_args + +logger = logging.getLogger("ezmsg") + + +async def handle_serve(args: argparse.Namespace) -> None: + graph_address = graph_address_from_args(args) + graph_service = GraphService(graph_address) + + logger.info(f"GraphServer Address: {graph_address}") + graph_server = graph_service.create_server() + + try: + logger.info("Servers running...") + await asyncio.to_thread(graph_server.join) + except KeyboardInterrupt: + logger.info("Interrupt detected; shutting down servers") + finally: + graph_server.stop() + + +def setup_serve_cmdline(subparsers: argparse._SubParsersAction) -> None: + parser = subparsers.add_parser("serve") + add_address_argument(parser) + parser.set_defaults(_handler=handle_serve) diff --git a/src/ezmsg/core/commands/shutdown.py b/src/ezmsg/core/commands/shutdown.py new file mode 100644 index 00000000..0f4dc77f --- /dev/null +++ b/src/ezmsg/core/commands/shutdown.py @@ -0,0 +1,26 @@ +import argparse +import logging + +from ..graphserver import GraphService +from .common import add_address_argument, graph_address_from_args + +logger = logging.getLogger("ezmsg") + + +async def handle_shutdown(args: argparse.Namespace) -> None: + graph_address = graph_address_from_args(args) + graph_service = GraphService(graph_address) + + try: + await graph_service.shutdown() + logger.info(f"Issued shutdown command to GraphServer @ {graph_service.address}") + except ConnectionRefusedError: + logger.warning( + f"Could not issue shutdown command to GraphServer @ {graph_service.address}; server not running?" + ) + + +def setup_shutdown_cmdline(subparsers: argparse._SubParsersAction) -> None: + parser = subparsers.add_parser("shutdown") + add_address_argument(parser) + parser.set_defaults(_handler=handle_shutdown) diff --git a/src/ezmsg/core/commands/start.py b/src/ezmsg/core/commands/start.py new file mode 100644 index 00000000..8133b0db --- /dev/null +++ b/src/ezmsg/core/commands/start.py @@ -0,0 +1,36 @@ +import argparse +import asyncio +import logging +import subprocess +import sys + +from ..graphserver import GraphService +from ..netprotocol import close_stream_writer +from .common import add_address_argument, graph_address_from_args + +logger = logging.getLogger("ezmsg") + + +async def handle_start(args: argparse.Namespace) -> None: + graph_address = graph_address_from_args(args) + graph_service = GraphService(graph_address) + + popen = subprocess.Popen( + [sys.executable, "-m", "ezmsg.core", "serve", f"--address={graph_address}"] + ) + + while True: + try: + _, writer = await graph_service.open_connection() + await close_stream_writer(writer) + break + except ConnectionRefusedError: + await asyncio.sleep(0.1) + + logger.info(f"Forked ezmsg servers in PID: {popen.pid}") + + +def setup_start_cmdline(subparsers: argparse._SubParsersAction) -> None: + parser = subparsers.add_parser("start") + add_address_argument(parser) + parser.set_defaults(_handler=handle_start) diff --git a/src/ezmsg/util/perf/command.py b/src/ezmsg/util/perf/command.py index 9dab1f8e..31f70451 100644 --- a/src/ezmsg/util/perf/command.py +++ b/src/ezmsg/util/perf/command.py @@ -1,4 +1,5 @@ import argparse +import sys from .ab import setup_ab_cmdline from .analysis import setup_summary_cmdline @@ -6,16 +7,30 @@ from .run import setup_run_cmdline -def command() -> None: +def setup_perf_cmdline(subparsers: argparse._SubParsersAction) -> None: + parser = subparsers.add_parser("perf", help="performance test utilities") + perf_subparsers = parser.add_subparsers(dest="perf_command", required=True) + + setup_run_cmdline(perf_subparsers) + setup_hotpath_cmdline(perf_subparsers) + setup_ab_cmdline(perf_subparsers) + setup_summary_cmdline(perf_subparsers) + + +def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="ezmsg perf test utility") subparsers = parser.add_subparsers(dest="command", required=True) + setup_perf_cmdline(subparsers) + return parser + + +def command(argv: list[str] | None = None) -> None: + parser = build_parser() - setup_run_cmdline(subparsers) - setup_hotpath_cmdline(subparsers) - setup_ab_cmdline(subparsers) - setup_summary_cmdline(subparsers) + if argv is None: + argv = ["perf", *sys.argv[1:]] - ns = parser.parse_args() + ns = parser.parse_args(argv) ns._handler(ns) diff --git a/tests/test_command.py b/tests/test_command.py new file mode 100644 index 00000000..6a737253 --- /dev/null +++ b/tests/test_command.py @@ -0,0 +1,68 @@ +import pytest + +from ezmsg.core.command import build_parser + + +def test_mermaid_subparser_accepts_mermaid_specific_args(): + parser = build_parser() + + args = parser.parse_args( + [ + "mermaid", + "--address", + "127.0.0.1:4000", + "--target", + "ink", + "-cc", + "--nobrowser", + ] + ) + + assert args.command == "mermaid" + assert args.address == "127.0.0.1:4000" + assert args.target == "ink" + assert args.compact == 2 + assert args.nobrowser is True + + +def test_perf_subparser_accepts_nested_perf_args(): + parser = build_parser() + + args = parser.parse_args( + [ + "perf", + "hotpath", + "--count", + "10", + "--samples", + "2", + "--quiet", + ] + ) + + assert args.command == "perf" + assert args.perf_command == "hotpath" + assert args.count == 10 + assert args.samples == 2 + assert args.quiet is True + + +def test_graphviz_subparser_rejects_mermaid_only_args(): + parser = build_parser() + + with pytest.raises(SystemExit): + parser.parse_args(["graphviz", "--nobrowser"]) + + +def test_serve_subparser_rejects_visualization_args(): + parser = build_parser() + + with pytest.raises(SystemExit): + parser.parse_args(["serve", "--target", "play"]) + + +def test_perf_subparser_rejects_core_only_args(): + parser = build_parser() + + with pytest.raises(SystemExit): + parser.parse_args(["perf", "hotpath", "--address", "127.0.0.1:4000"]) From 75b65ebb40b98ebe480fe507b30cb9a59fc0bf7b Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Mon, 6 Apr 2026 15:34:51 -0400 Subject: [PATCH 42/52] refactored perf commandline and outputs --- src/ezmsg/core/command.py | 4 +- src/ezmsg/util/perf/analysis.py | 1244 +++++++++++++++++++++++-------- src/ezmsg/util/perf/command.py | 9 +- src/ezmsg/util/perf/run.py | 62 +- tests/test_command.py | 47 +- tests/test_perf_analysis.py | 218 ++++++ 6 files changed, 1236 insertions(+), 348 deletions(-) create mode 100644 tests/test_perf_analysis.py diff --git a/src/ezmsg/core/command.py b/src/ezmsg/core/command.py index 61e76f97..f9c753d3 100644 --- a/src/ezmsg/core/command.py +++ b/src/ezmsg/core/command.py @@ -2,8 +2,6 @@ import asyncio import inspect -from ezmsg.util.perf.command import setup_perf_cmdline - from .commands import setup_core_cmdline from .commands.graphviz import handle_graphviz from .commands.mermaid import handle_mermaid, mermaid_url as mm @@ -38,6 +36,8 @@ def build_parser() -> argparse.ArgumentParser: subparsers = parser.add_subparsers(dest="command", required=True, help="command for ezmsg") setup_core_cmdline(subparsers) + from ezmsg.util.perf.command import setup_perf_cmdline + setup_perf_cmdline(subparsers) return parser diff --git a/src/ezmsg/util/perf/analysis.py b/src/ezmsg/util/perf/analysis.py index 590e681f..ea3db477 100644 --- a/src/ezmsg/util/perf/analysis.py +++ b/src/ezmsg/util/perf/analysis.py @@ -1,7 +1,7 @@ -import json -import dataclasses import argparse +import dataclasses import html +import json import math import webbrowser @@ -9,12 +9,7 @@ from ..messagecodec import MessageDecoder from .envinfo import TestEnvironmentInfo, format_env_diff -from .run import get_datestamp -from .impl import ( - TestParameters, - Metrics, - TestLogEntry, -) +from .impl import Metrics, TestLogEntry, TestParameters import ezmsg.core as ez @@ -46,7 +41,7 @@ - shm_spread / tcp_spread: each client in its own process; comms via SHM / TCP respectively Variables: -- n_clients: pubs (fanin), subs (fanout), or relays (relay) +- n_clients: pubs (fanin), subs (fanout), or relays (relay) - msg_size: nominal message size (bytes) Metrics: @@ -55,10 +50,69 @@ - latency_mean: average send -> receive latency in seconds (lower = better) """ +KEY_COLUMNS = ["config", "comms", "n_clients", "msg_size"] +DISPLAY_METRICS = [ + "sample_rate_mean", + "sample_rate_median", + "data_rate", + "latency_mean", + "latency_median", +] +METRIC_LABELS = { + "sample_rate_mean": "sample_rate_mean", + "sample_rate_median": "sample_rate_median", + "data_rate": "data_rate", + "latency_mean": "latency_mean", + "latency_median": "latency_median", +} +METRIC_GROUPS = { + "sample_rate_mean": "throughput", + "sample_rate_median": "throughput", + "data_rate": "throughput", + "latency_mean": "latency", + "latency_median": "latency", +} +ABSOLUTE_UNITS = { + "sample_rate_mean": "msgs/s", + "sample_rate_median": "msgs/s", + "data_rate": "MB/s", + "latency_mean": "us", + "latency_median": "us", +} +RELATIVE_UNITS = {metric: "%" for metric in DISPLAY_METRICS} +NOISE_BAND_PCT = 10.0 + + +@dataclasses.dataclass +class MetricDelta: + metric: str + config: str + comms: str + n_clients: int + msg_size: int + value: float + score: float + + +@dataclasses.dataclass +class ReportBundle: + perf_path: Path + baseline_path: Path | None + info: TestEnvironmentInfo + baseline_info: TestEnvironmentInfo | None + env_diff: str | None + relative: bool + terminal_df: pd.DataFrame + candidate_df: pd.DataFrame + baseline_df: pd.DataFrame | None + relative_df: pd.DataFrame | None + delta_counts: dict[str, int] + top_improvements: list[MetricDelta] + top_regressions: list[MetricDelta] + def load_perf(perf: Path) -> xr.Dataset: all_results: dict[TestParameters, dict[int, list[Metrics]]] = dict() - run_idx = 0 with open(perf, "r") as perf_f: @@ -74,10 +128,10 @@ def load_perf(perf: Path) -> xr.Dataset: runs[run_idx] = metrics all_results[obj.params] = runs - n_clients_axis = list(sorted(set([p.n_clients for p in all_results.keys()]))) - msg_size_axis = list(sorted(set([p.msg_size for p in all_results.keys()]))) - comms_axis = list(sorted(set([p.comms for p in all_results.keys()]))) - config_axis = list(sorted(set([p.config for p in all_results.keys()]))) + n_clients_axis = list(sorted({p.n_clients for p in all_results})) + msg_size_axis = list(sorted({p.msg_size for p in all_results})) + comms_axis = list(sorted({p.comms for p in all_results})) + config_axis = list(sorted({p.config for p in all_results})) dims = ["n_clients", "msg_size", "comms", "config"] coords = { @@ -89,7 +143,7 @@ def load_perf(perf: Path) -> xr.Dataset: data_vars = {} for field in dataclasses.fields(Metrics): - m = ( + metric_values = ( np.zeros( ( len(n_clients_axis), @@ -100,400 +154,944 @@ def load_perf(perf: Path) -> xr.Dataset: ) * np.nan ) - for p, a in all_results.items(): - # tests are run multiple times; get the median of means - m[ - n_clients_axis.index(p.n_clients), - msg_size_axis.index(p.msg_size), - comms_axis.index(p.comms), - config_axis.index(p.config), + for params, runs in all_results.items(): + metric_values[ + n_clients_axis.index(params.n_clients), + msg_size_axis.index(params.msg_size), + comms_axis.index(params.comms), + config_axis.index(params.config), ] = np.median( - [np.mean([getattr(v, field.name) for v in r]) for r in a.values()] + [np.mean([getattr(v, field.name) for v in run]) for run in runs.values()] ) - data_vars[field.name] = xr.DataArray(m, dims=dims, coords=coords) + data_vars[field.name] = xr.DataArray(metric_values, dims=dims, coords=coords) - dataset = xr.Dataset(data_vars, attrs=dict(info=info)) - return dataset + return xr.Dataset(data_vars, attrs=dict(info=info)) -def _escape(s: str) -> str: - return html.escape(str(s), quote=True) +def default_report_html_path(perf_path: Path) -> Path: + return perf_path.with_suffix(".html") -def _env_block(title: str, body: str) -> str: - return f""" -
-

{_escape(title)}

-
{_escape(body).strip()}
-
- """ +def default_compare_html_path(perf_path: Path, baseline_path: Path) -> Path: + return perf_path.with_name(f"{perf_path.stem}.vs_{baseline_path.stem}.html") -def _legend_block() -> str: - return """ -
-

Legend

-
    -
  • Comparison mode: values are percentages (100 = no change).
  • -
  • Green: improvement (↑ sample/data rate, ↓ latency).
  • -
  • Red: regression (↓ sample/data rate, ↑ latency).
  • -
-
- """ +def _escape(value: object) -> str: + return html.escape(str(value), quote=True) + + +def _frame_from_dataset(dataset: xr.Dataset) -> pd.DataFrame: + frame = dataset.to_dataframe().dropna(how="all") + frame = frame.reset_index() + frame = frame.dropna(subset=DISPLAY_METRICS, how="all") + return frame.sort_values(KEY_COLUMNS).reset_index(drop=True) + + +def _display_frame(frame: pd.DataFrame, relative: bool) -> pd.DataFrame: + out = frame.copy() + out = out[KEY_COLUMNS + DISPLAY_METRICS] + if not relative: + out["data_rate"] = out["data_rate"] / 2**20 + out["latency_mean"] = out["latency_mean"] * 1e6 + out["latency_median"] = out["latency_median"] * 1e6 + return out + + +def _metric_score(metric: str, value: float) -> float: + if not (isinstance(value, (int, float)) and math.isfinite(value)): + return 0.0 + if "latency" in metric: + return 100.0 - value + return value - 100.0 + + +def _classify_metric(metric: str, value: float, noise_band_pct: float = NOISE_BAND_PCT) -> str: + score = _metric_score(metric, value) + if abs(score) <= noise_band_pct: + return "neutral" + return "improvement" if score > 0 else "regression" + + +def _collect_metric_deltas( + relative_df: pd.DataFrame | None, noise_band_pct: float = NOISE_BAND_PCT +) -> tuple[dict[str, int], list[MetricDelta], list[MetricDelta]]: + if relative_df is None: + return {"improvement": 0, "neutral": 0, "regression": 0}, [], [] + + counts = {"improvement": 0, "neutral": 0, "regression": 0} + improvements: list[MetricDelta] = [] + regressions: list[MetricDelta] = [] + + for _, row in relative_df.iterrows(): + for metric in DISPLAY_METRICS: + value = float(row[metric]) + classification = _classify_metric(metric, value, noise_band_pct=noise_band_pct) + counts[classification] += 1 + score = _metric_score(metric, value) + if classification == "improvement": + improvements.append( + MetricDelta( + metric=metric, + config=str(row["config"]), + comms=str(row["comms"]), + n_clients=int(row["n_clients"]), + msg_size=int(row["msg_size"]), + value=value, + score=score, + ) + ) + elif classification == "regression": + regressions.append( + MetricDelta( + metric=metric, + config=str(row["config"]), + comms=str(row["comms"]), + n_clients=int(row["n_clients"]), + msg_size=int(row["msg_size"]), + value=value, + score=score, + ) + ) + + improvements.sort(key=lambda item: item.score, reverse=True) + regressions.sort(key=lambda item: item.score) + return counts, improvements, regressions + + +def _format_terminal_value(metric: str, value: float, relative: bool) -> str: + if pd.isna(value): + return "nan" + if relative: + return f"{float(value):.1f}%" + if metric in {"latency_mean", "latency_median"}: + return f"{float(value):,.3f}" + if metric == "data_rate": + return f"{float(value):,.3f}" + return f"{float(value):,.2f}" + + +def _terminal_table(frame: pd.DataFrame, relative: bool) -> str: + formatted = frame.copy() + for metric in DISPLAY_METRICS: + formatted[metric] = formatted[metric].map( + lambda value, metric=metric: _format_terminal_value(metric, value, relative) + ) + + sections: list[str] = [] + for (config, comms), group in formatted.groupby(["config", "comms"], sort=False): + sections.append(f"{config} / {comms}") + sections.append( + group[["n_clients", "msg_size", *DISPLAY_METRICS]].to_string(index=False) + ) + sections.append("") + return "\n".join(sections).strip() + + +def _terminal_delta_summary( + counts: dict[str, int], + improvements: list[MetricDelta], + regressions: list[MetricDelta], + limit: int = 5, +) -> str: + lines = [ + "COMPARISON OVERVIEW", + ( + f" improvements: {counts['improvement']}, neutral: {counts['neutral']}, " + f"regressions: {counts['regression']}" + ), + "", + ] + + if regressions: + lines.append("BIGGEST REGRESSIONS") + for delta in regressions[:limit]: + lines.append( + " " + f"{delta.metric}: {delta.value:.1f}% " + f"({delta.config}/{delta.comms}, n_clients={delta.n_clients}, msg_size={delta.msg_size})" + ) + lines.append("") + + if improvements: + lines.append("BIGGEST IMPROVEMENTS") + for delta in improvements[:limit]: + lines.append( + " " + f"{delta.metric}: {delta.value:.1f}% " + f"({delta.config}/{delta.comms}, n_clients={delta.n_clients}, msg_size={delta.msg_size})" + ) + lines.append("") + + return "\n".join(lines).strip() + + +def build_report_bundle(perf_path: Path, baseline_path: Path | None = None) -> ReportBundle: + candidate = load_perf(perf_path) + info: TestEnvironmentInfo = candidate.attrs["info"] + candidate_frame = _display_frame(_frame_from_dataset(candidate), relative=False) + + baseline_info: TestEnvironmentInfo | None = None + env_diff: str | None = None + relative = baseline_path is not None + baseline_frame: pd.DataFrame | None = None + relative_frame: pd.DataFrame | None = None + + if baseline_path is not None: + baseline = load_perf(baseline_path) + baseline_info = baseline.attrs["info"] + env_diff = format_env_diff(info.diff(baseline_info)) + baseline_frame = _display_frame(_frame_from_dataset(baseline), relative=False) + relative_dataset = (candidate / baseline) * 100.0 + relative_dataset = relative_dataset.drop_vars(["latency_total", "num_msgs"]) + relative_frame = _display_frame(_frame_from_dataset(relative_dataset), relative=True) + terminal_df = relative_frame + else: + terminal_df = candidate_frame + + delta_counts, top_improvements, top_regressions = _collect_metric_deltas(relative_frame) + + return ReportBundle( + perf_path=perf_path, + baseline_path=baseline_path, + info=info, + baseline_info=baseline_info, + env_diff=env_diff, + relative=relative, + terminal_df=terminal_df, + candidate_df=candidate_frame, + baseline_df=baseline_frame, + relative_df=relative_frame, + delta_counts=delta_counts, + top_improvements=top_improvements, + top_regressions=top_regressions, + ) + + +def _build_terminal_output(bundle: ReportBundle) -> str: + lines = [str(bundle.info), ""] + + if bundle.relative: + lines.extend( + [ + "PERFORMANCE COMPARISON", + "", + bundle.env_diff or "No differences.", + "", + _terminal_delta_summary( + bundle.delta_counts, bundle.top_improvements, bundle.top_regressions + ), + "", + ] + ) + else: + lines.extend(["PERFORMANCE REPORT", ""]) + + lines.extend([_terminal_table(bundle.terminal_df, relative=bundle.relative), ""]) + return "\n".join(line for line in lines if line is not None).strip() + + +def _format_html_number(metric: str, value: float, mode: str) -> str: + if not (isinstance(value, (int, float)) and math.isfinite(value)): + return "n/a" + if mode == "relative": + return f"{value:.1f}%" + if metric in {"latency_mean", "latency_median"}: + return f"{value:,.3f}" + if metric == "data_rate": + return f"{value:,.3f}" + return f"{value:,.2f}" + + +def _color_for_comparison( + value: float, metric: str, noise_band_pct: float = NOISE_BAND_PCT +) -> str: + if not (isinstance(value, (int, float)) and math.isfinite(value)): + return "" + score = _metric_score(metric, value) + magnitude = abs(score) + if magnitude <= noise_band_pct: + return "" + + scale = max(0.0, min(1.0, (magnitude - noise_band_pct) / 45.0)) + hue = "var(--green)" if score > 0 else "var(--red)" + alpha = 0.15 + 0.35 * scale + return f"background-color: hsla({hue}, 70%, 45%, {alpha});" def _base_css() -> str: - # Minimal, print-friendly CSS + color scales for cells. return """ """ -def _color_for_comparison( - value: float, metric: str, noise_band_pct: float = 10.0 -) -> str: - """ - Returns inline CSS background for a comparison % value. - value: e.g., 97.3, 104.8, etc. - For sample_rate/data_rate: improvement > 100 (good). - For latency_mean: improvement < 100 (good). - Noise band ±10% around 100 is neutral. - """ - if not (isinstance(value, (int, float)) and math.isfinite(value)): - return "" +def _render_delta_list(title: str, deltas: list[MetricDelta], empty: str) -> str: + if not deltas: + return f"

{_escape(title)}

{_escape(empty)}
" - delta = value - 100.0 - # Determine direction: + is good for sample/data; - is good for latency - if "rate" in metric: - # positive delta good, negative bad - magnitude = abs(delta) - sign_good = delta > 0 - elif "latency" in metric: - # negative delta good (lower latency) - magnitude = abs(delta) - sign_good = delta < 0 - else: - return "" + items = [] + for delta in deltas[:5]: + items.append( + "
  • " + f"{_escape(delta.metric)}: {_escape(f'{delta.value:.1f}%')} " + f"({_escape(delta.config)}/{_escape(delta.comms)}, " + f"n_clients={delta.n_clients}, msg_size={delta.msg_size})" + "
  • " + ) + return ( + f"

    {_escape(title)}

    " + f"
      {''.join(items)}
    " + ) - # Noise band: keep neutral - if magnitude <= noise_band_pct: - return "" - # Scale 5%..50% across 0..1; clamp - scale = max(0.0, min(1.0, (magnitude - noise_band_pct) / 45.0)) +def _build_rows(bundle: ReportBundle) -> list[dict[str, object]]: + rows: list[dict[str, object]] = [] + baseline_lookup: dict[tuple[str, str, int, int], dict[str, object]] = {} + relative_lookup: dict[tuple[str, str, int, int], dict[str, object]] = {} - # Choose hue and lightness; use HSL with gentle saturation - hue = "var(--green)" if sign_good else "var(--red)" - # opacity via alpha blend on lightness via HSLa - # Use saturation ~70%, lightness around 40–50% blended with table bg - alpha = 0.15 + 0.35 * scale # 0.15..0.50 - return f"background-color: hsla({hue}, 70%, 45%, {alpha});" + if bundle.baseline_df is not None: + for row in bundle.baseline_df.to_dict("records"): + key = (str(row["config"]), str(row["comms"]), int(row["n_clients"]), int(row["msg_size"])) + baseline_lookup[key] = row + if bundle.relative_df is not None: + for row in bundle.relative_df.to_dict("records"): + key = (str(row["config"]), str(row["comms"]), int(row["n_clients"]), int(row["msg_size"])) + relative_lookup[key] = row -def _format_number(x) -> str: - if isinstance(x, (int,)) and not isinstance(x, bool): - return f"{x:d}" - try: - xf = float(x) - except Exception: - return _escape(str(x)) - # Heuristic: for comparison percentages, 1 decimal is nice; for absolute, 3 decimals for latency. - return f"{xf:.3f}" + for row in bundle.candidate_df.to_dict("records"): + key = (str(row["config"]), str(row["comms"]), int(row["n_clients"]), int(row["msg_size"])) + baseline_row = baseline_lookup.get(key) + relative_row = relative_lookup.get(key) + severity = 0.0 + if relative_row is not None: + severity = max(abs(_metric_score(metric, float(relative_row[metric]))) for metric in DISPLAY_METRICS) + rows.append( + { + "config": key[0], + "comms": key[1], + "n_clients": key[2], + "msg_size": key[3], + "candidate": row, + "baseline": baseline_row, + "relative": relative_row, + "severity": severity, + } + ) + return rows -def summary(perf_path: Path, baseline_path: Path | None, html: bool = False) -> None: - """print perf test results and comparisons to the console""" +def _options(values: list[object], label: str) -> str: + options = [""] + for value in values: + options.append(f"") + return ( + f"" + ) - output = "" - perf = load_perf(perf_path) - info: TestEnvironmentInfo = perf.attrs["info"] - output += str(info) + "\n\n" +def _report_title(bundle: ReportBundle) -> str: + return "ezmsg Performance Comparison" if bundle.relative else "ezmsg Performance Report" - relative = False - env_diff = None - if baseline_path is not None: - relative = True - output += "PERFORMANCE COMPARISON\n\n" - baseline = load_perf(baseline_path) - perf = (perf / baseline) * 100.0 - baseline_info: TestEnvironmentInfo = baseline.attrs["info"] - env_diff = format_env_diff(info.diff(baseline_info)) - output += env_diff + "\n\n" - - # These raw stats are still valuable to have, but are confusing - # when making relative comparisons - perf = perf.drop_vars(["latency_total", "num_msgs"]) - - perf = perf.stack(params=["n_clients", "msg_size"]).dropna("params") - df = perf.squeeze().to_dataframe() - df = df.drop("n_clients", axis=1) - df = df.drop("msg_size", axis=1) - - for _, config_ds in perf.groupby("config"): - for _, comms_ds in config_ds.groupby("comms"): - output += str(comms_ds.squeeze().to_dataframe()) + "\n\n" - output += "\n" - - print(output) - - if html: - # Ensure expected columns exist - expected_cols = { - "sample_rate_mean", - "sample_rate_median", - "data_rate", - "latency_mean", - "latency_median", - } - missing = expected_cols - set(df.columns) - if missing: - raise ValueError(f"Missing expected columns in dataset: {missing}") - - # We'll render a table per (config, comms) group. - groups = ( - df.reset_index() - .sort_values(by=["config", "comms", "n_clients", "msg_size"]) - .groupby(["config", "comms"], sort=False) - ) - # Build HTML - parts: list[str] = [] - parts.append("") +def render_html_report(bundle: ReportBundle) -> str: + rows = _build_rows(bundle) + filters = { + "config": sorted(bundle.candidate_df["config"].unique().tolist()), + "comms": sorted(bundle.candidate_df["comms"].unique().tolist()), + "n_clients": sorted(bundle.candidate_df["n_clients"].unique().tolist()), + "msg_size": sorted(bundle.candidate_df["msg_size"].unique().tolist()), + } + mode = "relative" if bundle.relative else "candidate" + + parts: list[str] = [] + parts.append("") + parts.append( + "" + ) + parts.append(f"{_escape(_report_title(bundle))}") + parts.append(_base_css()) + parts.append("
    ") + + parts.append("
    ") + parts.append(f"

    {_escape(_report_title(bundle))}

    ") + sub = str(bundle.perf_path) + if bundle.baseline_path is not None: + sub += f" relative to {bundle.baseline_path}" + parts.append(f"
    {_escape(sub)}
    ") + parts.append("
    ") + + if bundle.relative: + parts.append("

    Comparison Overview

    ") + for label, value in [ + ("Improvements", bundle.delta_counts["improvement"]), + ("Neutral", bundle.delta_counts["neutral"]), + ("Regressions", bundle.delta_counts["regression"]), + ]: + parts.append( + "
    " + f"{_escape(label)}" + f"{_escape(value)}" + "
    " + ) + parts.append("
    ") + + parts.append("
    ") parts.append( - "" + _render_delta_list( + "Biggest Regressions", + bundle.top_regressions, + "No regressions outside the noise band.", + ) ) - parts.append("ezmsg perf report") - parts.append(_base_css()) - parts.append("
    ") - - parts.append("
    ") - parts.append("

    ezmsg Performance Report

    ") - sub = str(perf_path) - if baseline_path is not None: - sub += f" relative to {str(baseline_path)}" - parts.append(f"
    {_escape(sub)}
    ") - parts.append("
    ") - - if info is not None: - parts.append(_env_block("Test Environment", str(info))) - - parts.append(_env_block("Test Details", TEST_DESCRIPTION)) - - if env_diff is not None: - # Show diffs using your helper - parts.append("
    ") - parts.append("

    Environment Differences vs Baseline

    ") - parts.append(f"
    {_escape(env_diff)}
    ") - parts.append("
    ") - parts.append(_legend_block()) - - # Render each group - for (config, comms), g in groups: - # Keep only expected columns in order - cols = [ - "n_clients", - "msg_size", - "sample_rate_mean", - "sample_rate_median", - "data_rate", - "latency_mean", - "latency_median", - ] - g = g[cols].copy() + parts.append( + _render_delta_list( + "Biggest Improvements", + bundle.top_improvements, + "No improvements outside the noise band.", + ) + ) + parts.append("
    ") + + parts.append("
    ") + parts.append("

    Environment

    ") + parts.append(f"
    {_escape(str(bundle.info))}
    ") + parts.append("
    ") + + if bundle.env_diff is not None: + parts.append("
    ") + parts.append("

    Environment Differences vs Baseline

    ") + parts.append(f"
    {_escape(bundle.env_diff)}
    ") + parts.append("
    ") - # String format some columns (msg_size with separators) - g["msg_size"] = g["msg_size"].map( - lambda x: f"{int(x):,}" if pd.notna(x) else x + parts.append("

    Test Details

    ") + parts.append(f"
    {_escape(TEST_DESCRIPTION.strip())}
    ") + parts.append("
    ") + + parts.append("
    ") + parts.append("

    Explore

    ") + parts.append("
    ") + parts.append(_options(filters["config"], "Config")) + parts.append(_options(filters["comms"], "Comms")) + parts.append(_options(filters["n_clients"], "N_clients")) + parts.append(_options(filters["msg_size"], "Msg_size")) + parts.append( + "" + ) + parts.append("
    ") + parts.append("
    ") + for view, label in [("all", "All Metrics"), ("throughput", "Throughput"), ("latency", "Latency")]: + active = " active" if view == "all" else "" + parts.append( + f"" + ) + if bundle.relative: + parts.append("") + for display_mode, label in [ + ("relative", "Relative"), + ("candidate", "Candidate"), + ("baseline", "Baseline"), + ]: + active = " active" if display_mode == mode else "" + parts.append( + f"" ) + parts.append("
    ") + parts.append("
    ") - # Build table manually so we can inject inline cell styles easily - # (pandas Styler is great but produces bulky HTML; manual keeps it clean) - header = f""" - - - n_clients - msg_size {"" if relative else "(b)"} - sample_rate_mean {"" if relative else "(msgs/s)"} - sample_rate_median {"" if relative else "(msgs/s)"} - data_rate {"" if relative else "(MB/s)"} - latency_mean {"" if relative else "(us)"} - latency_median {"" if relative else "(us)"} - - - """ - body_rows: list[str] = [] - for _, row in g.iterrows(): - sr, srm, dr, lt, lm = ( - row["sample_rate_mean"], - row["sample_rate_median"], - row["data_rate"], - row["latency_mean"], - row["latency_median"], - ) - dr = dr if relative else dr / 2**20 - lt = lt if relative else lt * 1e6 - lm = lm if relative else lm * 1e6 - sr_style = ( - _color_for_comparison(sr, "sample_rate_mean") if relative else "" - ) - srm_style = ( - _color_for_comparison(srm, "sample_rate_median") if relative else "" - ) - dr_style = _color_for_comparison(dr, "data_rate") if relative else "" - lt_style = _color_for_comparison(lt, "latency_mean") if relative else "" - lm_style = ( - _color_for_comparison(lm, "latency_median") if relative else "" - ) + parts.append("

    Results

    ") + for column in ["config", "comms", "n_clients", "msg_size"]: + parts.append(f"") + for metric in DISPLAY_METRICS: + parts.append( + f"" + ) + parts.append("") - body_rows.append( - "" - f"" - f"" - f"" - f"" - f"" - f"" - f"" - "" - ) - table_html = f"
    {_escape(column)}
    {_format_number(row['n_clients'])}{_escape(row['msg_size'])}{_format_number(sr)}{_format_number(srm)}{_format_number(dr)}{_format_number(lt)}{_format_number(lm)}
    {header}{''.join(body_rows)}
    " + for row in rows: + parts.append( + "" + ) + for column in ["config", "comms", "n_clients", "msg_size"]: + parts.append(f"{_escape(row[column])}") + candidate = row["candidate"] + baseline = row["baseline"] + relative_row = row["relative"] + for metric in DISPLAY_METRICS: + candidate_value = float(candidate[metric]) + baseline_value = ( + float(baseline[metric]) if baseline is not None else float("nan") + ) + relative_value = ( + float(relative_row[metric]) if relative_row is not None else float("nan") + ) + style = ( + _color_for_comparison(relative_value, metric) + if bundle.relative and math.isfinite(relative_value) + else "" + ) + initial_value = ( + relative_value if bundle.relative else candidate_value + ) + initial_mode = "relative" if bundle.relative else "candidate" parts.append( - f"

    " - f"{_escape(config)}" - f"{_escape(comms)}" - f"

    {table_html}
    " + "" + f"{_escape(_format_html_number(metric, initial_value, initial_mode))}" + "" ) + parts.append("") - parts.append("
    ") - html_text = "".join(parts) + parts.append("") + parts.append( + """ + + """ + % ( + json.dumps(METRIC_LABELS), + json.dumps(ABSOLUTE_UNITS), + json.dumps(RELATIVE_UNITS), + json.dumps(mode), + ) + ) + parts.append("
    ") + return "".join(parts) + + +def write_html_report( + perf_path: Path, + baseline_path: Path | None = None, + output_path: Path | None = None, + open_browser: bool = False, +) -> Path: + bundle = build_report_bundle(perf_path, baseline_path=baseline_path) + if output_path is None: + output_path = ( + default_compare_html_path(perf_path, baseline_path) + if baseline_path is not None + else default_report_html_path(perf_path) + ) + html_text = render_html_report(bundle) + output_path.write_text(html_text, encoding="utf-8") + if open_browser: + webbrowser.open(output_path.resolve().as_uri()) + return output_path + + +def report( + perf_path: Path, + output_path: Path | None = None, + open_browser: bool = True, +) -> None: + bundle = build_report_bundle(perf_path) + print(_build_terminal_output(bundle)) + out_path = write_html_report( + perf_path=perf_path, + output_path=output_path, + open_browser=open_browser, + ) + print(f"\nHTML report: {out_path}") + + +def compare( + perf_path: Path, + baseline_path: Path, + output_path: Path | None = None, + open_browser: bool = True, +) -> None: + bundle = build_report_bundle(perf_path, baseline_path=baseline_path) + print(_build_terminal_output(bundle)) + out_path = write_html_report( + perf_path=perf_path, + baseline_path=baseline_path, + output_path=output_path, + open_browser=open_browser, + ) + print(f"\nComparison report: {out_path}") + + +def setup_report_cmdline(subparsers: argparse._SubParsersAction) -> None: + parser = subparsers.add_parser( + "report", help="render an absolute report for a benchmark file" + ) + parser.add_argument("perf", type=Path, help="benchmark file") + parser.add_argument( + "--output", type=Path, - help="perf test", + default=None, + help="optional explicit HTML output path", + ) + parser.add_argument( + "--no-browser", + action="store_true", + help="write HTML without opening it in a browser", + ) + parser.set_defaults( + _handler=lambda ns: report( + perf_path=ns.perf, + output_path=ns.output, + open_browser=not ns.no_browser, + ) + ) + + +def setup_compare_cmdline(subparsers: argparse._SubParsersAction) -> None: + parser = subparsers.add_parser( + "compare", help="compare one benchmark file against a baseline" ) - p_summary.add_argument( + parser.add_argument("perf", type=Path, help="candidate benchmark file") + parser.add_argument( "--baseline", "-b", type=Path, + required=True, + help="baseline benchmark file for comparison", + ) + parser.add_argument( + "--output", + type=Path, default=None, - help="baseline perf test for comparison", + help="optional explicit HTML output path", ) - p_summary.add_argument( - "--html", + parser.add_argument( + "--no-browser", action="store_true", - help="generate an html output file and render results in browser", + help="write HTML without opening it in a browser", ) - - p_summary.set_defaults( - _handler=lambda ns: summary( - perf_path=ns.perf, baseline_path=ns.baseline, html=ns.html + parser.set_defaults( + _handler=lambda ns: compare( + perf_path=ns.perf, + baseline_path=ns.baseline, + output_path=ns.output, + open_browser=not ns.no_browser, ) ) diff --git a/src/ezmsg/util/perf/command.py b/src/ezmsg/util/perf/command.py index 31f70451..b15200a4 100644 --- a/src/ezmsg/util/perf/command.py +++ b/src/ezmsg/util/perf/command.py @@ -2,19 +2,20 @@ import sys from .ab import setup_ab_cmdline -from .analysis import setup_summary_cmdline +from .analysis import setup_compare_cmdline, setup_report_cmdline from .hotpath import setup_hotpath_cmdline -from .run import setup_run_cmdline +from .run import setup_benchmark_cmdline def setup_perf_cmdline(subparsers: argparse._SubParsersAction) -> None: parser = subparsers.add_parser("perf", help="performance test utilities") perf_subparsers = parser.add_subparsers(dest="perf_command", required=True) - setup_run_cmdline(perf_subparsers) + setup_benchmark_cmdline(perf_subparsers) + setup_report_cmdline(perf_subparsers) + setup_compare_cmdline(perf_subparsers) setup_hotpath_cmdline(perf_subparsers) setup_ab_cmdline(perf_subparsers) - setup_summary_cmdline(perf_subparsers) def build_parser() -> argparse.ArgumentParser: diff --git a/src/ezmsg/util/perf/run.py b/src/ezmsg/util/perf/run.py index 3f794f21..1bebc86d 100644 --- a/src/ezmsg/util/perf/run.py +++ b/src/ezmsg/util/perf/run.py @@ -9,6 +9,7 @@ from datetime import datetime, timedelta from contextlib import contextmanager, redirect_stdout, redirect_stderr +from pathlib import Path import ezmsg.core as ez from ezmsg.core.graphserver import GraphServer @@ -98,7 +99,11 @@ def get_datestamp() -> str: return datetime.now().strftime("%Y%m%d_%H%M%S") -def perf_run( +def output_paths_for_name(name: str) -> tuple[Path, Path]: + return Path(f"perf_{name}.txt"), Path(f"report_{name}.html") + + +def benchmark( max_duration: float, num_msgs: int, num_buffers: int, @@ -110,17 +115,20 @@ def perf_run( configs: typing.Iterable[str] | None, grid: bool, warmup_dur: float, -) -> None: + name: str | None = None, + open_browser: bool = True, +) -> tuple[Path, Path | None]: if n_clients is None: n_clients = DEFAULT_N_CLIENTS if any(c < 0 for c in n_clients): ez.logger.error("All tests must have >=0 clients") - return + raise ValueError("All tests must have >=0 clients") if msg_sizes is None: msg_sizes = DEFAULT_MSG_SIZES if any(s < 0 for s in msg_sizes): ez.logger.error("All msg_sizes must be >=0 bytes") + raise ValueError("All msg_sizes must be >=0 bytes") if not grid and len(list(n_clients)) != len(list(msg_sizes)): ez.logger.warning( @@ -136,7 +144,7 @@ def perf_run( ez.logger.error( f"Invalid test communications requested. Valid communications: {', '.join([c.value for c in Communication])}" ) - return + raise ValueError("Invalid test communications requested") try: configurators = ( @@ -146,7 +154,7 @@ def perf_run( ez.logger.error( f"Invalid test configuration requested. Valid configurations: {', '.join([c for c in CONFIGS])}" ) - return + raise ValueError("Invalid test configuration requested") subitr = itertools.product if grid else zip @@ -177,12 +185,17 @@ def perf_run( quitting = False start_time = time.time() + if name is not None: + output_path, html_out = output_paths_for_name(name) + else: + output_path = Path(f"perf_{get_datestamp()}.txt") + html_out = None try: ez.logger.info(f"Warming up for {warmup_dur} seconds...") warmup(warmup_dur) - with open(f"perf_{get_datestamp()}.txt", "w") as out_f: + with open(output_path, "w") as out_f: for _ in range(repeats): out_f.write( json.dumps(TestEnvironmentInfo(), cls=MessageEncoder) + "\n" @@ -235,9 +248,25 @@ def perf_run( ) ez.logger.info(f"Tests concluded. Wallclock Runtime: {dur_str}s") + html_path = None + try: + from .analysis import write_html_report + + html_path = write_html_report( + perf_path=output_path, + output_path=html_out, + open_browser=open_browser, + ) + ez.logger.info(f"Wrote benchmark log to {output_path}") + ez.logger.info(f"Wrote benchmark report to {html_path}") + except ImportError: + ez.logger.warning("Could not generate benchmark HTML report; analysis dependencies are unavailable.") + + return output_path, html_path + -def setup_run_cmdline(subparsers: argparse._SubParsersAction) -> None: - p_run = subparsers.add_parser("run", help="run performance test") +def setup_benchmark_cmdline(subparsers: argparse._SubParsersAction) -> None: + p_run = subparsers.add_parser("benchmark", help="run the legacy benchmark matrix") p_run.add_argument( "--max-duration", @@ -328,8 +357,21 @@ def setup_run_cmdline(subparsers: argparse._SubParsersAction) -> None: help="warmup CPU with busy task for some number of seconds (default = 60.0)", ) + p_run.add_argument( + "--name", + type=str, + default=None, + help="optional short name used for perf_.txt and report_.html", + ) + + p_run.add_argument( + "--no-browser", + action="store_true", + help="write the generated HTML report without opening it in a browser", + ) + p_run.set_defaults( - _handler=lambda ns: perf_run( + _handler=lambda ns: benchmark( max_duration=ns.max_duration, num_msgs=ns.num_msgs, num_buffers=ns.num_buffers, @@ -341,5 +383,7 @@ def setup_run_cmdline(subparsers: argparse._SubParsersAction) -> None: configs=ns.configs, grid=True, warmup_dur=ns.warmup, + name=ns.name, + open_browser=not ns.no_browser, ) ) diff --git a/tests/test_command.py b/tests/test_command.py index 6a737253..5cf9630d 100644 --- a/tests/test_command.py +++ b/tests/test_command.py @@ -31,20 +31,47 @@ def test_perf_subparser_accepts_nested_perf_args(): args = parser.parse_args( [ "perf", - "hotpath", - "--count", - "10", - "--samples", + "benchmark", + "--name", + "smoke", + "--num-msgs", "2", - "--quiet", + "--repeats", + "1", + "--no-browser", ] ) assert args.command == "perf" - assert args.perf_command == "hotpath" - assert args.count == 10 - assert args.samples == 2 - assert args.quiet is True + assert args.perf_command == "benchmark" + assert args.name == "smoke" + assert args.num_msgs == 2 + assert args.repeats == 1 + assert args.no_browser is True + + +def test_perf_compare_subparser_accepts_baseline_args(): + parser = build_parser() + + args = parser.parse_args( + [ + "perf", + "compare", + "candidate.txt", + "--baseline", + "baseline.txt", + "--output", + "diff.html", + "--no-browser", + ] + ) + + assert args.command == "perf" + assert args.perf_command == "compare" + assert str(args.perf) == "candidate.txt" + assert str(args.baseline) == "baseline.txt" + assert str(args.output) == "diff.html" + assert args.no_browser is True def test_graphviz_subparser_rejects_mermaid_only_args(): @@ -65,4 +92,4 @@ def test_perf_subparser_rejects_core_only_args(): parser = build_parser() with pytest.raises(SystemExit): - parser.parse_args(["perf", "hotpath", "--address", "127.0.0.1:4000"]) + parser.parse_args(["perf", "benchmark", "--address", "127.0.0.1:4000"]) diff --git a/tests/test_perf_analysis.py b/tests/test_perf_analysis.py new file mode 100644 index 00000000..12c00dc2 --- /dev/null +++ b/tests/test_perf_analysis.py @@ -0,0 +1,218 @@ +import json +from pathlib import Path + +from ezmsg.util.messagecodec import MessageEncoder +from ezmsg.util.perf.analysis import ( + build_report_bundle, + default_compare_html_path, + default_report_html_path, + write_html_report, +) +from ezmsg.util.perf.envinfo import TestEnvironmentInfo as PerfEnvironmentInfo +from ezmsg.util.perf.impl import ( + Metrics as PerfMetrics, + TestLogEntry as PerfLogEntry, + TestParameters as PerfParameters, +) +from ezmsg.util.perf.run import benchmark, output_paths_for_name + + +def _write_perf_log( + path: Path, info: PerfEnvironmentInfo, entries: list[PerfLogEntry] +) -> None: + with open(path, "w") as handle: + handle.write(json.dumps(info, cls=MessageEncoder) + "\n") + for entry in entries: + handle.write(json.dumps(entry, cls=MessageEncoder) + "\n") + + +def _entry( + *, + config: str, + comms: str, + n_clients: int, + msg_size: int, + sample_rate: float, + latency: float, + data_rate: float, +) -> PerfLogEntry: + return PerfLogEntry( + params=PerfParameters( + msg_size=msg_size, + num_msgs=128, + n_clients=n_clients, + config=config, + comms=comms, + max_duration=1.0, + num_buffers=1, + ), + results=PerfMetrics( + num_msgs=128, + sample_rate_mean=sample_rate, + sample_rate_median=sample_rate * 0.98, + latency_mean=latency, + latency_median=latency * 0.97, + latency_total=latency * 128, + data_rate=data_rate, + ), + ) + + +def test_report_and_compare_html_paths(tmp_path): + perf = tmp_path / "perf_20260406_120000.txt" + baseline = tmp_path / "perf_20260405_120000.txt" + + assert default_report_html_path(perf) == tmp_path / "perf_20260406_120000.html" + assert default_compare_html_path(perf, baseline) == ( + tmp_path / "perf_20260406_120000.vs_perf_20260405_120000.html" + ) + assert output_paths_for_name("smoke") == ( + Path("perf_smoke.txt"), + Path("report_smoke.html"), + ) + + +def test_build_report_bundle_and_html_report(tmp_path): + candidate = tmp_path / "candidate.txt" + baseline = tmp_path / "baseline.txt" + + candidate_info = PerfEnvironmentInfo(git_branch="candidate", git_commit="abc123") + baseline_info = PerfEnvironmentInfo(git_branch="baseline", git_commit="def456") + + _write_perf_log( + candidate, + candidate_info, + [ + _entry( + config="fanin", + comms="local", + n_clients=1, + msg_size=64, + sample_rate=1200.0, + latency=0.0009, + data_rate=2_400_000.0, + ), + _entry( + config="relay", + comms="tcp", + n_clients=2, + msg_size=256, + sample_rate=700.0, + latency=0.0024, + data_rate=1_200_000.0, + ), + ], + ) + _write_perf_log( + baseline, + baseline_info, + [ + _entry( + config="fanin", + comms="local", + n_clients=1, + msg_size=64, + sample_rate=1000.0, + latency=0.0012, + data_rate=2_000_000.0, + ), + _entry( + config="relay", + comms="tcp", + n_clients=2, + msg_size=256, + sample_rate=900.0, + latency=0.0018, + data_rate=1_700_000.0, + ), + ], + ) + + bundle = build_report_bundle(candidate, baseline_path=baseline) + assert bundle.relative is True + assert bundle.delta_counts["improvement"] > 0 + assert bundle.delta_counts["regression"] > 0 + assert bundle.top_improvements + assert bundle.top_regressions + + out_path = write_html_report(candidate, baseline_path=baseline, open_browser=False) + html = out_path.read_text(encoding="utf-8") + + assert "Comparison Overview" in html + assert "Biggest Regressions" in html + assert "data-filter='config'" in html + assert "display-mode" in html + assert "metric-view" in html + assert "candidate.vs_baseline.html" == out_path.name + + +def test_benchmark_writes_raw_output_and_html_report(tmp_path, monkeypatch): + output_path = tmp_path / "perf_smoke.txt" + html_path = tmp_path / "report_smoke.html" + monkeypatch.chdir(tmp_path) + + class FakeGraphServer: + def __init__(self): + self.address = ("127.0.0.1", 0) + + def start(self): + return None + + def stop(self): + return None + + html_calls: list[tuple[Path, Path | None, bool]] = [] + + monkeypatch.setattr("ezmsg.util.perf.run.GraphServer", FakeGraphServer) + monkeypatch.setattr("ezmsg.util.perf.run.warmup", lambda _: None) + monkeypatch.setattr( + "ezmsg.util.perf.run.perform_test", + lambda **_: PerfMetrics( + num_msgs=8, + sample_rate_mean=1000.0, + sample_rate_median=950.0, + latency_mean=0.001, + latency_median=0.0009, + latency_total=0.008, + data_rate=1_000_000.0, + ), + ) + + def _fake_write_html_report( + perf_path: Path, + baseline_path: Path | None = None, + output_path: Path | None = None, + open_browser: bool = False, + ) -> Path: + html_calls.append((perf_path, output_path, open_browser)) + target = output_path or perf_path.with_suffix(".html") + target.write_text("", encoding="utf-8") + return target + + monkeypatch.setattr( + "ezmsg.util.perf.analysis.write_html_report", + _fake_write_html_report, + ) + + raw_path, report_path = benchmark( + max_duration=0.01, + num_msgs=8, + num_buffers=1, + iters=1, + repeats=1, + msg_sizes=[64], + n_clients=[1], + comms=["local"], + configs=["fanin"], + grid=True, + warmup_dur=0.0, + name="smoke", + open_browser=False, + ) + + assert raw_path.name == output_path.name + assert report_path is not None + assert report_path.name == html_path.name + assert output_path.exists() + assert html_path.exists() + assert html_calls == [(raw_path, report_path, False)] From 9e93d0574e8dab4e4f356efdd60dcb247f93e6e9 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Mon, 6 Apr 2026 16:30:34 -0400 Subject: [PATCH 43/52] refactored ab testing to support shared env and remove --- src/ezmsg/util/perf/ab.py | 549 ++++++++++++++++++++++++++++++++------ tests/test_command.py | 32 +++ tests/test_perf_ab.py | 51 +++- 3 files changed, 552 insertions(+), 80 deletions(-) diff --git a/src/ezmsg/util/perf/ab.py b/src/ezmsg/util/perf/ab.py index 5a2da621..e4696f57 100644 --- a/src/ezmsg/util/perf/ab.py +++ b/src/ezmsg/util/perf/ab.py @@ -2,6 +2,7 @@ import argparse import contextlib +import hashlib import json import os import random @@ -15,6 +16,27 @@ DEFAULT_PAIR_SEED = 0 +DEFAULT_REF_A = "dev" +DEFAULT_REF_B = "CURRENT" +VENV_DIR_CANDIDATES = (".venv", "venv", ".env", "env") + + +@dataclass(frozen=True) +class ABEnvironmentInfo: + label: str + ref: str + tree: str + python: str + python_version: str + ezmsg_version: str + numpy_version: str + git_commit: str + git_branch: str + dirty: bool + env_mode: str + pyproject_hash: str | None + uv_lock_hash: str | None + env_overrides: dict[str, str] @dataclass(frozen=True) @@ -34,6 +56,9 @@ class ABRunSummary: ref_b: str rounds: int seed: int + env_a: ABEnvironmentInfo + env_b: ABEnvironmentInfo + warnings: list[str] cases: list[ABCaseSummary] @@ -49,6 +74,7 @@ def _hotpath_json_arg(path: Path) -> list[str]: def build_hotpath_command( + python: str | Path, output_path: Path, count: int, warmup: int, @@ -59,9 +85,7 @@ def build_hotpath_command( quiet: bool, ) -> list[str]: cmd = [ - "uv", - "run", - "python", + str(python), "-m", "ezmsg.util.perf.hotpath", "--count", @@ -85,6 +109,18 @@ def build_hotpath_command( return cmd +def parse_env_assignments(values: list[str]) -> dict[str, str]: + assignments: dict[str, str] = {} + for value in values: + if "=" not in value: + raise ValueError(f"Environment override must use KEY=VALUE format: {value}") + key, env_value = value.split("=", 1) + if not key: + raise ValueError(f"Environment override is missing a key: {value}") + assignments[key] = env_value + return assignments + + def load_hotpath_summary(path: Path) -> dict[str, float]: payload = json.loads(path.read_text()) return { @@ -99,6 +135,9 @@ def summarize_ab_results( rounds: int, seed: int, paired_runs: list[tuple[dict[str, float], dict[str, float]]], + env_a: ABEnvironmentInfo | None = None, + env_b: ABEnvironmentInfo | None = None, + warnings: list[str] | None = None, ) -> ABRunSummary: case_ids = sorted(paired_runs[0][0].keys()) cases: list[ABCaseSummary] = [] @@ -119,11 +158,30 @@ def summarize_ab_results( ) ) + placeholder = ABEnvironmentInfo( + label="?", + ref="unknown", + tree="unknown", + python="unknown", + python_version="unknown", + ezmsg_version="unknown", + numpy_version="unknown", + git_commit="unknown", + git_branch="unknown", + dirty=False, + env_mode="unknown", + pyproject_hash=None, + uv_lock_hash=None, + env_overrides={}, + ) return ABRunSummary( ref_a=ref_a, ref_b=ref_b, rounds=rounds, seed=seed, + env_a=env_a or placeholder, + env_b=env_b or placeholder, + warnings=warnings or [], cases=cases, ) @@ -136,9 +194,7 @@ def _median(values: list[float]) -> float: return (ordered[mid - 1] + ordered[mid]) / 2.0 -def _run_checked(cmd: list[str], cwd: Path) -> None: - env = os.environ.copy() - env.pop("VIRTUAL_ENV", None) +def _run_checked(cmd: list[str], cwd: Path, env: dict[str, str] | None = None) -> None: completed = subprocess.run(cmd, cwd=cwd, capture_output=True, text=True, env=env) if completed.returncode == 0: return @@ -151,6 +207,18 @@ def _run_checked(cmd: list[str], cwd: Path) -> None: ) +def _run_json(cmd: list[str], cwd: Path, env: dict[str, str]) -> dict[str, str]: + completed = subprocess.run(cmd, cwd=cwd, capture_output=True, text=True, env=env) + if completed.returncode != 0: + raise RuntimeError( + f"Command failed in {cwd}:\n" + f"$ {' '.join(cmd)}\n\n" + f"stdout:\n{completed.stdout}\n" + f"stderr:\n{completed.stderr}" + ) + return json.loads(completed.stdout) + + def _is_current_ref(ref: str) -> bool: return ref.upper() == "CURRENT" @@ -184,10 +252,6 @@ def _provision_tree( shutil.rmtree(parent, ignore_errors=True) -def _maybe_sync(tree: Path) -> None: - _run_checked(["uv", "sync", "--group", "dev"], cwd=tree) - - def _mirror_hotpath_module(source_root: Path, target_tree: Path) -> None: source = source_root / "src" / "ezmsg" / "util" / "perf" / "hotpath.py" target = target_tree / "src" / "ezmsg" / "util" / "perf" / "hotpath.py" @@ -209,11 +273,182 @@ def _ensure_json_files_match( ) +def _python_from_venv(venv_dir: Path) -> Path | None: + candidates = [venv_dir / "bin" / "python", venv_dir / "Scripts" / "python.exe"] + for candidate in candidates: + if candidate.exists(): + return candidate + return None + + +def discover_tree_python(tree: Path) -> Path | None: + for dirname in VENV_DIR_CANDIDATES: + python = _python_from_venv(tree / dirname) + if python is not None: + return python + return None + + +def _virtual_env_from_python(python: Path) -> str | None: + parent = python.parent + if parent.name in {"bin", "Scripts"}: + return str(parent.parent) + return None + + +def _tree_env( + base_env: dict[str, str], + tree: Path, + python: Path, + overrides: dict[str, str], +) -> dict[str, str]: + env = base_env.copy() + env.update(overrides) + path_parts = [str(tree / "src")] + if env.get("PYTHONPATH"): + path_parts.append(env["PYTHONPATH"]) + env["PYTHONPATH"] = os.pathsep.join(path_parts) + venv = _virtual_env_from_python(python) + if venv is not None: + env["VIRTUAL_ENV"] = venv + return env + + +def _hash_file(path: Path) -> str | None: + if not path.exists(): + return None + return hashlib.sha256(path.read_bytes()).hexdigest()[:12] + + +def _git_value(tree: Path, args: list[str], default: str = "unknown") -> str: + completed = subprocess.run( + ["git", *args], + cwd=tree, + capture_output=True, + text=True, + ) + if completed.returncode != 0: + return default + return completed.stdout.strip() + + +def _git_dirty(tree: Path) -> bool: + completed = subprocess.run( + ["git", "status", "--porcelain"], + cwd=tree, + capture_output=True, + text=True, + ) + if completed.returncode != 0: + return False + return bool(completed.stdout.strip()) + + +def _runtime_probe_env( + tree: Path, + python: Path, + overrides: dict[str, str], +) -> dict[str, str]: + return _tree_env(os.environ.copy(), tree, python, overrides) + + +def _probe_runtime(tree: Path, python: Path, overrides: dict[str, str]) -> dict[str, str]: + script = ( + "import json, sys\n" + "payload = {\n" + " 'python': sys.executable,\n" + " 'python_version': sys.version.replace('\\n', ' '),\n" + " 'ezmsg_version': 'unknown',\n" + " 'numpy_version': 'unknown',\n" + "}\n" + "try:\n" + " import ezmsg.core as ez\n" + " payload['ezmsg_version'] = ez.__version__\n" + "except Exception:\n" + " pass\n" + "try:\n" + " import numpy as np\n" + " payload['numpy_version'] = np.__version__\n" + "except Exception:\n" + " pass\n" + "print(json.dumps(payload))\n" + ) + return _run_json( + [str(python), "-c", script], + cwd=tree, + env=_runtime_probe_env(tree, python, overrides), + ) + + +def _build_env_info( + label: str, + ref: str, + tree: Path, + python: Path, + env_mode: str, + overrides: dict[str, str], +) -> ABEnvironmentInfo: + runtime = _probe_runtime(tree, python, overrides) + return ABEnvironmentInfo( + label=label, + ref=ref, + tree=str(tree), + python=runtime["python"], + python_version=runtime["python_version"], + ezmsg_version=runtime["ezmsg_version"], + numpy_version=runtime["numpy_version"], + git_commit=_git_value(tree, ["rev-parse", "HEAD"]), + git_branch=_git_value(tree, ["rev-parse", "--abbrev-ref", "HEAD"]), + dirty=_git_dirty(tree), + env_mode=env_mode, + pyproject_hash=_hash_file(tree / "pyproject.toml"), + uv_lock_hash=_hash_file(tree / "uv.lock"), + env_overrides=overrides, + ) + + +def _shared_env_warnings(info_a: ABEnvironmentInfo, info_b: ABEnvironmentInfo) -> list[str]: + warnings: list[str] = [] + if info_a.pyproject_hash != info_b.pyproject_hash: + warnings.append( + "pyproject.toml differs between A and B while using a shared environment." + ) + if info_a.uv_lock_hash != info_b.uv_lock_hash: + warnings.append( + "uv.lock differs between A and B while using a shared environment." + ) + return warnings + + +def _print_env_info(env_info: ABEnvironmentInfo) -> None: + print( + f"{env_info.label}: ref={env_info.ref} tree={env_info.tree} " + f"python={env_info.python} env_mode={env_info.env_mode}" + ) + print( + f" python_version={env_info.python_version} " + f"ezmsg={env_info.ezmsg_version} numpy={env_info.numpy_version}" + ) + print( + f" branch={env_info.git_branch} commit={env_info.git_commit} dirty={env_info.dirty}" + ) + print( + f" pyproject_hash={env_info.pyproject_hash} uv_lock_hash={env_info.uv_lock_hash}" + ) + if env_info.env_overrides: + pairs = ", ".join(f"{key}={value}" for key, value in sorted(env_info.env_overrides.items())) + print(f" env_overrides={pairs}") + + def _print_summary(summary: ABRunSummary) -> None: print( f"Interleaved hot-path comparison: A={summary.ref_a}, " f"B={summary.ref_b}, rounds={summary.rounds}, seed={summary.seed}" ) + _print_env_info(summary.env_a) + _print_env_info(summary.env_b) + for warning in summary.warnings: + print(f"WARNING: {warning}") for case in summary.cases: sign = "regression" if case.delta_pct_median > 0 else "improvement" print( @@ -232,14 +467,68 @@ def dump_ab_json(summary: ABRunSummary, path: Path) -> None: "ref_b": summary.ref_b, "rounds": summary.rounds, "seed": summary.seed, + "env_a": asdict(summary.env_a), + "env_b": asdict(summary.env_b), + "warnings": summary.warnings, "cases": [asdict(case) for case in summary.cases], } path.write_text(json.dumps(payload, indent=2) + "\n") +def _default_ref(ref: str | None, fallback: str) -> str: + return fallback if ref is None else ref + + +def _resolve_python_for_shared( + python_a: str | None, + python_b: str | None, +) -> Path: + if python_a is not None and python_b is not None and python_a != python_b: + raise ValueError( + "shared env mode requires a single interpreter; --python-a and --python-b must match" + ) + if python_a is not None: + return Path(python_a) + if python_b is not None: + return Path(python_b) + return Path(sys.executable) + + +def _resolve_python_for_existing( + tree: Path, + label: str, + explicit_python: str | None, + repo_root: Path, +) -> Path: + if explicit_python is not None: + return Path(explicit_python) + + discovered = discover_tree_python(tree) + if discovered is not None: + return discovered + + if tree.resolve() == repo_root.resolve(): + return Path(sys.executable) + + raise ValueError( + f"Could not locate a Python interpreter for side {label} in {tree}. " + "Prepare the environment yourself and rerun with --env-mode existing " + "plus --python-a/--python-b or local .venv/venv directories." + ) + + def perf_ab( - ref_a: str, - ref_b: str, + ref_a: str | None, + ref_b: str | None, + dir_a: Path | None, + dir_b: Path | None, + python_a: str | None, + python_b: str | None, + env_mode: str, + force_shared_env: bool, + env: list[str], + env_a: list[str], + env_b: list[str], rounds: int, count: int, warmup: int, @@ -251,11 +540,23 @@ def perf_ab( seed: int, json_out: Path | None, keep_worktrees: bool, - sync: bool, quiet: bool, ) -> None: if rounds <= 0: raise ValueError("rounds must be > 0") + if dir_a is not None and ref_a is not None: + raise ValueError("Use either --ref-a or --dir-a, not both") + if dir_b is not None and ref_b is not None: + raise ValueError("Use either --ref-b or --dir-b, not both") + if force_shared_env and env_mode != "shared": + raise ValueError("--force-shared-env only applies to --env-mode shared") + + shared_overrides = parse_env_assignments(env) + env_overrides_a = {**shared_overrides, **parse_env_assignments(env_a)} + env_overrides_b = {**shared_overrides, **parse_env_assignments(env_b)} + + resolved_ref_a = dir_a.name if dir_a is not None else _default_ref(ref_a, DEFAULT_REF_A) + resolved_ref_b = dir_b.name if dir_b is not None else _default_ref(ref_b, DEFAULT_REF_B) repo_root = Path( subprocess.run( @@ -267,63 +568,117 @@ def perf_ab( ) pair_order = build_pair_order(rounds, seed) - with _provision_tree(repo_root, ref_a, "A", keep_worktrees) as tree_a: - with _provision_tree(repo_root, ref_b, "B", keep_worktrees) as tree_b: - if tree_a != repo_root: - _mirror_hotpath_module(repo_root, tree_a) - if tree_b != repo_root: - _mirror_hotpath_module(repo_root, tree_b) - - if sync: - _maybe_sync(tree_a) - if tree_b != tree_a: - _maybe_sync(tree_b) - - with tempfile.TemporaryDirectory(prefix="ezmsg-perf-ab-runs-") as tmpdir_name: - tmpdir = Path(tmpdir_name) - cmd_by_label = { - "A": lambda path: build_hotpath_command( - path, - count=count, - warmup=warmup, - payload_sizes=payload_sizes, - transports=transports, - apis=apis, - num_buffers=num_buffers, - quiet=quiet, - ), - "B": lambda path: build_hotpath_command( - path, - count=count, - warmup=warmup, - payload_sizes=payload_sizes, - transports=transports, - apis=apis, - num_buffers=num_buffers, - quiet=quiet, - ), - } - tree_by_label = {"A": tree_a, "B": tree_b} - - for idx in range(prewarm): - for label in ("A", "B"): - if label == "B" and tree_b == tree_a: - continue - warm_path = tmpdir / f"warm-{label}-{idx}.json" - _run_checked(cmd_by_label[label](warm_path), cwd=tree_by_label[label]) - - paired_runs: list[tuple[dict[str, float], dict[str, float]]] = [] - for round_idx, (first, second) in enumerate(pair_order, start=1): - outputs: dict[str, dict[str, float]] = {} - for label in (first, second): - output_path = tmpdir / f"round-{round_idx:02d}-{label}.json" - _run_checked(cmd_by_label[label](output_path), cwd=tree_by_label[label]) - outputs[label] = load_hotpath_summary(output_path) - - _ensure_json_files_match(outputs["A"], outputs["B"], ref_a, ref_b) - paired_runs.append((outputs["A"], outputs["B"])) - - summary = summarize_ab_results(ref_a, ref_b, rounds, seed, paired_runs) + with contextlib.ExitStack() as stack: + tree_a = ( + dir_a.resolve() + if dir_a is not None + else stack.enter_context(_provision_tree(repo_root, resolved_ref_a, "A", keep_worktrees)) + ) + tree_b = ( + dir_b.resolve() + if dir_b is not None + else stack.enter_context(_provision_tree(repo_root, resolved_ref_b, "B", keep_worktrees)) + ) + + if dir_a is None and tree_a != repo_root: + _mirror_hotpath_module(repo_root, tree_a) + if dir_b is None and tree_b != repo_root and tree_b != tree_a: + _mirror_hotpath_module(repo_root, tree_b) + + if env_mode == "shared": + shared_python = _resolve_python_for_shared(python_a, python_b) + python_path_a = shared_python + python_path_b = shared_python + else: + python_path_a = _resolve_python_for_existing(tree_a, "A", python_a, repo_root) + python_path_b = _resolve_python_for_existing(tree_b, "B", python_b, repo_root) + + env_info_a = _build_env_info( + "A", resolved_ref_a, tree_a, python_path_a, env_mode, env_overrides_a + ) + env_info_b = _build_env_info( + "B", resolved_ref_b, tree_b, python_path_b, env_mode, env_overrides_b + ) + + warnings = [] + if env_mode == "shared": + warnings = _shared_env_warnings(env_info_a, env_info_b) + if warnings and not force_shared_env: + raise ValueError( + "Shared-environment comparison detected project metadata mismatches:\n" + + "\n".join(f"- {warning}" for warning in warnings) + + "\n\nRe-run with --force-shared-env to continue anyway, or prepare " + "side-specific environments and use --env-mode existing." + ) + + with tempfile.TemporaryDirectory(prefix="ezmsg-perf-ab-runs-") as tmpdir_name: + tmpdir = Path(tmpdir_name) + cmd_by_label = { + "A": lambda path: build_hotpath_command( + python_path_a, + path, + count=count, + warmup=warmup, + payload_sizes=payload_sizes, + transports=transports, + apis=apis, + num_buffers=num_buffers, + quiet=quiet, + ), + "B": lambda path: build_hotpath_command( + python_path_b, + path, + count=count, + warmup=warmup, + payload_sizes=payload_sizes, + transports=transports, + apis=apis, + num_buffers=num_buffers, + quiet=quiet, + ), + } + tree_by_label = {"A": tree_a, "B": tree_b} + env_by_label = { + "A": _tree_env(os.environ.copy(), tree_a, python_path_a, env_overrides_a), + "B": _tree_env(os.environ.copy(), tree_b, python_path_b, env_overrides_b), + } + + for idx in range(prewarm): + for label in ("A", "B"): + if label == "B" and tree_b == tree_a and env_by_label["B"] == env_by_label["A"]: + continue + warm_path = tmpdir / f"warm-{label}-{idx}.json" + _run_checked( + cmd_by_label[label](warm_path), + cwd=tree_by_label[label], + env=env_by_label[label], + ) + + paired_runs: list[tuple[dict[str, float], dict[str, float]]] = [] + for round_idx, (first, second) in enumerate(pair_order, start=1): + outputs: dict[str, dict[str, float]] = {} + for label in (first, second): + output_path = tmpdir / f"round-{round_idx:02d}-{label}.json" + _run_checked( + cmd_by_label[label](output_path), + cwd=tree_by_label[label], + env=env_by_label[label], + ) + outputs[label] = load_hotpath_summary(output_path) + + _ensure_json_files_match(outputs["A"], outputs["B"], resolved_ref_a, resolved_ref_b) + paired_runs.append((outputs["A"], outputs["B"])) + + summary = summarize_ab_results( + resolved_ref_a, + resolved_ref_b, + rounds, + seed, + paired_runs, + env_a=env_info_a, + env_b=env_info_b, + warnings=warnings, + ) _print_summary(summary) if json_out is not None: dump_ab_json(summary, json_out) @@ -333,10 +688,43 @@ def perf_ab( def setup_ab_cmdline(subparsers: argparse._SubParsersAction) -> None: p_ab = subparsers.add_parser( "ab", - help="run interleaved A/B hot-path comparisons using git worktrees", + help="run interleaved A/B hot-path comparisons using worktrees or prepared directories", + ) + p_ab.add_argument("--ref-a", default=None, help=f"baseline git ref (default = {DEFAULT_REF_A})") + p_ab.add_argument("--ref-b", default=None, help=f"candidate git ref (default = {DEFAULT_REF_B})") + p_ab.add_argument("--dir-a", type=Path, default=None, help="use an existing directory for side A") + p_ab.add_argument("--dir-b", type=Path, default=None, help="use an existing directory for side B") + p_ab.add_argument("--python-a", default=None, help="explicit Python interpreter for side A") + p_ab.add_argument("--python-b", default=None, help="explicit Python interpreter for side B") + p_ab.add_argument( + "--env-mode", + choices=["shared", "existing"], + default="shared", + help="shared = reuse one interpreter for both sides; existing = use prepared per-tree environments", + ) + p_ab.add_argument( + "--force-shared-env", + action="store_true", + help="continue shared-env comparisons even when pyproject.toml or uv.lock differ", + ) + p_ab.add_argument( + "--env", + action="append", + default=[], + help="environment override for both sides (KEY=VALUE). Repeatable.", + ) + p_ab.add_argument( + "--env-a", + action="append", + default=[], + help="environment override for side A only (KEY=VALUE). Repeatable.", + ) + p_ab.add_argument( + "--env-b", + action="append", + default=[], + help="environment override for side B only (KEY=VALUE). Repeatable.", ) - p_ab.add_argument("--ref-a", default="dev", help="baseline git ref or CURRENT") - p_ab.add_argument("--ref-b", default="CURRENT", help="candidate git ref or CURRENT") p_ab.add_argument( "--rounds", type=int, @@ -405,11 +793,6 @@ def setup_ab_cmdline(subparsers: argparse._SubParsersAction) -> None: action="store_true", help="leave auto-provisioned worktrees on disk for inspection", ) - p_ab.add_argument( - "--sync", - action="store_true", - help="run 'uv sync --group dev' in each provisioned worktree first", - ) p_ab.add_argument( "--quiet", action="store_true", @@ -419,6 +802,15 @@ def setup_ab_cmdline(subparsers: argparse._SubParsersAction) -> None: _handler=lambda ns: perf_ab( ref_a=ns.ref_a, ref_b=ns.ref_b, + dir_a=ns.dir_a, + dir_b=ns.dir_b, + python_a=ns.python_a, + python_b=ns.python_b, + env_mode=ns.env_mode, + force_shared_env=ns.force_shared_env, + env=ns.env, + env_a=ns.env_a, + env_b=ns.env_b, rounds=ns.rounds, count=ns.count, warmup=ns.warmup, @@ -430,7 +822,6 @@ def setup_ab_cmdline(subparsers: argparse._SubParsersAction) -> None: seed=ns.seed, json_out=ns.json_out, keep_worktrees=ns.keep_worktrees, - sync=ns.sync, quiet=ns.quiet, ) ) diff --git a/tests/test_command.py b/tests/test_command.py index 5cf9630d..b75e59d1 100644 --- a/tests/test_command.py +++ b/tests/test_command.py @@ -74,6 +74,38 @@ def test_perf_compare_subparser_accepts_baseline_args(): assert args.no_browser is True +def test_perf_ab_subparser_accepts_manual_env_args(): + parser = build_parser() + + args = parser.parse_args( + [ + "perf", + "ab", + "--dir-a", + "/tmp/a", + "--dir-b", + "/tmp/b", + "--env-mode", + "existing", + "--env", + "FOO=bar", + "--env-a", + "ONLY_A=1", + "--python-b", + "/tmp/b/.venv/bin/python", + ] + ) + + assert args.command == "perf" + assert args.perf_command == "ab" + assert str(args.dir_a) == "/tmp/a" + assert str(args.dir_b) == "/tmp/b" + assert args.env_mode == "existing" + assert args.env == ["FOO=bar"] + assert args.env_a == ["ONLY_A=1"] + assert args.python_b == "/tmp/b/.venv/bin/python" + + def test_graphviz_subparser_rejects_mermaid_only_args(): parser = build_parser() diff --git a/tests/test_perf_ab.py b/tests/test_perf_ab.py index 3fbb935c..b6b0131f 100644 --- a/tests/test_perf_ab.py +++ b/tests/test_perf_ab.py @@ -1,6 +1,8 @@ from ezmsg.util.perf.ab import ( + ABEnvironmentInfo, build_hotpath_command, build_pair_order, + parse_env_assignments, summarize_ab_results, ) @@ -17,6 +19,7 @@ def test_build_pair_order_is_balanced_and_reproducible(): def test_build_hotpath_command_contains_expected_args(tmp_path): cmd = build_hotpath_command( + "/tmp/shared-python", tmp_path / "out.json", count=100, warmup=10, @@ -27,12 +30,19 @@ def test_build_hotpath_command_contains_expected_args(tmp_path): quiet=True, ) - assert cmd[:5] == ["uv", "run", "python", "-m", "ezmsg.util.perf.hotpath"] + assert cmd[:3] == ["/tmp/shared-python", "-m", "ezmsg.util.perf.hotpath"] assert "--count" in cmd assert "--payload-sizes" in cmd assert "--quiet" in cmd +def test_parse_env_assignments_merges_repeatable_values(): + assert parse_env_assignments(["FOO=bar", "BAZ=qux", "FOO=override"]) == { + "FOO": "override", + "BAZ": "qux", + } + + def test_summarize_ab_results_uses_b_vs_a_delta(): paired_runs = [ ( @@ -45,15 +55,54 @@ def test_summarize_ab_results_uses_b_vs_a_delta(): ), ] + env_a = ABEnvironmentInfo( + label="A", + ref="dev", + tree="/tmp/a", + python="/tmp/a/.venv/bin/python", + python_version="3.11.0", + ezmsg_version="1.0.0", + numpy_version="2.0.0", + git_commit="abc", + git_branch="dev", + dirty=False, + env_mode="shared", + pyproject_hash="123", + uv_lock_hash="456", + env_overrides={"FOO": "bar"}, + ) + env_b = ABEnvironmentInfo( + label="B", + ref="CURRENT", + tree="/tmp/b", + python="/tmp/b/.venv/bin/python", + python_version="3.11.0", + ezmsg_version="1.0.0", + numpy_version="2.0.0", + git_commit="def", + git_branch="main", + dirty=True, + env_mode="shared", + pyproject_hash="123", + uv_lock_hash="789", + env_overrides={"FOO": "baz"}, + ) + summary = summarize_ab_results( ref_a="dev", ref_b="CURRENT", rounds=2, seed=0, paired_runs=paired_runs, + env_a=env_a, + env_b=env_b, + warnings=["uv.lock differs"], ) assert len(summary.cases) == 1 + assert summary.env_a == env_a + assert summary.env_b == env_b + assert summary.warnings == ["uv.lock differs"] case = summary.cases[0] assert case.case_id == "async/shm/payload=64/buffers=1" assert case.a_us_per_message_median == 9.0 From c9fbcc3b0982d716aab428b55d90c2bb0d2be6b1 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Mon, 6 Apr 2026 16:51:33 -0400 Subject: [PATCH 44/52] fixed tests for windows --- tests/test_command.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_command.py b/tests/test_command.py index b75e59d1..b74df4bd 100644 --- a/tests/test_command.py +++ b/tests/test_command.py @@ -1,4 +1,5 @@ import pytest +from pathlib import Path from ezmsg.core.command import build_parser @@ -98,8 +99,8 @@ def test_perf_ab_subparser_accepts_manual_env_args(): assert args.command == "perf" assert args.perf_command == "ab" - assert str(args.dir_a) == "/tmp/a" - assert str(args.dir_b) == "/tmp/b" + assert args.dir_a == Path("/tmp/a") + assert args.dir_b == Path("/tmp/b") assert args.env_mode == "existing" assert args.env == ["FOO=bar"] assert args.env_a == ["ONLY_A=1"] From 63b02b01146513292a9e410e7e9351e3c53ee957 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Mon, 6 Apr 2026 17:34:53 -0400 Subject: [PATCH 45/52] fix optional dependencies --- pyproject.toml | 3 ++ src/ezmsg/util/perf/analysis.py | 58 ++++++++++++++++++++++++--------- src/ezmsg/util/perf/run.py | 36 ++++++++++++-------- 3 files changed, 68 insertions(+), 29 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6d796f65..37154bc7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,9 @@ docs = [ axisarray = [ "numpy>=2.2.6", ] +perf = [ + "xarray", +] [project.scripts] ezmsg = "ezmsg.core.command:cmdline" diff --git a/src/ezmsg/util/perf/analysis.py b/src/ezmsg/util/perf/analysis.py index ea3db477..743acbbe 100644 --- a/src/ezmsg/util/perf/analysis.py +++ b/src/ezmsg/util/perf/analysis.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import argparse import dataclasses import html @@ -6,25 +8,43 @@ import webbrowser from pathlib import Path - -from ..messagecodec import MessageDecoder -from .envinfo import TestEnvironmentInfo, format_env_diff -from .impl import Metrics, TestLogEntry, TestParameters +from typing import TYPE_CHECKING, Any import ezmsg.core as ez -try: +if TYPE_CHECKING: + import numpy as np + import pandas as pd import xarray as xr - import pandas as pd # xarray depends on pandas -except ImportError: - ez.logger.error("ezmsg perf analysis requires xarray") - raise + from .envinfo import TestEnvironmentInfo -try: - import numpy as np -except ImportError: - ez.logger.error("ezmsg perf analysis requires numpy") - raise +xr: Any | None = None +pd: Any | None = None +np: Any | None = None + + +def _load_analysis_dependencies() -> tuple[Any, Any, Any]: + global xr, pd, np + + if xr is None or pd is None: + try: + import xarray as _xr + import pandas as _pd # xarray depends on pandas + except ImportError: + ez.logger.error("ezmsg perf analysis requires xarray") + raise + xr = _xr + pd = _pd + + if np is None: + try: + import numpy as _np + except ImportError: + ez.logger.error("ezmsg perf analysis requires numpy") + raise + np = _np + + return xr, pd, np TEST_DESCRIPTION = """ Configurations (config): @@ -112,7 +132,12 @@ class ReportBundle: def load_perf(perf: Path) -> xr.Dataset: - all_results: dict[TestParameters, dict[int, list[Metrics]]] = dict() + xr, _, np = _load_analysis_dependencies() + from ..messagecodec import MessageDecoder + from .envinfo import TestEnvironmentInfo + from .impl import Metrics, TestLogEntry + + all_results: dict[Any, dict[int, list[Any]]] = dict() run_idx = 0 with open(perf, "r") as perf_f: @@ -326,6 +351,9 @@ def _terminal_delta_summary( def build_report_bundle(perf_path: Path, baseline_path: Path | None = None) -> ReportBundle: + _load_analysis_dependencies() + from .envinfo import format_env_diff + candidate = load_perf(perf_path) info: TestEnvironmentInfo = candidate.attrs["info"] candidate_frame = _display_frame(_frame_from_dataset(candidate), relative=False) diff --git a/src/ezmsg/util/perf/run.py b/src/ezmsg/util/perf/run.py index 1bebc86d..d002ad74 100644 --- a/src/ezmsg/util/perf/run.py +++ b/src/ezmsg/util/perf/run.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import sys import json @@ -14,20 +16,10 @@ import ezmsg.core as ez from ezmsg.core.graphserver import GraphServer -from ..messagecodec import MessageEncoder -from .envinfo import TestEnvironmentInfo -from .util import warmup -from .impl import ( - TestParameters, - TestLogEntry, - perform_test, - Communication, - CONFIGS, -) - DEFAULT_MSG_SIZES = [2**4, 2**20] DEFAULT_N_CLIENTS = [1, 16] -DEFAULT_COMMS = [c for c in Communication] +DEFAULT_COMMS = ["local", "shm", "tcp", "shm_spread", "tcp_spread"] +DEFAULT_CONFIGS = ["fanin", "fanout", "relay"] # --- Output Suppression Context Manager --- @@ -103,6 +95,18 @@ def output_paths_for_name(name: str) -> tuple[Path, Path]: return Path(f"perf_{name}.txt"), Path(f"report_{name}.html") +def warmup(*args, **kwargs): + from .util import warmup as _warmup + + return _warmup(*args, **kwargs) + + +def perform_test(**kwargs): + from .impl import perform_test as _perform_test + + return _perform_test(**kwargs) + + def benchmark( max_duration: float, num_msgs: int, @@ -118,6 +122,10 @@ def benchmark( name: str | None = None, open_browser: bool = True, ) -> tuple[Path, Path | None]: + from ..messagecodec import MessageEncoder + from .envinfo import TestEnvironmentInfo + from .impl import Communication, CONFIGS, TestLogEntry, TestParameters + if n_clients is None: n_clients = DEFAULT_N_CLIENTS if any(c < 0 for c in n_clients): @@ -339,7 +347,7 @@ def setup_benchmark_cmdline(subparsers: argparse._SubParsersAction) -> None: type=str, default=None, nargs="*", - help=f"communication strategies to test (default = {[c.value for c in DEFAULT_COMMS]})", + help=f"communication strategies to test (default = {DEFAULT_COMMS})", ) p_run.add_argument( @@ -347,7 +355,7 @@ def setup_benchmark_cmdline(subparsers: argparse._SubParsersAction) -> None: type=str, default=None, nargs="*", - help=f"configurations to test (default = {[c for c in CONFIGS]})", + help=f"configurations to test (default = {DEFAULT_CONFIGS})", ) p_run.add_argument( From c733893ee15880c8c06278e3767e49d6156e6a1d Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Mon, 6 Apr 2026 17:57:54 -0400 Subject: [PATCH 46/52] fixed race: _startup only true once system is at steady state --- src/ezmsg/core/backend.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ezmsg/core/backend.py b/src/ezmsg/core/backend.py index 7a06fbd0..44c8d457 100644 --- a/src/ezmsg/core/backend.py +++ b/src/ezmsg/core/backend.py @@ -667,7 +667,6 @@ def run_blocking(self) -> None: force_single_process=self._force_single_process, wait_for_ready=False ): return - self._started = True self._run_main_process() def _initialize(self, force_single_process: bool, wait_for_ready: bool) -> bool: @@ -767,11 +766,12 @@ def _run_main_process(self) -> None: if self._execution_context is None or self._loop is None: return self._main_process = self._execution_context.processes[0] - self._start_processes(self._execution_context.processes[1:]) interrupts = 0 forced_sigint = False try: + self._start_processes(self._execution_context.processes[1:]) + self._started = True self._main_process.process(self._loop) self._join_spawned_processes() logger.info("All processes exited normally") From 1d1a1bce84c56b71773eb0168a7204044bbc3052 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Tue, 7 Apr 2026 13:13:45 -0400 Subject: [PATCH 47/52] initial dashboard integration --- pyproject.toml | 12 +- src/ezmsg/core/command.py | 32 +++- src/ezmsg/core/commands/__init__.py | 2 + src/ezmsg/core/commands/common.py | 5 +- src/ezmsg/core/commands/dashboard.py | 77 +++++++++ src/ezmsg/core/commands/dashboard_cmd.py | 43 +++++ src/ezmsg/core/commands/serve.py | 19 ++- src/ezmsg/core/commands/start.py | 20 ++- tests/test_command.py | 98 ++++++++++- tests/test_dashboard_commands.py | 199 +++++++++++++++++++++++ 10 files changed, 495 insertions(+), 12 deletions(-) create mode 100644 src/ezmsg/core/commands/dashboard.py create mode 100644 src/ezmsg/core/commands/dashboard_cmd.py create mode 100644 tests/test_dashboard_commands.py diff --git a/pyproject.toml b/pyproject.toml index 37154bc7..193ec177 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,9 +52,6 @@ docs = [ axisarray = [ "numpy>=2.2.6", ] -perf = [ - "xarray", -] [project.scripts] ezmsg = "ezmsg.core.command:cmdline" @@ -63,6 +60,12 @@ ezmsg = "ezmsg.core.command:cmdline" axisarray = [ "numpy>=2.2.6", ] +perf = [ + "xarray>=2025.6.1", +] +dashboard = [ + "ezmsg-dashboard; python_version >= '3.11'", +] [tool.pytest.ini_options] addopts = ["--import-mode=importlib"] @@ -76,5 +79,8 @@ build-backend = "hatchling.build" [tool.hatch.metadata] allow-direct-references = true +[tool.uv.sources] +ezmsg-dashboard = { path = "../ezmsg-dashboard", editable = true } + [tool.hatch.build.targets.wheel] packages = ["src/ezmsg"] diff --git a/src/ezmsg/core/command.py b/src/ezmsg/core/command.py index f9c753d3..61aa2f46 100644 --- a/src/ezmsg/core/command.py +++ b/src/ezmsg/core/command.py @@ -10,11 +10,17 @@ from .commands.start import handle_start from .netprotocol import ( Address, + DEFAULT_HOST, GRAPHSERVER_ADDR_ENV, GRAPHSERVER_PORT_DEFAULT, PUBLISHER_START_PORT_ENV, PUBLISHER_START_PORT_DEFAULT, ) +from .commands.dashboard import ( + DASHBOARD_ADDR_ENV, + DASHBOARD_INSTALL_HINT, + DASHBOARD_PORT_DEFAULT, +) def build_parser() -> argparse.ArgumentParser: @@ -30,6 +36,7 @@ def build_parser() -> argparse.ArgumentParser: epilog=f""" You can also change server configuration with environment variables. GraphServer will be hosted on ${GRAPHSERVER_ADDR_ENV} (default port: {GRAPHSERVER_PORT_DEFAULT}). + Dashboard will be hosted on ${DASHBOARD_ADDR_ENV} (default: {DEFAULT_HOST}:{DASHBOARD_PORT_DEFAULT}, or graph port + 1). Publishers will be assigned available ports starting from {PUBLISHER_START_PORT_DEFAULT}. (Change with ${PUBLISHER_START_PORT_ENV}) """, ) @@ -55,7 +62,12 @@ def cmdline(argv: list[str] | None = None) -> None: result = args._handler(args) if inspect.isawaitable(result): - asyncio.run(result) + try: + asyncio.run(result) + except KeyboardInterrupt: + # asyncio.run() re-raises KeyboardInterrupt after cancelling the main + # task on Ctrl+C, even when command cleanup has already completed. + pass async def run_command( @@ -64,8 +76,10 @@ async def run_command( target: str = "live", compact: int | None = None, nobrowser: bool = False, + dashboard: int | bool | None = None, ) -> None: handlers = { + "dashboard": None, "serve": handle_serve, "start": handle_start, "shutdown": handle_shutdown, @@ -74,11 +88,25 @@ async def run_command( } if cmd not in handlers: raise ValueError(f"Unknown ezmsg command '{cmd}'") + if cmd == "dashboard": + try: + from ezmsg.dashboard.server import handle_dashboard + except ImportError as exc: + raise RuntimeError(DASHBOARD_INSTALL_HINT) from exc + handlers["dashboard"] = handle_dashboard args = argparse.Namespace( command=cmd, address=str(graph_address), + graph_address=str(graph_address), target=target, compact=compact, nobrowser=nobrowser, + dashboard=dashboard, + host="127.0.0.1", + port=8000, + open_browser=False, + log_level="info", ) - await handlers[cmd](args) + result = handlers[cmd](args) + if inspect.isawaitable(result): + await result diff --git a/src/ezmsg/core/commands/__init__.py b/src/ezmsg/core/commands/__init__.py index 7177e520..175c3c96 100644 --- a/src/ezmsg/core/commands/__init__.py +++ b/src/ezmsg/core/commands/__init__.py @@ -1,5 +1,6 @@ import argparse +from .dashboard_cmd import setup_dashboard_cmdline from .graphviz import setup_graphviz_cmdline from .mermaid import setup_mermaid_cmdline from .serve import setup_serve_cmdline @@ -8,6 +9,7 @@ def setup_core_cmdline(subparsers: argparse._SubParsersAction) -> None: + setup_dashboard_cmdline(subparsers) setup_serve_cmdline(subparsers) setup_start_cmdline(subparsers) setup_shutdown_cmdline(subparsers) diff --git a/src/ezmsg/core/commands/common.py b/src/ezmsg/core/commands/common.py index 2b8eaf51..8fe85010 100644 --- a/src/ezmsg/core/commands/common.py +++ b/src/ezmsg/core/commands/common.py @@ -1,6 +1,7 @@ import argparse -from ..netprotocol import Address, GRAPHSERVER_PORT_DEFAULT +from ..graphserver import GraphService +from ..netprotocol import Address def add_address_argument(parser: argparse.ArgumentParser) -> None: @@ -21,5 +22,5 @@ def add_compact_argument(parser: argparse.ArgumentParser) -> None: def graph_address_from_args(args: argparse.Namespace) -> Address: if args.address is None: - return Address("127.0.0.1", GRAPHSERVER_PORT_DEFAULT) + return GraphService.default_address() return Address.from_string(args.address) diff --git a/src/ezmsg/core/commands/dashboard.py b/src/ezmsg/core/commands/dashboard.py new file mode 100644 index 00000000..3997dc6e --- /dev/null +++ b/src/ezmsg/core/commands/dashboard.py @@ -0,0 +1,77 @@ +import argparse +import os +from typing import Any + +from ..netprotocol import ( + Address, + DEFAULT_HOST, + GRAPHSERVER_ADDR_ENV, + GRAPHSERVER_PORT_DEFAULT, +) + +DASHBOARD_ADDR_ENV = "EZMSG_DASHBOARD_ADDR" +DASHBOARD_PORT_DEFAULT = GRAPHSERVER_PORT_DEFAULT + 1 +DASHBOARD_INSTALL_HINT = ( + "Dashboard support requires the optional `ezmsg-dashboard` package. " + "Install it with `pip install ezmsg-dashboard`." +) + + +class DashboardDependencyError(RuntimeError): + pass + + +def add_dashboard_argument(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--dashboard", + nargs="?", + const=True, + default=None, + type=int, + metavar="PORT", + help=( + "Serve the optional ezmsg dashboard alongside the graph server. " + "If PORT is omitted, ezmsg uses the configured dashboard address or graph port + 1." + ), + ) + + +def default_graph_address() -> Address: + address_str = os.environ.get( + GRAPHSERVER_ADDR_ENV, f"{DEFAULT_HOST}:{GRAPHSERVER_PORT_DEFAULT}" + ) + return Address.from_string(address_str) + + +def dashboard_address( + graph_address: Address | None = None, dashboard_port: int | None = None +) -> Address: + if DASHBOARD_ADDR_ENV in os.environ: + address = Address.from_string(os.environ[DASHBOARD_ADDR_ENV]) + else: + resolved_graph_address = graph_address or default_graph_address() + address = Address(resolved_graph_address.host, resolved_graph_address.port + 1) + + if dashboard_port is not None: + return Address(address.host, dashboard_port) + return address + + +def require_dashboard_dependency() -> Any: + try: + from ezmsg.dashboard.server import start_dashboard_server + except ImportError as exc: + raise DashboardDependencyError(DASHBOARD_INSTALL_HINT) from exc + return start_dashboard_server + + +def start_dashboard(graph_address: Address, dashboard_port: int | None = None) -> Any: + start_dashboard_server = require_dashboard_dependency() + + address = dashboard_address(graph_address, dashboard_port=dashboard_port) + return start_dashboard_server( + graph_address=graph_address, + host=address.host, + port=address.port, + log_level="warning", + ) diff --git a/src/ezmsg/core/commands/dashboard_cmd.py b/src/ezmsg/core/commands/dashboard_cmd.py new file mode 100644 index 00000000..0ffb5f54 --- /dev/null +++ b/src/ezmsg/core/commands/dashboard_cmd.py @@ -0,0 +1,43 @@ +import argparse +import logging + +from .dashboard import DASHBOARD_INSTALL_HINT + +logger = logging.getLogger("ezmsg") + + +def _warn_dashboard_dependency_missing(_: argparse.Namespace) -> None: + logger.warning(DASHBOARD_INSTALL_HINT) + + +def _setup_dashboard_fallback(subparsers: argparse._SubParsersAction) -> None: + parser = subparsers.add_parser( + "dashboard", + help="launch the optional ezmsg dashboard server", + description="Launch the optional ezmsg dashboard server.", + ) + parser.add_argument("--graph-address", default=None, help="Address of the ezmsg graph server.") + parser.add_argument("--host", default="127.0.0.1", help="HTTP bind host for the dashboard.") + parser.add_argument("--port", type=int, default=8000, help="HTTP bind port for the dashboard.") + parser.add_argument( + "--open-browser", + action="store_true", + help="Open the dashboard in a browser after startup.", + ) + parser.add_argument( + "--log-level", + default="info", + choices=["critical", "error", "warning", "info", "debug", "trace"], + help="Uvicorn log verbosity.", + ) + parser.set_defaults(_handler=_warn_dashboard_dependency_missing) + + +def setup_dashboard_cmdline(subparsers: argparse._SubParsersAction) -> None: + try: + from ezmsg.dashboard.server import setup_dashboard_cmdline as setup_optional_dashboard + except ImportError: + _setup_dashboard_fallback(subparsers) + return + + setup_optional_dashboard(subparsers) diff --git a/src/ezmsg/core/commands/serve.py b/src/ezmsg/core/commands/serve.py index 075bad18..9ed5b959 100644 --- a/src/ezmsg/core/commands/serve.py +++ b/src/ezmsg/core/commands/serve.py @@ -4,6 +4,11 @@ from ..graphserver import GraphService from .common import add_address_argument, graph_address_from_args +from .dashboard import ( + DashboardDependencyError, + add_dashboard_argument, + start_dashboard, +) logger = logging.getLogger("ezmsg") @@ -14,17 +19,29 @@ async def handle_serve(args: argparse.Namespace) -> None: logger.info(f"GraphServer Address: {graph_address}") graph_server = graph_service.create_server() + dashboard_server = None try: + if args.dashboard is not None: + dashboard_port = args.dashboard if type(args.dashboard) is int else None + dashboard_server = start_dashboard( + graph_service.address, dashboard_port=dashboard_port + ) + logger.info(f"Dashboard Address: {dashboard_server.url}") logger.info("Servers running...") await asyncio.to_thread(graph_server.join) - except KeyboardInterrupt: + except (KeyboardInterrupt, asyncio.CancelledError): logger.info("Interrupt detected; shutting down servers") + except DashboardDependencyError as exc: + logger.warning(str(exc)) finally: + if dashboard_server is not None: + dashboard_server.stop() graph_server.stop() def setup_serve_cmdline(subparsers: argparse._SubParsersAction) -> None: parser = subparsers.add_parser("serve") add_address_argument(parser) + add_dashboard_argument(parser) parser.set_defaults(_handler=handle_serve) diff --git a/src/ezmsg/core/commands/start.py b/src/ezmsg/core/commands/start.py index 8133b0db..eb040d87 100644 --- a/src/ezmsg/core/commands/start.py +++ b/src/ezmsg/core/commands/start.py @@ -7,6 +7,11 @@ from ..graphserver import GraphService from ..netprotocol import close_stream_writer from .common import add_address_argument, graph_address_from_args +from .dashboard import ( + DashboardDependencyError, + add_dashboard_argument, + require_dashboard_dependency, +) logger = logging.getLogger("ezmsg") @@ -14,10 +19,18 @@ async def handle_start(args: argparse.Namespace) -> None: graph_address = graph_address_from_args(args) graph_service = GraphService(graph_address) + cmd = [sys.executable, "-m", "ezmsg.core", "serve", f"--address={graph_address}"] + if args.dashboard is not None: + try: + require_dashboard_dependency() + except DashboardDependencyError as exc: + logger.warning(str(exc)) + return + cmd.append("--dashboard") + if type(args.dashboard) is int: + cmd.append(str(args.dashboard)) - popen = subprocess.Popen( - [sys.executable, "-m", "ezmsg.core", "serve", f"--address={graph_address}"] - ) + popen = subprocess.Popen(cmd) while True: try: @@ -33,4 +46,5 @@ async def handle_start(args: argparse.Namespace) -> None: def setup_start_cmdline(subparsers: argparse._SubParsersAction) -> None: parser = subparsers.add_parser("start") add_address_argument(parser) + add_dashboard_argument(parser) parser.set_defaults(_handler=handle_start) diff --git a/tests/test_command.py b/tests/test_command.py index b74df4bd..cee71c59 100644 --- a/tests/test_command.py +++ b/tests/test_command.py @@ -1,7 +1,8 @@ import pytest +import argparse from pathlib import Path -from ezmsg.core.command import build_parser +from ezmsg.core.command import build_parser, cmdline def test_mermaid_subparser_accepts_mermaid_specific_args(): @@ -121,8 +122,103 @@ def test_serve_subparser_rejects_visualization_args(): parser.parse_args(["serve", "--target", "play"]) +def test_serve_subparser_accepts_dashboard_flag(): + parser = build_parser() + + args = parser.parse_args(["serve", "--dashboard"]) + + assert args.command == "serve" + assert args.dashboard is True + + +def test_start_subparser_accepts_dashboard_flag(): + parser = build_parser() + + args = parser.parse_args(["start", "--dashboard"]) + + assert args.command == "start" + assert args.dashboard is True + + +def test_serve_subparser_accepts_dashboard_port(): + parser = build_parser() + + args = parser.parse_args(["serve", "--dashboard", "28000"]) + + assert args.command == "serve" + assert args.dashboard == 28000 + + +def test_dashboard_subparser_accepts_dashboard_args(): + parser = build_parser() + + args = parser.parse_args( + [ + "dashboard", + "--graph-address", + "127.0.0.1:4000", + "--host", + "0.0.0.0", + "--port", + "28000", + "--open-browser", + "--log-level", + "debug", + ] + ) + + assert args.command == "dashboard" + assert args.graph_address == "127.0.0.1:4000" + assert args.host == "0.0.0.0" + assert args.port == 28000 + assert args.open_browser is True + assert args.log_level == "debug" + + def test_perf_subparser_rejects_core_only_args(): parser = build_parser() with pytest.raises(SystemExit): parser.parse_args(["perf", "benchmark", "--address", "127.0.0.1:4000"]) + + +def test_cmdline_suppresses_keyboard_interrupt_from_asyncio_run(monkeypatch): + class DummyParser: + def parse_args(self, args=None): + return argparse.Namespace(_handler=lambda parsed_args: object()) + + monkeypatch.setattr("ezmsg.core.command.build_parser", lambda: DummyParser()) + monkeypatch.setattr("ezmsg.core.command.inspect.isawaitable", lambda result: True) + + def raise_keyboard_interrupt(result): + raise KeyboardInterrupt + + monkeypatch.setattr("ezmsg.core.command.asyncio.run", raise_keyboard_interrupt) + + cmdline([]) + + +def test_dashboard_subcommand_warns_when_optional_dependency_missing(monkeypatch, caplog): + real_import = __import__ + + def fake_import(name, globals=None, locals=None, fromlist=(), level=0): + if name == "ezmsg.dashboard.server": + raise ImportError("missing optional dashboard package") + return real_import(name, globals, locals, fromlist, level) + + monkeypatch.setattr("builtins.__import__", fake_import) + monkeypatch.delitem(__import__("sys").modules, "ezmsg.dashboard.server", raising=False) + monkeypatch.delitem(__import__("sys").modules, "ezmsg.core.commands.dashboard_cmd", raising=False) + + from ezmsg.core.commands.dashboard_cmd import setup_dashboard_cmdline + + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers(dest="command", required=True) + setup_dashboard_cmdline(subparsers) + + args = parser.parse_args(["dashboard"]) + + with caplog.at_level("WARNING"): + args._handler(args) + + assert "pip install ezmsg-dashboard" in caplog.text diff --git a/tests/test_dashboard_commands.py b/tests/test_dashboard_commands.py new file mode 100644 index 00000000..7b392615 --- /dev/null +++ b/tests/test_dashboard_commands.py @@ -0,0 +1,199 @@ +import argparse +import sys +from types import SimpleNamespace + +import pytest + +from ezmsg.core.commands.dashboard import ( + DASHBOARD_ADDR_ENV, + DashboardDependencyError, + DASHBOARD_INSTALL_HINT, + dashboard_address, + require_dashboard_dependency, + start_dashboard, +) +from ezmsg.core.commands.start import handle_start +from ezmsg.core.commands.serve import handle_serve +from ezmsg.core.commands.common import graph_address_from_args +from ezmsg.core.netprotocol import Address +from ezmsg.core.graphserver import GraphService + + +def test_dashboard_address_defaults_to_graph_port_plus_one(): + graph_address = Address("127.0.0.1", 25978) + + assert dashboard_address(graph_address) == Address("127.0.0.1", 25979) + + +def test_dashboard_address_uses_environment_override(monkeypatch): + monkeypatch.setenv(DASHBOARD_ADDR_ENV, "0.0.0.0:4100") + + assert dashboard_address(Address("127.0.0.1", 25978)) == Address("0.0.0.0", 4100) + + +def test_dashboard_address_uses_explicit_port_with_graph_host(): + assert dashboard_address(Address("127.0.0.1", 30000), dashboard_port=30001) == Address( + "127.0.0.1", 30001 + ) + + +def test_dashboard_address_uses_explicit_port_with_env_host(monkeypatch): + monkeypatch.setenv(DASHBOARD_ADDR_ENV, "0.0.0.0:4100") + + assert dashboard_address(Address("127.0.0.1", 25978), dashboard_port=4101) == Address( + "0.0.0.0", 4101 + ) + + +def test_graph_address_from_args_uses_environment_override(monkeypatch): + monkeypatch.setenv("EZMSG_GRAPHSERVER_ADDR", "0.0.0.0:4101") + + assert graph_address_from_args(argparse.Namespace(address=None)) == Address( + "0.0.0.0", 4101 + ) + + +def test_require_dashboard_dependency_raises_helpful_error_when_package_missing(monkeypatch): + import builtins + + real_import = builtins.__import__ + + def fake_import(name, globals=None, locals=None, fromlist=(), level=0): + if name == "ezmsg.dashboard.server": + raise ImportError("missing optional dashboard package") + return real_import(name, globals, locals, fromlist, level) + + monkeypatch.setattr(builtins, "__import__", fake_import) + + with pytest.raises(RuntimeError, match="pip install ezmsg-dashboard"): + require_dashboard_dependency() + + +def test_start_dashboard_raises_helpful_error_when_package_missing(monkeypatch): + monkeypatch.setattr( + "ezmsg.core.commands.dashboard.require_dashboard_dependency", + lambda: (_ for _ in ()).throw(DashboardDependencyError(DASHBOARD_INSTALL_HINT)), + ) + + with pytest.raises(RuntimeError, match="pip install ezmsg-dashboard"): + start_dashboard(Address("127.0.0.1", 25978)) + + +@pytest.mark.asyncio +async def test_handle_start_forwards_dashboard_flag(monkeypatch): + popen_calls: list[list[str]] = [] + + class DummyPopen: + def __init__(self, cmd): + popen_calls.append(cmd) + self.pid = 1234 + + async def fake_open_connection(self): + return object(), SimpleNamespace() + + async def fake_close_stream_writer(writer): + return None + + monkeypatch.setattr("ezmsg.core.commands.start.subprocess.Popen", DummyPopen) + monkeypatch.setattr( + "ezmsg.core.commands.start.GraphService.open_connection", fake_open_connection + ) + monkeypatch.setattr( + "ezmsg.core.commands.start.close_stream_writer", fake_close_stream_writer + ) + + args = argparse.Namespace(address="127.0.0.1:25978", dashboard=True) + await handle_start(args) + + assert popen_calls == [ + [ + sys.executable, + "-m", + "ezmsg.core", + "serve", + "--address=127.0.0.1:25978", + "--dashboard", + ] + ] + + +@pytest.mark.asyncio +async def test_handle_start_forwards_dashboard_port(monkeypatch): + popen_calls: list[list[str]] = [] + + class DummyPopen: + def __init__(self, cmd): + popen_calls.append(cmd) + self.pid = 1234 + + async def fake_open_connection(self): + return object(), SimpleNamespace() + + async def fake_close_stream_writer(writer): + return None + + monkeypatch.setattr("ezmsg.core.commands.start.subprocess.Popen", DummyPopen) + monkeypatch.setattr( + "ezmsg.core.commands.start.GraphService.open_connection", fake_open_connection + ) + monkeypatch.setattr( + "ezmsg.core.commands.start.close_stream_writer", fake_close_stream_writer + ) + + args = argparse.Namespace(address="127.0.0.1:25978", dashboard=28123) + await handle_start(args) + + assert popen_calls == [ + [ + sys.executable, + "-m", + "ezmsg.core", + "serve", + "--address=127.0.0.1:25978", + "--dashboard", + "28123", + ] + ] + + +@pytest.mark.asyncio +async def test_handle_start_warns_when_dashboard_dependency_missing(monkeypatch, caplog): + monkeypatch.setattr( + "ezmsg.core.commands.start.require_dashboard_dependency", + lambda: (_ for _ in ()).throw(DashboardDependencyError(DASHBOARD_INSTALL_HINT)), + ) + + args = argparse.Namespace(address="127.0.0.1:25978", dashboard=True) + + with caplog.at_level("WARNING"): + await handle_start(args) + + assert "pip install ezmsg-dashboard" in caplog.text + + +@pytest.mark.asyncio +async def test_handle_serve_warns_when_dashboard_dependency_missing(monkeypatch, caplog): + class DummyGraphServer: + def join(self): + return None + + def stop(self): + return None + + monkeypatch.setattr( + "ezmsg.core.commands.serve.GraphService.create_server", + lambda self: DummyGraphServer(), + ) + monkeypatch.setattr( + "ezmsg.core.commands.serve.start_dashboard", + lambda graph_address, dashboard_port=None: (_ for _ in ()).throw( + DashboardDependencyError(DASHBOARD_INSTALL_HINT) + ), + ) + + args = argparse.Namespace(address="127.0.0.1:25978", dashboard=True) + + with caplog.at_level("WARNING"): + await handle_serve(args) + + assert "pip install ezmsg-dashboard" in caplog.text From 4af9e06bbfbfc6d7bbe5d9eaf100fa4f34802199 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Tue, 7 Apr 2026 15:47:44 -0400 Subject: [PATCH 48/52] small bugfix --- src/ezmsg/util/perf/run.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/ezmsg/util/perf/run.py b/src/ezmsg/util/perf/run.py index d002ad74..bc882759 100644 --- a/src/ezmsg/util/perf/run.py +++ b/src/ezmsg/util/perf/run.py @@ -145,9 +145,8 @@ def benchmark( ) try: - communications = ( - DEFAULT_COMMS if comms is None else [Communication(c) for c in comms] - ) + communication_names = DEFAULT_COMMS if comms is None else list(comms) + communications = [Communication(c) for c in communication_names] except ValueError: ez.logger.error( f"Invalid test communications requested. Valid communications: {', '.join([c.value for c in Communication])}" From 7078a38ea8e9b205f0239742f1464a0199096041 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Tue, 7 Apr 2026 15:54:08 -0400 Subject: [PATCH 49/52] removed local dashboard link --- pyproject.toml | 3 --- 1 file changed, 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 193ec177..48e0ea84 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,8 +79,5 @@ build-backend = "hatchling.build" [tool.hatch.metadata] allow-direct-references = true -[tool.uv.sources] -ezmsg-dashboard = { path = "../ezmsg-dashboard", editable = true } - [tool.hatch.build.targets.wheel] packages = ["src/ezmsg"] From ab5b70a05bd308c548b4da1986fb17879666de32 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Tue, 7 Apr 2026 15:58:31 -0400 Subject: [PATCH 50/52] version bump --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 48e0ea84..f69689f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ezmsg" -version = "3.8.0" +version = "3.9.0" description = "A simple DAG-based computation model" authors = [ { name = "Griffin Milsap", email = "griffin.milsap@gmail.com" }, From 1343f8e4c1f762727b01567abbe77fe748e3846a Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Tue, 7 Apr 2026 16:50:10 -0400 Subject: [PATCH 51/52] funding acknowedgement --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 63d02446..b5f55477 100644 --- a/README.md +++ b/README.md @@ -95,4 +95,4 @@ These publications provide insights into the practical applications and impact o ## Financial Support -`ezmsg` is supported by Johns Hopkins University (JHU), the JHU Applied Physics Laboratory (APL), and by the Wyss Center for Bio and Neuro Engineering. +`ezmsg` is supported by Johns Hopkins University (JHU), the JHU Applied Physics Laboratory (APL), Blackrock Neurotech and by the Wyss Center for Bio and Neuro Engineering. From 20c7b08475e1d2647dba869e164e0d9ce0fadee5 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Tue, 7 Apr 2026 16:55:06 -0400 Subject: [PATCH 52/52] fix tests with optional dashboard dep --- tests/test_dashboard_commands.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/test_dashboard_commands.py b/tests/test_dashboard_commands.py index 7b392615..cd6b7d9d 100644 --- a/tests/test_dashboard_commands.py +++ b/tests/test_dashboard_commands.py @@ -95,6 +95,10 @@ async def fake_close_stream_writer(writer): return None monkeypatch.setattr("ezmsg.core.commands.start.subprocess.Popen", DummyPopen) + monkeypatch.setattr( + "ezmsg.core.commands.start.require_dashboard_dependency", + lambda: object(), + ) monkeypatch.setattr( "ezmsg.core.commands.start.GraphService.open_connection", fake_open_connection ) @@ -133,6 +137,10 @@ async def fake_close_stream_writer(writer): return None monkeypatch.setattr("ezmsg.core.commands.start.subprocess.Popen", DummyPopen) + monkeypatch.setattr( + "ezmsg.core.commands.start.require_dashboard_dependency", + lambda: object(), + ) monkeypatch.setattr( "ezmsg.core.commands.start.GraphService.open_connection", fake_open_connection )