diff --git a/scripts/perf_ab.py b/scripts/perf_ab.py new file mode 100644 index 00000000..5e7a8b8d --- /dev/null +++ b/scripts/perf_ab.py @@ -0,0 +1,5 @@ +from ezmsg.util.perf.ab import main + + +if __name__ == "__main__": + main() diff --git a/src/ezmsg/core/backendprocess.py b/src/ezmsg/core/backendprocess.py index 8ef22a72..f3ab935d 100644 --- a/src/ezmsg/core/backendprocess.py +++ b/src/ezmsg/core/backendprocess.py @@ -299,6 +299,7 @@ async def setup_state(): buf_size=stream.buf_size, start_paused=True, force_tcp=stream.force_tcp, + allow_local=stream.allow_local, ), loop=loop, ).result() diff --git a/src/ezmsg/core/pubclient.py b/src/ezmsg/core/pubclient.py index f9c42952..86cf320e 100644 --- a/src/ezmsg/core/pubclient.py +++ b/src/ezmsg/core/pubclient.py @@ -37,6 +37,36 @@ BACKPRESSURE_WARNING = "EZMSG_DISABLE_BACKPRESSURE_WARNING" not in os.environ BACKPRESSURE_REFRACTORY = 5.0 # sec +ALLOW_LOCAL_ENV = "EZMSG_ALLOW_LOCAL" +FORCE_TCP_ENV = "EZMSG_FORCE_TCP" + + +def _process_allow_local_default() -> bool: + value = os.environ.get(ALLOW_LOCAL_ENV, "") + if value == "": + return True + return value.lower() in ("1", "true", "yes", "on") + + +def _process_force_tcp_default() -> bool: + value = os.environ.get(FORCE_TCP_ENV, "") + if value == "": + return False + return value.lower() in ("1", "true", "yes", "on") + + +def _resolve_force_tcp(force_tcp: bool | None) -> bool: + if force_tcp is None: + return _process_force_tcp_default() + return force_tcp + + +def _resolve_allow_local(force_tcp: bool, allow_local: bool | None) -> bool: + resolved = _process_allow_local_default() if allow_local is None else allow_local + if force_tcp and resolved: + logger.info("force_tcp=True disables local delivery for this publisher") + return False + return resolved # Publisher needs a bit more information about connected channels @@ -75,6 +105,7 @@ class Publisher: _msg_id: int _shm: SHMContext _force_tcp: bool + _allow_local: bool _last_backpressure_event: float _graph_address: AddressType | None @@ -99,7 +130,8 @@ async def create( buf_size: int = DEFAULT_SHM_SIZE, num_buffers: int = 32, start_paused: bool = False, - force_tcp: bool = False, + force_tcp: bool | None = None, + allow_local: bool | None = None, ) -> "Publisher": """ Create a new Publisher instance and register it with the graph server. @@ -116,6 +148,16 @@ async def create( :type port: int | None :param buf_size: Size of shared memory buffers. :type buf_size: int + :param force_tcp: Whether to force TCP transport instead of shared memory. + If None, inherit the process default from ``EZMSG_FORCE_TCP`` which + defaults to disabled. + :type force_tcp: bool | None + :param allow_local: Whether to allow the in-process fast path when available. + If None, inherit the process default from ``EZMSG_ALLOW_LOCAL`` which + defaults to enabled. Set to False to bypass local delivery and + characterize same-process SHM or TCP. When ``force_tcp=True``, local + delivery is disabled regardless of this value. + :type allow_local: bool | None :param kwargs: Additional keyword arguments for Publisher constructor. :return: Initialized and registered Publisher instance. :rtype: Publisher @@ -127,6 +169,8 @@ async def create( writer.write(Command.PUBLISH.value) writer.write(encode_str(topic)) + resolved_force_tcp = _resolve_force_tcp(force_tcp) + pub_id = UUID(await read_str(reader)) pub = cls( id=pub_id, @@ -135,7 +179,8 @@ async def create( graph_address=graph_address, num_buffers=num_buffers, start_paused=start_paused, - force_tcp=force_tcp, + force_tcp=resolved_force_tcp, + allow_local=allow_local, _guard=cls._SENTINEL, ) @@ -189,7 +234,8 @@ def __init__( graph_address: AddressType | None = None, num_buffers: int = 32, start_paused: bool = False, - force_tcp: bool = False, + force_tcp: bool | None = None, + allow_local: bool | None = None, _guard = None ) -> None: """ @@ -207,7 +253,12 @@ def __init__( :param start_paused: Whether to start in paused state. :type start_paused: bool :param force_tcp: Whether to force TCP transport instead of shared memory. - :type force_tcp: bool + If None, inherit the process default from ``EZMSG_FORCE_TCP``. + :type force_tcp: bool | None + :param allow_local: Whether to allow the direct in-process fast path when available. + If None, inherit the process default from ``EZMSG_ALLOW_LOCAL``. + When ``force_tcp=True``, local delivery is disabled regardless of this value. + :type allow_local: bool | None """ if _guard is not self._SENTINEL: raise TypeError( @@ -227,7 +278,8 @@ def __init__( self._running.set() self._num_buffers = num_buffers self._backpressure = Backpressure(num_buffers) - self._force_tcp = force_tcp + self._force_tcp = _resolve_force_tcp(force_tcp) + self._allow_local = _resolve_allow_local(self._force_tcp, allow_local) self._last_backpressure_event = -1 self._graph_address = graph_address @@ -436,12 +488,10 @@ async def broadcast(self, obj: Any) -> None: self._last_backpressure_event = time.time() await self._backpressure.wait(buf_idx) - # Get local channel and put variable there for local tx - self._local_channel.put_local(self._msg_id, obj) + if self._should_use_local_fast_path(): + self._local_channel.put_local(self._msg_id, obj) - if self._force_tcp or any( - ch.pid != self.pid or not ch.shm_ok for ch in self._channels.values() - ): + if any(not self._can_deliver_locally(ch) for ch in self._channels.values()): with MessageMarshal.serialize(self._msg_id, obj) as ( total_size, header, @@ -449,9 +499,7 @@ async def broadcast(self, obj: Any) -> None: ): total_size_bytes = uint64_to_bytes(total_size) - if not self._force_tcp and any( - ch.pid != self.pid and ch.shm_ok for ch in self._channels.values() - ): + if any(self._can_deliver_via_shm(ch) for ch in self._channels.values()): if self._shm.buf_size < total_size: new_shm = await GraphService(self._graph_address).create_shm( self._num_buffers, total_size * 2 @@ -475,14 +523,10 @@ async def broadcast(self, obj: Any) -> None: for channel in self._channels.values(): msg: bytes = b"" - if self.pid == channel.pid and channel.shm_ok: + if self._can_deliver_locally(channel): continue # Local transmission handled by channel.put - elif ( - (not self._force_tcp) - and self.pid != channel.pid - and channel.shm_ok - ): + elif self._can_deliver_via_shm(channel): msg = ( Command.TX_SHM.value + msg_id_bytes @@ -509,3 +553,16 @@ async def broadcast(self, obj: Any) -> None: ) self._msg_id += 1 + + def _should_use_local_fast_path(self) -> bool: + return any(self._can_deliver_locally(ch) for ch in self._channels.values()) + + def _can_deliver_locally(self, channel: PubChannelInfo) -> bool: + return self._allow_local and self.pid == channel.pid and channel.shm_ok + + def _can_deliver_via_shm(self, channel: PubChannelInfo) -> bool: + return ( + (not self._force_tcp) + and channel.shm_ok + and not self._can_deliver_locally(channel) + ) diff --git a/src/ezmsg/core/stream.py b/src/ezmsg/core/stream.py index c719c92a..79089d86 100644 --- a/src/ezmsg/core/stream.py +++ b/src/ezmsg/core/stream.py @@ -124,15 +124,20 @@ class OutputStream(Stream): :type num_buffers: int :param buf_size: Size of each message buffer in bytes :type buf_size: int - :param force_tcp: Whether to force TCP transport instead of shared memory - :type force_tcp: bool + :param force_tcp: Whether to force TCP transport instead of shared memory. + If None, inherit the process default from ``EZMSG_FORCE_TCP``. + :type force_tcp: bool | None + :param allow_local: Whether to allow the in-process fast path when available. + If None, inherit the process default from ``EZMSG_ALLOW_LOCAL``. + :type allow_local: bool | None """ host: str | None port: int | None num_buffers: int buf_size: int - force_tcp: bool + force_tcp: bool | None + allow_local: bool | None def __init__( self, @@ -141,7 +146,8 @@ def __init__( port: int | None = None, num_buffers: int = 32, buf_size: int = DEFAULT_SHM_SIZE, - force_tcp: bool = False, + force_tcp: bool | None = None, + allow_local: bool | None = None, ) -> None: super().__init__(msg_type) self.host = host @@ -149,7 +155,11 @@ def __init__( self.num_buffers = num_buffers self.buf_size = buf_size self.force_tcp = force_tcp + self.allow_local = allow_local def __repr__(self) -> str: preamble = f"Output{super().__repr__()}" - return f"{preamble}({self.num_buffers=}, {self.force_tcp=})" + return ( + f"{preamble}({self.num_buffers=}, {self.force_tcp=}, " + f"{self.allow_local=})" + ) diff --git a/src/ezmsg/util/perf/ab.py b/src/ezmsg/util/perf/ab.py new file mode 100644 index 00000000..5a2da621 --- /dev/null +++ b/src/ezmsg/util/perf/ab.py @@ -0,0 +1,446 @@ +from __future__ import annotations + +import argparse +import contextlib +import json +import os +import random +import shutil +import subprocess +import sys +import tempfile + +from dataclasses import asdict, dataclass +from pathlib import Path + + +DEFAULT_PAIR_SEED = 0 + + +@dataclass(frozen=True) +class ABCaseSummary: + case_id: str + a_us_per_message_median: float + b_us_per_message_median: float + delta_pct_median: float + delta_pct_mean: float + pair_count: int + b_faster_pairs: int + + +@dataclass(frozen=True) +class ABRunSummary: + ref_a: str + ref_b: str + rounds: int + seed: int + cases: list[ABCaseSummary] + + +def build_pair_order(rounds: int, seed: int) -> list[tuple[str, str]]: + base = [("A", "B"), ("B", "A")] * ((rounds + 1) // 2) + order = base[:rounds] + random.Random(seed).shuffle(order) + return order + + +def _hotpath_json_arg(path: Path) -> list[str]: + return ["--json-out", str(path)] + + +def build_hotpath_command( + output_path: Path, + count: int, + warmup: int, + payload_sizes: list[int], + transports: list[str], + apis: list[str], + num_buffers: int, + quiet: bool, +) -> list[str]: + cmd = [ + "uv", + "run", + "python", + "-m", + "ezmsg.util.perf.hotpath", + "--count", + str(count), + "--warmup", + str(warmup), + "--samples", + "1", + "--num-buffers", + str(num_buffers), + "--payload-sizes", + *[str(payload_size) for payload_size in payload_sizes], + "--transports", + *transports, + "--apis", + *apis, + *_hotpath_json_arg(output_path), + ] + if quiet: + cmd.append("--quiet") + return cmd + + +def load_hotpath_summary(path: Path) -> dict[str, float]: + payload = json.loads(path.read_text()) + return { + entry["case_id"]: float(entry["summary"]["us_per_message_median"]) + for entry in payload["results"] + } + + +def summarize_ab_results( + ref_a: str, + ref_b: str, + rounds: int, + seed: int, + paired_runs: list[tuple[dict[str, float], dict[str, float]]], +) -> ABRunSummary: + case_ids = sorted(paired_runs[0][0].keys()) + cases: list[ABCaseSummary] = [] + + for case_id in case_ids: + a_values = [pair[0][case_id] for pair in paired_runs] + b_values = [pair[1][case_id] for pair in paired_runs] + deltas = [((b / a) - 1.0) * 100.0 for a, b in zip(a_values, b_values)] + cases.append( + ABCaseSummary( + case_id=case_id, + a_us_per_message_median=_median(a_values), + b_us_per_message_median=_median(b_values), + delta_pct_median=_median(deltas), + delta_pct_mean=sum(deltas) / len(deltas), + pair_count=len(deltas), + b_faster_pairs=sum(1 for a, b in zip(a_values, b_values) if b < a), + ) + ) + + return ABRunSummary( + ref_a=ref_a, + ref_b=ref_b, + rounds=rounds, + seed=seed, + cases=cases, + ) + + +def _median(values: list[float]) -> float: + ordered = sorted(values) + mid = len(ordered) // 2 + if len(ordered) % 2: + return ordered[mid] + return (ordered[mid - 1] + ordered[mid]) / 2.0 + + +def _run_checked(cmd: list[str], cwd: Path) -> None: + env = os.environ.copy() + env.pop("VIRTUAL_ENV", None) + completed = subprocess.run(cmd, cwd=cwd, capture_output=True, text=True, env=env) + if completed.returncode == 0: + return + + raise RuntimeError( + f"Command failed in {cwd}:\n" + f"$ {' '.join(cmd)}\n\n" + f"stdout:\n{completed.stdout}\n" + f"stderr:\n{completed.stderr}" + ) + + +def _is_current_ref(ref: str) -> bool: + return ref.upper() == "CURRENT" + + +@contextlib.contextmanager +def _provision_tree( + repo_root: Path, + ref: str, + label: str, + keep: bool, +) -> Path: + if _is_current_ref(ref): + yield repo_root + return + + parent = Path(tempfile.mkdtemp(prefix=f"ezmsg-perf-{label.lower()}-")) + tree_path = parent / "tree" + _run_checked( + ["git", "worktree", "add", "--detach", str(tree_path), ref], + cwd=repo_root, + ) + + try: + yield tree_path + finally: + if keep: + return + try: + _run_checked(["git", "worktree", "remove", "--force", str(tree_path)], cwd=repo_root) + finally: + shutil.rmtree(parent, ignore_errors=True) + + +def _maybe_sync(tree: Path) -> None: + _run_checked(["uv", "sync", "--group", "dev"], cwd=tree) + + +def _mirror_hotpath_module(source_root: Path, target_tree: Path) -> None: + source = source_root / "src" / "ezmsg" / "util" / "perf" / "hotpath.py" + target = target_tree / "src" / "ezmsg" / "util" / "perf" / "hotpath.py" + shutil.copy2(source, target) + + +def _ensure_json_files_match( + left: dict[str, float], + right: dict[str, float], + label_left: str, + label_right: str, +) -> None: + if left.keys() == right.keys(): + return + + raise RuntimeError( + f"Benchmark cases differ between {label_left} and {label_right}: " + f"{sorted(left.keys())} != {sorted(right.keys())}" + ) + + +def _print_summary(summary: ABRunSummary) -> None: + print( + f"Interleaved hot-path comparison: A={summary.ref_a}, " + f"B={summary.ref_b}, rounds={summary.rounds}, seed={summary.seed}" + ) + for case in summary.cases: + sign = "regression" if case.delta_pct_median > 0 else "improvement" + print( + f"{case.case_id:<36} " + f"A={case.a_us_per_message_median:>10.2f} us/msg " + f"B={case.b_us_per_message_median:>10.2f} us/msg " + f"delta={case.delta_pct_median:>7.2f}% ({sign}) " + f"wins={case.b_faster_pairs}/{case.pair_count}" + ) + + +def dump_ab_json(summary: ABRunSummary, path: Path) -> None: + payload = { + "suite": "hotpath-ab", + "ref_a": summary.ref_a, + "ref_b": summary.ref_b, + "rounds": summary.rounds, + "seed": summary.seed, + "cases": [asdict(case) for case in summary.cases], + } + path.write_text(json.dumps(payload, indent=2) + "\n") + + +def perf_ab( + ref_a: str, + ref_b: str, + rounds: int, + count: int, + warmup: int, + prewarm: int, + payload_sizes: list[int], + transports: list[str], + apis: list[str], + num_buffers: int, + seed: int, + json_out: Path | None, + keep_worktrees: bool, + sync: bool, + quiet: bool, +) -> None: + if rounds <= 0: + raise ValueError("rounds must be > 0") + + repo_root = Path( + subprocess.run( + ["git", "rev-parse", "--show-toplevel"], + check=True, + capture_output=True, + text=True, + ).stdout.strip() + ) + pair_order = build_pair_order(rounds, seed) + + with _provision_tree(repo_root, ref_a, "A", keep_worktrees) as tree_a: + with _provision_tree(repo_root, ref_b, "B", keep_worktrees) as tree_b: + if tree_a != repo_root: + _mirror_hotpath_module(repo_root, tree_a) + if tree_b != repo_root: + _mirror_hotpath_module(repo_root, tree_b) + + if sync: + _maybe_sync(tree_a) + if tree_b != tree_a: + _maybe_sync(tree_b) + + with tempfile.TemporaryDirectory(prefix="ezmsg-perf-ab-runs-") as tmpdir_name: + tmpdir = Path(tmpdir_name) + cmd_by_label = { + "A": lambda path: build_hotpath_command( + path, + count=count, + warmup=warmup, + payload_sizes=payload_sizes, + transports=transports, + apis=apis, + num_buffers=num_buffers, + quiet=quiet, + ), + "B": lambda path: build_hotpath_command( + path, + count=count, + warmup=warmup, + payload_sizes=payload_sizes, + transports=transports, + apis=apis, + num_buffers=num_buffers, + quiet=quiet, + ), + } + tree_by_label = {"A": tree_a, "B": tree_b} + + for idx in range(prewarm): + for label in ("A", "B"): + if label == "B" and tree_b == tree_a: + continue + warm_path = tmpdir / f"warm-{label}-{idx}.json" + _run_checked(cmd_by_label[label](warm_path), cwd=tree_by_label[label]) + + paired_runs: list[tuple[dict[str, float], dict[str, float]]] = [] + for round_idx, (first, second) in enumerate(pair_order, start=1): + outputs: dict[str, dict[str, float]] = {} + for label in (first, second): + output_path = tmpdir / f"round-{round_idx:02d}-{label}.json" + _run_checked(cmd_by_label[label](output_path), cwd=tree_by_label[label]) + outputs[label] = load_hotpath_summary(output_path) + + _ensure_json_files_match(outputs["A"], outputs["B"], ref_a, ref_b) + paired_runs.append((outputs["A"], outputs["B"])) + + summary = summarize_ab_results(ref_a, ref_b, rounds, seed, paired_runs) + _print_summary(summary) + if json_out is not None: + dump_ab_json(summary, json_out) + print(f"Wrote JSON results to {json_out}") + + +def setup_ab_cmdline(subparsers: argparse._SubParsersAction) -> None: + p_ab = subparsers.add_parser( + "ab", + help="run interleaved A/B hot-path comparisons using git worktrees", + ) + p_ab.add_argument("--ref-a", default="dev", help="baseline git ref or CURRENT") + p_ab.add_argument("--ref-b", default="CURRENT", help="candidate git ref or CURRENT") + p_ab.add_argument( + "--rounds", + type=int, + default=6, + help="number of A/B pairs to run (default = 6)", + ) + p_ab.add_argument( + "--count", + type=int, + default=2_000, + help="messages per hot-path sample (default = 2000)", + ) + p_ab.add_argument( + "--warmup", + type=int, + default=200, + help="warmup messages per hot-path sample (default = 200)", + ) + p_ab.add_argument( + "--prewarm", + type=int, + default=1, + help="unmeasured warmup invocations per side (default = 1)", + ) + p_ab.add_argument( + "--payload-sizes", + nargs="*", + type=int, + default=[64, 4096], + help="payload sizes in bytes (default = [64, 4096])", + ) + p_ab.add_argument( + "--transports", + nargs="*", + choices=["local", "shm", "tcp"], + default=["local", "shm", "tcp"], + help="transports to compare (default = ['local', 'shm', 'tcp'])", + ) + p_ab.add_argument( + "--apis", + nargs="*", + choices=["async", "sync"], + default=["async"], + help="apis to compare (default = ['async'])", + ) + p_ab.add_argument( + "--num-buffers", + type=int, + default=1, + help="publisher buffers (default = 1)", + ) + p_ab.add_argument( + "--seed", + type=int, + default=DEFAULT_PAIR_SEED, + help="pair-order shuffle seed (default = 0)", + ) + p_ab.add_argument( + "--json-out", + type=Path, + default=None, + help="optional JSON output path", + ) + p_ab.add_argument( + "--keep-worktrees", + action="store_true", + help="leave auto-provisioned worktrees on disk for inspection", + ) + p_ab.add_argument( + "--sync", + action="store_true", + help="run 'uv sync --group dev' in each provisioned worktree first", + ) + p_ab.add_argument( + "--quiet", + action="store_true", + help="suppress ezmsg runtime logs in child benchmark runs", + ) + p_ab.set_defaults( + _handler=lambda ns: perf_ab( + ref_a=ns.ref_a, + ref_b=ns.ref_b, + rounds=ns.rounds, + count=ns.count, + warmup=ns.warmup, + prewarm=ns.prewarm, + payload_sizes=ns.payload_sizes, + transports=ns.transports, + apis=ns.apis, + num_buffers=ns.num_buffers, + seed=ns.seed, + json_out=ns.json_out, + keep_worktrees=ns.keep_worktrees, + sync=ns.sync, + quiet=ns.quiet, + ) + ) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Run interleaved ezmsg hot-path A/B comparisons." + ) + subparsers = parser.add_subparsers(dest="command", required=True) + setup_ab_cmdline(subparsers) + ns = parser.parse_args(["ab", *sys.argv[1:]]) + ns._handler(ns) diff --git a/src/ezmsg/util/perf/command.py b/src/ezmsg/util/perf/command.py index 21fed7eb..9dab1f8e 100644 --- a/src/ezmsg/util/perf/command.py +++ b/src/ezmsg/util/perf/command.py @@ -1,6 +1,8 @@ import argparse +from .ab import setup_ab_cmdline from .analysis import setup_summary_cmdline +from .hotpath import setup_hotpath_cmdline from .run import setup_run_cmdline @@ -9,6 +11,8 @@ def command() -> None: subparsers = parser.add_subparsers(dest="command", required=True) setup_run_cmdline(subparsers) + setup_hotpath_cmdline(subparsers) + setup_ab_cmdline(subparsers) setup_summary_cmdline(subparsers) ns = parser.parse_args() diff --git a/src/ezmsg/util/perf/hotpath.py b/src/ezmsg/util/perf/hotpath.py new file mode 100644 index 00000000..74585981 --- /dev/null +++ b/src/ezmsg/util/perf/hotpath.py @@ -0,0 +1,437 @@ +from __future__ import annotations + +import argparse +import asyncio +import contextlib +import inspect +import json +import logging +import random +import statistics +import sys +import time + +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Literal +from uuid import uuid4 + +import ezmsg.core as ez + +from ezmsg.core.graphserver import GraphServer + +from .util import coef_var, median_of_means, stable_perf + +ApiName = Literal["async", "sync"] +TransportName = Literal["local", "shm", "tcp"] + +DEFAULT_APIS: tuple[ApiName, ...] = ("async",) +DEFAULT_TRANSPORTS: tuple[TransportName, ...] = ("local", "shm", "tcp") +DEFAULT_PAYLOAD_SIZES = (64, 4096) + + +def _supports_same_process_transport_selection() -> bool: + return "allow_local" in inspect.signature(ez.Publisher.create).parameters + + +def _validate_transport_support(case: "HotPathCase") -> None: + if case.transport == "local": + return + if _supports_same_process_transport_selection(): + return + raise RuntimeError( + "This ref does not support bypassing the same-process local fast path, " + f"so '{case.case_id}' cannot be benchmarked here. Compare against a ref " + "that includes the allow_local transport-selection support, or limit the " + "run to '--transports local'." + ) + + +def _publisher_transport_kwargs(case: "HotPathCase", host: str) -> dict[str, object]: + kwargs: dict[str, object] = { + "host": host, + "num_buffers": case.num_buffers, + "force_tcp": (case.transport == "tcp"), + } + if _supports_same_process_transport_selection(): + kwargs["allow_local"] = case.transport == "local" + return kwargs + + +@dataclass(frozen=True) +class HotPathCase: + api: ApiName + transport: TransportName + payload_size: int + num_buffers: int + + @property + def case_id(self) -> str: + return ( + f"{self.api}/{self.transport}/payload={self.payload_size}" + f"/buffers={self.num_buffers}" + ) + + +@dataclass(frozen=True) +class HotPathSummary: + samples: int + seconds_median: float + seconds_mean: float + seconds_min: float + seconds_max: float + seconds_median_of_means: float + seconds_cv: float + us_per_message_median: float + us_per_message_mean: float + messages_per_second_median: float + messages_per_second_mean: float + + +@dataclass(frozen=True) +class HotPathCaseResult: + case: HotPathCase + count: int + warmup: int + samples_seconds: list[float] + summary: HotPathSummary + + +@dataclass(frozen=True) +class HotPathSuiteResult: + seed: int + count: int + warmup: int + results: list[HotPathCaseResult] + + +@contextlib.contextmanager +def _quiet_ezmsg_logs(enabled: bool): + if not enabled: + yield + return + + old_level = ez.logger.level + ez.logger.setLevel(logging.WARNING) + try: + yield + finally: + ez.logger.setLevel(old_level) + + +def build_cases( + apis: list[str], + transports: list[str], + payload_sizes: list[int], + num_buffers: int, +) -> list[HotPathCase]: + cases = [ + HotPathCase( + api=api, # type: ignore[arg-type] + transport=transport, # type: ignore[arg-type] + payload_size=payload_size, + num_buffers=num_buffers, + ) + for api in apis + for transport in transports + for payload_size in payload_sizes + if api in ("async", "sync") and transport in ("local", "shm", "tcp") + ] + return sorted(cases, key=lambda case: case.case_id) + + +def summarize_samples(samples: list[float], count: int) -> HotPathSummary: + us_per_message = [(sample / count) * 1e6 for sample in samples] + rates = [count / sample for sample in samples] + return HotPathSummary( + samples=len(samples), + seconds_median=statistics.median(samples), + seconds_mean=statistics.fmean(samples), + seconds_min=min(samples), + seconds_max=max(samples), + seconds_median_of_means=median_of_means(samples), + seconds_cv=coef_var(samples), + us_per_message_median=statistics.median(us_per_message), + us_per_message_mean=statistics.fmean(us_per_message), + messages_per_second_median=statistics.median(rates), + messages_per_second_mean=statistics.fmean(rates), + ) + + +async def _async_roundtrip( + case: HotPathCase, + count: int, + warmup: int, + graph_address: tuple[str, int], +) -> float: + _validate_transport_support(case) + topic = f"/EZMSG/PERF/HOTPATH/{uuid4().hex}" + payload = bytes(case.payload_size) + + async with ez.GraphContext(graph_address, auto_start=False) as ctx: + pub = await ctx.publisher(topic, **_publisher_transport_kwargs(case, graph_address[0])) + sub = await ctx.subscriber(topic) + + for _ in range(warmup): + await pub.broadcast(payload) + await sub.recv() + + start = time.perf_counter() + for _ in range(count): + await pub.broadcast(payload) + await sub.recv() + return time.perf_counter() - start + + +def _sync_roundtrip( + case: HotPathCase, + count: int, + warmup: int, + graph_address: tuple[str, int], +) -> float: + _validate_transport_support(case) + topic = f"/EZMSG/PERF/HOTPATH/{uuid4().hex}" + payload = bytes(case.payload_size) + + with ez.sync.init(graph_address, auto_start=False) as ctx: + pub = ctx.create_publisher(topic, **_publisher_transport_kwargs(case, graph_address[0])) + sub = ctx.create_subscription(topic) + + for _ in range(warmup): + pub.publish(payload) + sub.recv() + + start = time.perf_counter() + for _ in range(count): + pub.publish(payload) + sub.recv() + return time.perf_counter() - start + + +def run_hotpath_case( + case: HotPathCase, + count: int, + warmup: int, + samples: int, + graph_address: tuple[str, int], +) -> HotPathCaseResult: + results: list[float] = [] + + for _ in range(samples): + with stable_perf(): + if case.api == "async": + elapsed = asyncio.run( + _async_roundtrip(case, count, warmup, graph_address) + ) + else: + elapsed = _sync_roundtrip(case, count, warmup, graph_address) + results.append(elapsed) + + return HotPathCaseResult( + case=case, + count=count, + warmup=warmup, + samples_seconds=results, + summary=summarize_samples(results, count), + ) + + +def _format_case_result(result: HotPathCaseResult) -> str: + summary = result.summary + return ( + f"{result.case.case_id:<36} " + f"{summary.us_per_message_median:>10.2f} us/msg " + f"{summary.messages_per_second_median:>12,.0f} msg/s " + f"cv={summary.seconds_cv:>5.3f}" + ) + + +def run_hotpath_suite( + count: int, + warmup: int, + samples: int, + apis: list[str], + transports: list[str], + payload_sizes: list[int], + num_buffers: int, + seed: int, + quiet: bool, +) -> HotPathSuiteResult: + rng = random.Random(seed) + cases = build_cases(apis, transports, payload_sizes, num_buffers) + rng.shuffle(cases) + + graph_server = GraphServer() + graph_server.start(("127.0.0.1", 0)) + try: + with _quiet_ezmsg_logs(quiet): + results = [ + run_hotpath_case( + case, + count=count, + warmup=warmup, + samples=samples, + graph_address=graph_server.address, + ) + for case in cases + ] + finally: + graph_server.stop() + + results.sort(key=lambda result: result.case.case_id) + return HotPathSuiteResult(seed=seed, count=count, warmup=warmup, results=results) + + +def dump_suite_json(result: HotPathSuiteResult, path: Path) -> None: + payload = { + "suite": "hotpath", + "seed": result.seed, + "count": result.count, + "warmup": result.warmup, + "results": [ + { + "case": asdict(case_result.case), + "case_id": case_result.case.case_id, + "count": case_result.count, + "warmup": case_result.warmup, + "samples_seconds": case_result.samples_seconds, + "summary": asdict(case_result.summary), + } + for case_result in result.results + ], + } + path.write_text(json.dumps(payload, indent=2) + "\n") + + +def perf_hotpath( + count: int, + warmup: int, + samples: int, + apis: list[str], + transports: list[str], + payload_sizes: list[int], + num_buffers: int, + seed: int, + json_out: Path | None, + quiet: bool, +) -> None: + result = run_hotpath_suite( + count=count, + warmup=warmup, + samples=samples, + apis=apis, + transports=transports, + payload_sizes=payload_sizes, + num_buffers=num_buffers, + seed=seed, + quiet=quiet, + ) + + print("Hot-path roundtrip benchmark") + print( + f"count={count}, warmup={warmup}, samples={samples}, " + f"payload_sizes={payload_sizes}, transports={transports}, apis={apis}" + ) + for case_result in result.results: + print(_format_case_result(case_result)) + + if json_out is not None: + dump_suite_json(result, json_out) + print(f"Wrote JSON results to {json_out}") + + +def setup_hotpath_cmdline(subparsers: argparse._SubParsersAction) -> None: + p_hotpath = subparsers.add_parser( + "hotpath", + help="run fast, focused hot-path roundtrip benchmarks", + ) + p_hotpath.add_argument( + "--count", + type=int, + default=5_000, + help="messages per sample (default = 5000)", + ) + p_hotpath.add_argument( + "--warmup", + type=int, + default=500, + help="warmup messages per sample (default = 500)", + ) + p_hotpath.add_argument( + "--samples", + type=int, + default=5, + help="timed samples per case (default = 5)", + ) + p_hotpath.add_argument( + "--apis", + nargs="*", + choices=DEFAULT_APIS + ("sync",), + default=list(DEFAULT_APIS), + help=f"apis to benchmark (default = {list(DEFAULT_APIS)})", + ) + p_hotpath.add_argument( + "--transports", + nargs="*", + choices=DEFAULT_TRANSPORTS, + default=list(DEFAULT_TRANSPORTS), + help=f"transports to benchmark (default = {list(DEFAULT_TRANSPORTS)})", + ) + p_hotpath.add_argument( + "--payload-sizes", + nargs="*", + type=int, + default=list(DEFAULT_PAYLOAD_SIZES), + help=f"payload sizes in bytes (default = {list(DEFAULT_PAYLOAD_SIZES)})", + ) + p_hotpath.add_argument( + "--num-buffers", + type=int, + default=1, + help="publisher buffers (default = 1)", + ) + p_hotpath.add_argument( + "--seed", + type=int, + default=0, + help="shuffle seed for case order (default = 0)", + ) + p_hotpath.add_argument( + "--json-out", + type=Path, + default=None, + help="optional JSON output path", + ) + p_hotpath.add_argument( + "--quiet", + action="store_true", + help="suppress ezmsg runtime logs during the benchmark", + ) + p_hotpath.set_defaults( + _handler=lambda ns: perf_hotpath( + count=ns.count, + warmup=ns.warmup, + samples=ns.samples, + apis=ns.apis, + transports=ns.transports, + payload_sizes=ns.payload_sizes, + num_buffers=ns.num_buffers, + seed=ns.seed, + json_out=ns.json_out, + quiet=ns.quiet, + ) + ) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Run ezmsg hot-path roundtrip benchmarks." + ) + subparsers = parser.add_subparsers(dest="command", required=True) + setup_hotpath_cmdline(subparsers) + ns = parser.parse_args(["hotpath", *sys.argv[1:]]) + ns._handler(ns) + + +if __name__ == "__main__": + main() diff --git a/tests/test_perf_ab.py b/tests/test_perf_ab.py new file mode 100644 index 00000000..3fbb935c --- /dev/null +++ b/tests/test_perf_ab.py @@ -0,0 +1,62 @@ +from ezmsg.util.perf.ab import ( + build_hotpath_command, + build_pair_order, + summarize_ab_results, +) + + +def test_build_pair_order_is_balanced_and_reproducible(): + first = build_pair_order(6, seed=123) + second = build_pair_order(6, seed=123) + + assert first == second + assert len(first) == 6 + assert first.count(("A", "B")) == 3 + assert first.count(("B", "A")) == 3 + + +def test_build_hotpath_command_contains_expected_args(tmp_path): + cmd = build_hotpath_command( + tmp_path / "out.json", + count=100, + warmup=10, + payload_sizes=[64, 256], + transports=["local", "shm"], + apis=["async", "sync"], + num_buffers=2, + quiet=True, + ) + + assert cmd[:5] == ["uv", "run", "python", "-m", "ezmsg.util.perf.hotpath"] + assert "--count" in cmd + assert "--payload-sizes" in cmd + assert "--quiet" in cmd + + +def test_summarize_ab_results_uses_b_vs_a_delta(): + paired_runs = [ + ( + {"async/shm/payload=64/buffers=1": 10.0}, + {"async/shm/payload=64/buffers=1": 12.0}, + ), + ( + {"async/shm/payload=64/buffers=1": 8.0}, + {"async/shm/payload=64/buffers=1": 9.0}, + ), + ] + + summary = summarize_ab_results( + ref_a="dev", + ref_b="CURRENT", + rounds=2, + seed=0, + paired_runs=paired_runs, + ) + + assert len(summary.cases) == 1 + case = summary.cases[0] + assert case.case_id == "async/shm/payload=64/buffers=1" + assert case.a_us_per_message_median == 9.0 + assert case.b_us_per_message_median == 10.5 + assert case.delta_pct_median > 0 + assert case.b_faster_pairs == 0 diff --git a/tests/test_perf_hotpath.py b/tests/test_perf_hotpath.py new file mode 100644 index 00000000..9e0fcc6f --- /dev/null +++ b/tests/test_perf_hotpath.py @@ -0,0 +1,45 @@ +import pytest + +from ezmsg.util.perf.hotpath import HotPathCase, build_cases, run_hotpath_case + +from ezmsg.core.graphserver import GraphServer + + +def test_build_cases_are_sorted_by_case_id(): + cases = build_cases( + apis=["sync", "async"], + transports=["local", "tcp", "shm"], + payload_sizes=[1024, 64], + num_buffers=1, + ) + assert [case.case_id for case in cases] == sorted(case.case_id for case in cases) + assert "async/shm/payload=64/buffers=1" in {case.case_id for case in cases} + assert "async/local/payload=64/buffers=1" in {case.case_id for case in cases} + + +def test_run_hotpath_case_smoke(): + server = GraphServer() + try: + server.start(("127.0.0.1", 0)) + except PermissionError: + pytest.skip("Local socket binding is unavailable in this environment") + try: + result = run_hotpath_case( + HotPathCase( + api="sync", + transport="tcp", + payload_size=64, + num_buffers=1, + ), + count=8, + warmup=2, + samples=2, + graph_address=server.address, + ) + finally: + server.stop() + + assert result.case.case_id == "sync/tcp/payload=64/buffers=1" + assert len(result.samples_seconds) == 2 + assert all(sample > 0 for sample in result.samples_seconds) + assert result.summary.us_per_message_median > 0 diff --git a/tests/test_pubclient.py b/tests/test_pubclient.py new file mode 100644 index 00000000..27048b4d --- /dev/null +++ b/tests/test_pubclient.py @@ -0,0 +1,203 @@ +from contextlib import contextmanager +from uuid import uuid4 + +import pytest + +from ezmsg.core.netprotocol import Address, Command +from ezmsg.core.pubclient import ( + ALLOW_LOCAL_ENV, + FORCE_TCP_ENV, + PubChannelInfo, + Publisher, +) + + +class DummyLocalChannel: + def __init__(self) -> None: + self.calls: list[tuple[int, object]] = [] + + def put_local(self, msg_id: int, obj: object) -> None: + self.calls.append((msg_id, obj)) + + +class DummyWriter: + def __init__(self) -> None: + self.buffer: list[bytes] = [] + + def write(self, data: bytes) -> None: + self.buffer.append(data) + + async def drain(self) -> None: + return None + + +class DummyShm: + def __init__(self, num_buffers: int, buf_size: int = 65536) -> None: + self.name = "dummy-shm" + self.buf_size = buf_size + self._buffers = [bytearray(buf_size) for _ in range(num_buffers)] + + @contextmanager + def buffer(self, idx: int, readonly: bool = False): + del readonly + yield memoryview(self._buffers[idx]) + + +def _make_publisher( + *, + force_tcp: bool | None, + allow_local: bool | None, + channel_pid: int, + shm_ok: bool = True, +) -> tuple[Publisher, DummyLocalChannel, DummyWriter]: + pub = Publisher( + id=uuid4(), + topic="/TEST", + shm=DummyShm(num_buffers=2), + graph_address=Address("127.0.0.1", 25978), + num_buffers=2, + force_tcp=force_tcp, + allow_local=allow_local, + _guard=Publisher._SENTINEL, + ) + pub._running.set() + + local_channel = DummyLocalChannel() + writer = DummyWriter() + channel = PubChannelInfo( + id=uuid4(), + writer=writer, + pub_id=pub.id, + pid=channel_pid, + shm_ok=shm_ok, + ) + pub._channels[channel.id] = channel + pub._local_channel = local_channel # type: ignore[assignment] + return pub, local_channel, writer + + +@pytest.mark.asyncio +async def test_broadcast_same_process_prefers_local_fast_path(): + pub, local_channel, writer = _make_publisher( + force_tcp=False, + allow_local=True, + channel_pid=0, + ) + pub.pid = 0 + + await pub.broadcast(b"payload") + + assert local_channel.calls == [(0, b"payload")] + assert writer.buffer == [] + + +@pytest.mark.asyncio +async def test_broadcast_same_process_can_force_shm_path(): + pub, local_channel, writer = _make_publisher( + force_tcp=False, + allow_local=False, + channel_pid=0, + ) + pub.pid = 0 + + await pub.broadcast(b"payload") + + assert local_channel.calls == [] + assert writer.buffer + assert writer.buffer[0].startswith(Command.TX_SHM.value) + + +@pytest.mark.asyncio +async def test_broadcast_same_process_can_force_tcp_path(): + pub, local_channel, writer = _make_publisher( + force_tcp=True, + allow_local=False, + channel_pid=0, + ) + pub.pid = 0 + + await pub.broadcast(b"payload") + + assert local_channel.calls == [] + assert writer.buffer + assert writer.buffer[0].startswith(Command.TX_TCP.value) + + +def test_force_tcp_disables_allow_local_from_env(monkeypatch, caplog): + monkeypatch.setenv(ALLOW_LOCAL_ENV, "1") + with caplog.at_level("INFO"): + pub, _, _ = _make_publisher( + force_tcp=True, + allow_local=None, + channel_pid=0, + ) + + assert pub._allow_local is False + assert "force_tcp=True disables local delivery" in caplog.text + + +def test_force_tcp_disables_explicit_allow_local(caplog): + with caplog.at_level("INFO"): + pub, _, _ = _make_publisher( + force_tcp=True, + allow_local=True, + channel_pid=0, + ) + + assert pub._allow_local is False + assert "force_tcp=True disables local delivery" in caplog.text + + +def test_force_tcp_uses_env_default_when_none(monkeypatch): + monkeypatch.setenv(FORCE_TCP_ENV, "1") + pub, _, _ = _make_publisher( + force_tcp=None, + allow_local=False, + channel_pid=0, + ) + + assert pub._force_tcp is True + + +def test_explicit_force_tcp_false_overrides_env(monkeypatch): + monkeypatch.setenv(FORCE_TCP_ENV, "1") + pub, _, _ = _make_publisher( + force_tcp=False, + allow_local=False, + channel_pid=0, + ) + + assert pub._force_tcp is False + + +@pytest.mark.asyncio +async def test_broadcast_same_process_uses_env_default_when_allow_local_is_none(monkeypatch): + monkeypatch.setenv(ALLOW_LOCAL_ENV, "0") + pub, local_channel, writer = _make_publisher( + force_tcp=False, + allow_local=None, + channel_pid=0, + ) + pub.pid = 0 + + await pub.broadcast(b"payload") + + assert local_channel.calls == [] + assert writer.buffer + assert writer.buffer[0].startswith(Command.TX_SHM.value) + + +@pytest.mark.asyncio +async def test_broadcast_same_process_explicit_allow_local_overrides_env(monkeypatch): + monkeypatch.setenv(ALLOW_LOCAL_ENV, "0") + pub, local_channel, writer = _make_publisher( + force_tcp=False, + allow_local=True, + channel_pid=0, + ) + pub.pid = 0 + + await pub.broadcast(b"payload") + + assert local_channel.calls == [(0, b"payload")] + assert writer.buffer == []