Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 22 additions & 27 deletions dimos/protocol/pubsub/shmpubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import struct
import threading
import time
import uuid
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional, Tuple
Expand Down Expand Up @@ -82,13 +83,12 @@ class _TopicState:
"cp",
"last_local_payload",
"suppress_counts",
"message_counter",
)

def __init__(self, channel, capacity: int, cp_mod):
self.channel = channel
self.capacity = int(capacity)
self.shape = (self.capacity + 12,) # +12 for header: length(4) + pid(4) + counter(4)
self.shape = (self.capacity + 20,) # +20 for header: length(4) + uuid(16)
self.dtype = np.uint8
self.subs: list[Callable[[bytes, str], None]] = []
self.stop = threading.Event()
Expand All @@ -97,8 +97,7 @@ def __init__(self, channel, capacity: int, cp_mod):
# TODO: implement an initializer variable for is_cuda once CUDA IPC is in
self.cp = cp_mod
self.last_local_payload: Optional[bytes] = None
self.suppress_counts: Dict[Tuple[int, int], int] = defaultdict(int)
self.message_counter = 0
self.suppress_counts: Dict[bytes, int] = defaultdict(int) # UUID bytes as key

# ----- init / lifecycle -------------------------------------------------

Expand Down Expand Up @@ -160,10 +159,8 @@ def publish(self, topic: str, message: bytes) -> None:
logger.error(f"Payload too large: {L} > capacity {st.capacity}")
raise ValueError(f"Payload too large: {L} > capacity {st.capacity}")

# Create a unique identifier using PID and incrementing counter
pid = os.getpid()
st.message_counter += 1
message_id = (pid, st.message_counter)
# Create a unique identifier using UUID4
message_id = uuid.uuid4().bytes # 16 bytes

# Mark this message to suppress its echo
st.suppress_counts[message_id] += 1
Expand All @@ -176,14 +173,15 @@ def publish(self, topic: str, message: bytes) -> None:
logger.warn(f"Payload couldn't be pushed to topic: {topic}")
pass

# Build host frame [len:4] + [pid:4] + [counter:4] + payload and publish
# We embed the message ID in the frame for echo suppression
# Build host frame [len:4] + [uuid:16] + payload and publish
# We embed the message UUID in the frame for echo suppression
host = np.zeros(st.shape, dtype=st.dtype)
# Pack: length(4) + pid(4) + counter(4) + payload
header = struct.pack("<III", L + 8, pid, st.message_counter) # L+8 for pid+counter
host[:12] = np.frombuffer(header, dtype=np.uint8)
# Pack: length(4) + uuid(16) + payload
header = struct.pack("<I", L + 16) # L+16 for uuid
host[:4] = np.frombuffer(header, dtype=np.uint8)
host[4:20] = np.frombuffer(message_id, dtype=np.uint8)
if L:
host[12 : 12 + L] = np.frombuffer(memoryview(payload_bytes), dtype=np.uint8)
host[20 : 20 + L] = np.frombuffer(memoryview(payload_bytes), dtype=np.uint8)

st.channel.publish(host)

Expand Down Expand Up @@ -240,7 +238,7 @@ def reconfigure(self, topic: str, *, capacity: int) -> dict:
"""Change payload capacity (bytes) for a topic; returns new descriptor."""
st = self._ensure_topic(topic)
new_cap = int(capacity)
new_shape = (new_cap + 12,) # +12 for header: length(4) + pid(4) + counter(4)
new_shape = (new_cap + 20,) # +20 for header: length(4) + uuid(16)
desc = st.channel.reconfigure(new_shape, np.uint8)
st.capacity = new_cap
st.shape = new_shape
Expand All @@ -263,7 +261,7 @@ def _names_for_topic(topic: str, capacity: int) -> tuple[str, str]:
return f"psm_{h}_data", f"psm_{h}_ctrl"

data_name, ctrl_name = _names_for_topic(topic, cap)
ch = CpuShmChannel((cap + 12,), np.uint8, data_name=data_name, ctrl_name=ctrl_name)
ch = CpuShmChannel((cap + 20,), np.uint8, data_name=data_name, ctrl_name=ctrl_name)
st = SharedMemoryPubSubBase._TopicState(ch, cap, None)
self._topics[topic] = st
return st
Expand All @@ -279,22 +277,19 @@ def _fanout_loop(self, topic: str, st: _TopicState) -> None:
host = np.array(view, copy=True)

try:
# Read header: length(4) + pid(4) + counter(4)
header = struct.unpack("<III", host[:12].tobytes())
L = header[0]
# Read header: length(4) + uuid(16)
L = struct.unpack("<I", host[:4].tobytes())[0]

if L < 8 or L > st.capacity + 8:
if L < 16 or L > st.capacity + 16:
continue

# Extract PID and counter
pid = header[1]
counter = header[2]
message_id = (pid, counter)
# Extract UUID
message_id = host[4:20].tobytes()

# Extract actual payload (after removing the 8 bytes for pid+counter)
payload_len = L - 8
# Extract actual payload (after removing the 16 bytes for uuid)
payload_len = L - 16
if payload_len > 0:
payload = host[12 : 12 + payload_len].tobytes()
payload = host[20 : 20 + payload_len].tobytes()
else:
continue

Expand Down