diff --git a/dimos/protocol/pubsub/shmpubsub.py b/dimos/protocol/pubsub/shmpubsub.py index d0cfac01ff..8bcf87828c 100644 --- a/dimos/protocol/pubsub/shmpubsub.py +++ b/dimos/protocol/pubsub/shmpubsub.py @@ -24,6 +24,7 @@ import struct import threading import time +import uuid from collections import defaultdict from dataclasses import dataclass from typing import Any, Callable, Dict, Optional, Tuple @@ -82,13 +83,12 @@ class _TopicState: "cp", "last_local_payload", "suppress_counts", - "message_counter", ) def __init__(self, channel, capacity: int, cp_mod): self.channel = channel self.capacity = int(capacity) - self.shape = (self.capacity + 12,) # +12 for header: length(4) + pid(4) + counter(4) + self.shape = (self.capacity + 20,) # +20 for header: length(4) + uuid(16) self.dtype = np.uint8 self.subs: list[Callable[[bytes, str], None]] = [] self.stop = threading.Event() @@ -97,8 +97,7 @@ def __init__(self, channel, capacity: int, cp_mod): # TODO: implement an initializer variable for is_cuda once CUDA IPC is in self.cp = cp_mod self.last_local_payload: Optional[bytes] = None - self.suppress_counts: Dict[Tuple[int, int], int] = defaultdict(int) - self.message_counter = 0 + self.suppress_counts: Dict[bytes, int] = defaultdict(int) # UUID bytes as key # ----- init / lifecycle ------------------------------------------------- @@ -160,10 +159,8 @@ def publish(self, topic: str, message: bytes) -> None: logger.error(f"Payload too large: {L} > capacity {st.capacity}") raise ValueError(f"Payload too large: {L} > capacity {st.capacity}") - # Create a unique identifier using PID and incrementing counter - pid = os.getpid() - st.message_counter += 1 - message_id = (pid, st.message_counter) + # Create a unique identifier using UUID4 + message_id = uuid.uuid4().bytes # 16 bytes # Mark this message to suppress its echo st.suppress_counts[message_id] += 1 @@ -176,14 +173,15 @@ def publish(self, topic: str, message: bytes) -> None: logger.warn(f"Payload couldn't be pushed to topic: {topic}") pass - # Build host frame [len:4] + [pid:4] + [counter:4] + payload and publish - # We embed the message ID in the frame for echo suppression + # Build host frame [len:4] + [uuid:16] + payload and publish + # We embed the message UUID in the frame for echo suppression host = np.zeros(st.shape, dtype=st.dtype) - # Pack: length(4) + pid(4) + counter(4) + payload - header = struct.pack(" dict: """Change payload capacity (bytes) for a topic; returns new descriptor.""" st = self._ensure_topic(topic) new_cap = int(capacity) - new_shape = (new_cap + 12,) # +12 for header: length(4) + pid(4) + counter(4) + new_shape = (new_cap + 20,) # +20 for header: length(4) + uuid(16) desc = st.channel.reconfigure(new_shape, np.uint8) st.capacity = new_cap st.shape = new_shape @@ -263,7 +261,7 @@ def _names_for_topic(topic: str, capacity: int) -> tuple[str, str]: return f"psm_{h}_data", f"psm_{h}_ctrl" data_name, ctrl_name = _names_for_topic(topic, cap) - ch = CpuShmChannel((cap + 12,), np.uint8, data_name=data_name, ctrl_name=ctrl_name) + ch = CpuShmChannel((cap + 20,), np.uint8, data_name=data_name, ctrl_name=ctrl_name) st = SharedMemoryPubSubBase._TopicState(ch, cap, None) self._topics[topic] = st return st @@ -279,22 +277,19 @@ def _fanout_loop(self, topic: str, st: _TopicState) -> None: host = np.array(view, copy=True) try: - # Read header: length(4) + pid(4) + counter(4) - header = struct.unpack(" st.capacity + 8: + if L < 16 or L > st.capacity + 16: continue - # Extract PID and counter - pid = header[1] - counter = header[2] - message_id = (pid, counter) + # Extract UUID + message_id = host[4:20].tobytes() - # Extract actual payload (after removing the 8 bytes for pid+counter) - payload_len = L - 8 + # Extract actual payload (after removing the 16 bytes for uuid) + payload_len = L - 16 if payload_len > 0: - payload = host[12 : 12 + payload_len].tobytes() + payload = host[20 : 20 + payload_len].tobytes() else: continue