diff --git a/dimos/protocol/pubsub/benchmark/type.py b/dimos/protocol/pubsub/benchmark/type.py index 79101df9c5..9ffcf8942e 100644 --- a/dimos/protocol/pubsub/benchmark/type.py +++ b/dimos/protocol/pubsub/benchmark/type.py @@ -41,24 +41,31 @@ def __len__(self) -> int: TestData = Sequence[Case[Any, Any]] -def _format_size(size_bytes: int) -> str: - """Format byte size to human-readable string.""" - if size_bytes >= 1048576: - return f"{size_bytes / 1048576:.1f} MB" - if size_bytes >= 1024: - return f"{size_bytes / 1024:.1f} KB" - return f"{size_bytes} B" - - -def _format_throughput(bytes_per_sec: float) -> str: - """Format throughput to human-readable string.""" - if bytes_per_sec >= 1e9: - return f"{bytes_per_sec / 1e9:.2f} GB/s" - if bytes_per_sec >= 1e6: - return f"{bytes_per_sec / 1e6:.2f} MB/s" - if bytes_per_sec >= 1e3: - return f"{bytes_per_sec / 1e3:.2f} KB/s" - return f"{bytes_per_sec:.2f} B/s" +def _format_mib(value: float) -> str: + """Format bytes as MiB with intelligent rounding. + + >= 10 MiB: integer (e.g., "42") + 1-10 MiB: 1 decimal (e.g., "2.5") + < 1 MiB: 2 decimals (e.g., "0.07") + """ + mib = value / (1024**2) + if mib >= 10: + return f"{mib:.0f}" + if mib >= 1: + return f"{mib:.1f}" + return f"{mib:.2f}" + + +def _format_iec(value: float, concise: bool = False, decimals: int = 2) -> str: + """Format bytes with IEC units (Ki/Mi/Gi = 1024^1/2/3)""" + k = 1024.0 + units = ["B", "K", "M", "G", "T"] if concise else ["B", "KiB", "MiB", "GiB", "TiB"] + + for unit in units[:-1]: + if abs(value) < k: + return f"{value:.{decimals}f}{unit}" if concise else f"{value:.{decimals}f} {unit}" + value /= k + return f"{value:.{decimals}f}{units[-1]}" if concise else f"{value:.{decimals}f} {units[-1]}" @dataclass @@ -117,7 +124,7 @@ def print_summary(self) -> None: table.add_column("Sent", justify="right") table.add_column("Recv", justify="right") table.add_column("Msgs/s", justify="right", style="green") - table.add_column("Throughput", justify="right", style="green") + table.add_column("MiB/s", justify="right", style="green") table.add_column("Latency", justify="right") table.add_column("Loss", justify="right") @@ -126,11 +133,11 @@ def print_summary(self) -> None: recv_style = "yellow" if r.receive_time > 0.1 else "dim" table.add_row( r.transport, - _format_size(r.msg_size_bytes), + _format_iec(r.msg_size_bytes, decimals=0), f"{r.msgs_sent:,}", f"{r.msgs_received:,}", f"{r.throughput_msgs:,.0f}", - _format_throughput(r.throughput_bytes), + _format_mib(r.throughput_bytes), f"[{recv_style}]{r.receive_time * 1000:.0f}ms[/{recv_style}]", f"[{loss_style}]{r.loss_pct:.1f}%[/{loss_style}]", ) @@ -149,13 +156,6 @@ def _print_heatmap( if not self.results: return - def size_id(size: int) -> str: - if size >= 1048576: - return f"{size // 1048576}MB" - if size >= 1024: - return f"{size // 1024}KB" - return f"{size}B" - transports = sorted(set(r.transport for r in self.results)) sizes = sorted(set(r.msg_size_bytes for r in self.results)) @@ -211,7 +211,7 @@ def val_to_color(v: float) -> int: return gradient[int(t * (len(gradient) - 1))] reset = "\033[0m" - size_labels = [size_id(s) for s in sizes] + size_labels = [_format_iec(s, concise=True, decimals=0) for s in sizes] col_w = max(8, max(len(s) for s in size_labels) + 1) transport_w = max(len(t) for t in transports) + 1 @@ -245,15 +245,9 @@ def print_bandwidth_heatmap(self) -> None: """Print bandwidth heatmap.""" def fmt(v: float) -> str: - if v >= 1e9: - return f"{v / 1e9:.1f}G" - if v >= 1e6: - return f"{v / 1e6:.0f}M" - if v >= 1e3: - return f"{v / 1e3:.0f}K" - return f"{v:.0f}" - - self._print_heatmap("Bandwidth", lambda r: r.throughput_bytes, fmt) + return _format_iec(v, concise=True, decimals=1) + + self._print_heatmap("Bandwidth (IEC)", lambda r: r.throughput_bytes, fmt) def print_latency_heatmap(self) -> None: """Print latency heatmap (time waiting for messages after publishing)."""