Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 39 additions & 18 deletions dimos/protocol/pubsub/shmpubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import time
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict, Optional, Tuple

import numpy as np

Expand Down Expand Up @@ -82,12 +82,13 @@ class _TopicState:
"cp",
"last_local_payload",
"suppress_counts",
"message_counter",
)

def __init__(self, channel, capacity: int, cp_mod):
self.channel = channel
self.capacity = int(capacity)
self.shape = (self.capacity + 4,) # +4 for uint32 length header
self.shape = (self.capacity + 12,) # +12 for header: length(4) + pid(4) + counter(4)
self.dtype = np.uint8
self.subs: list[Callable[[bytes, str], None]] = []
self.stop = threading.Event()
Expand All @@ -96,7 +97,8 @@ def __init__(self, channel, capacity: int, cp_mod):
# TODO: implement an initializer variable for is_cuda once CUDA IPC is in
self.cp = cp_mod
self.last_local_payload: Optional[bytes] = None
self.suppress_counts = defaultdict(int)
self.suppress_counts: Dict[Tuple[int, int], int] = defaultdict(int)
self.message_counter = 0

# ----- init / lifecycle -------------------------------------------------

Expand Down Expand Up @@ -158,9 +160,13 @@ def publish(self, topic: str, message: bytes) -> None:
logger.error(f"Payload too large: {L} > capacity {st.capacity}")
raise ValueError(f"Payload too large: {L} > capacity {st.capacity}")

# Mark this payload to suppress its single echo (handles back-to-back publishes)
payload_hash = hashlib.md5(payload_bytes).digest()
st.suppress_counts[payload_hash] += 1
# Create a unique identifier using PID and incrementing counter
pid = os.getpid()
st.message_counter += 1
message_id = (pid, st.message_counter)

# Mark this message to suppress its echo
st.suppress_counts[message_id] += 1

# Synchronous local delivery first (zero extra copies)
for cb in list(st.subs):
Expand All @@ -170,11 +176,14 @@ def publish(self, topic: str, message: bytes) -> None:
logger.warn(f"Payload couldn't be pushed to topic: {topic}")
pass

# Build host frame [len:4] + payload and publish
# Build host frame [len:4] + [pid:4] + [counter:4] + payload and publish
# We embed the message ID in the frame for echo suppression
host = np.zeros(st.shape, dtype=st.dtype)
host[:4] = np.frombuffer(struct.pack("<I", L), dtype=np.uint8)
# Pack: length(4) + pid(4) + counter(4) + payload
header = struct.pack("<III", L + 8, pid, st.message_counter) # L+8 for pid+counter
host[:12] = np.frombuffer(header, dtype=np.uint8)
if L:
host[4 : 4 + L] = np.frombuffer(memoryview(payload_bytes), dtype=np.uint8)
host[12 : 12 + L] = np.frombuffer(memoryview(payload_bytes), dtype=np.uint8)

st.channel.publish(host)

Expand Down Expand Up @@ -231,7 +240,7 @@ def reconfigure(self, topic: str, *, capacity: int) -> dict:
"""Change payload capacity (bytes) for a topic; returns new descriptor."""
st = self._ensure_topic(topic)
new_cap = int(capacity)
new_shape = (new_cap + 4,)
new_shape = (new_cap + 12,) # +12 for header: length(4) + pid(4) + counter(4)
desc = st.channel.reconfigure(new_shape, np.uint8)
st.capacity = new_cap
st.shape = new_shape
Expand All @@ -254,7 +263,7 @@ def _names_for_topic(topic: str, capacity: int) -> tuple[str, str]:
return f"psm_{h}_data", f"psm_{h}_ctrl"

data_name, ctrl_name = _names_for_topic(topic, cap)
ch = CpuShmChannel((cap + 4,), np.uint8, data_name=data_name, ctrl_name=ctrl_name)
ch = CpuShmChannel((cap + 12,), np.uint8, data_name=data_name, ctrl_name=ctrl_name)
st = SharedMemoryPubSubBase._TopicState(ch, cap, None)
self._topics[topic] = st
return st
Expand All @@ -270,20 +279,32 @@ def _fanout_loop(self, topic: str, st: _TopicState) -> None:
host = np.array(view, copy=True)

try:
L = struct.unpack("<I", host[:4].tobytes())[0]
if L == 0 or L < 0 or L > st.capacity:
# Read header: length(4) + pid(4) + counter(4)
header = struct.unpack("<III", host[:12].tobytes())
L = header[0]

if L < 8 or L > st.capacity + 8:
continue

payload = host[4 : 4 + L].tobytes()
# Extract PID and counter
pid = header[1]
counter = header[2]
message_id = (pid, counter)

# Extract actual payload (after removing the 8 bytes for pid+counter)
payload_len = L - 8
if payload_len > 0:
payload = host[12 : 12 + payload_len].tobytes()
else:
continue

# Drop exactly the number of local echoes we created
payload_hash = hashlib.md5(payload).digest()
cnt = st.suppress_counts.get(payload_hash, 0)
cnt = st.suppress_counts.get(message_id, 0)
if cnt > 0:
if cnt == 1:
del st.suppress_counts[payload_hash]
del st.suppress_counts[message_id]
else:
st.suppress_counts[payload_hash] = cnt - 1
st.suppress_counts[message_id] = cnt - 1
continue # suppressed

except Exception:
Expand Down