diff --git a/dimos/protocol/pubsub/shmpubsub.py b/dimos/protocol/pubsub/shmpubsub.py index 10f8763496..d0cfac01ff 100644 --- a/dimos/protocol/pubsub/shmpubsub.py +++ b/dimos/protocol/pubsub/shmpubsub.py @@ -26,7 +26,7 @@ import time from collections import defaultdict from dataclasses import dataclass -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, Tuple import numpy as np @@ -82,12 +82,13 @@ class _TopicState: "cp", "last_local_payload", "suppress_counts", + "message_counter", ) def __init__(self, channel, capacity: int, cp_mod): self.channel = channel self.capacity = int(capacity) - self.shape = (self.capacity + 4,) # +4 for uint32 length header + self.shape = (self.capacity + 12,) # +12 for header: length(4) + pid(4) + counter(4) self.dtype = np.uint8 self.subs: list[Callable[[bytes, str], None]] = [] self.stop = threading.Event() @@ -96,7 +97,8 @@ def __init__(self, channel, capacity: int, cp_mod): # TODO: implement an initializer variable for is_cuda once CUDA IPC is in self.cp = cp_mod self.last_local_payload: Optional[bytes] = None - self.suppress_counts = defaultdict(int) + self.suppress_counts: Dict[Tuple[int, int], int] = defaultdict(int) + self.message_counter = 0 # ----- init / lifecycle ------------------------------------------------- @@ -158,9 +160,13 @@ def publish(self, topic: str, message: bytes) -> None: logger.error(f"Payload too large: {L} > capacity {st.capacity}") raise ValueError(f"Payload too large: {L} > capacity {st.capacity}") - # Mark this payload to suppress its single echo (handles back-to-back publishes) - payload_hash = hashlib.md5(payload_bytes).digest() - st.suppress_counts[payload_hash] += 1 + # Create a unique identifier using PID and incrementing counter + pid = os.getpid() + st.message_counter += 1 + message_id = (pid, st.message_counter) + + # Mark this message to suppress its echo + st.suppress_counts[message_id] += 1 # Synchronous local delivery first (zero extra copies) for cb in list(st.subs): @@ -170,11 +176,14 @@ def publish(self, topic: str, message: bytes) -> None: logger.warn(f"Payload couldn't be pushed to topic: {topic}") pass - # Build host frame [len:4] + payload and publish + # Build host frame [len:4] + [pid:4] + [counter:4] + payload and publish + # We embed the message ID in the frame for echo suppression host = np.zeros(st.shape, dtype=st.dtype) - host[:4] = np.frombuffer(struct.pack(" dict: """Change payload capacity (bytes) for a topic; returns new descriptor.""" st = self._ensure_topic(topic) new_cap = int(capacity) - new_shape = (new_cap + 4,) + new_shape = (new_cap + 12,) # +12 for header: length(4) + pid(4) + counter(4) desc = st.channel.reconfigure(new_shape, np.uint8) st.capacity = new_cap st.shape = new_shape @@ -254,7 +263,7 @@ def _names_for_topic(topic: str, capacity: int) -> tuple[str, str]: return f"psm_{h}_data", f"psm_{h}_ctrl" data_name, ctrl_name = _names_for_topic(topic, cap) - ch = CpuShmChannel((cap + 4,), np.uint8, data_name=data_name, ctrl_name=ctrl_name) + ch = CpuShmChannel((cap + 12,), np.uint8, data_name=data_name, ctrl_name=ctrl_name) st = SharedMemoryPubSubBase._TopicState(ch, cap, None) self._topics[topic] = st return st @@ -270,20 +279,32 @@ def _fanout_loop(self, topic: str, st: _TopicState) -> None: host = np.array(view, copy=True) try: - L = struct.unpack(" st.capacity: + # Read header: length(4) + pid(4) + counter(4) + header = struct.unpack(" st.capacity + 8: continue - payload = host[4 : 4 + L].tobytes() + # Extract PID and counter + pid = header[1] + counter = header[2] + message_id = (pid, counter) + + # Extract actual payload (after removing the 8 bytes for pid+counter) + payload_len = L - 8 + if payload_len > 0: + payload = host[12 : 12 + payload_len].tobytes() + else: + continue # Drop exactly the number of local echoes we created - payload_hash = hashlib.md5(payload).digest() - cnt = st.suppress_counts.get(payload_hash, 0) + cnt = st.suppress_counts.get(message_id, 0) if cnt > 0: if cnt == 1: - del st.suppress_counts[payload_hash] + del st.suppress_counts[message_id] else: - st.suppress_counts[payload_hash] = cnt - 1 + st.suppress_counts[message_id] = cnt - 1 continue # suppressed except Exception: