Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion dimos/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@
from dimos.core.core import rpc
from dimos.core.module import Module, ModuleBase
from dimos.core.stream import In, Out, RemoteIn, RemoteOut, Transport
from dimos.core.transport import LCMTransport, ZenohTransport, pLCMTransport
from dimos.core.transport import (
LCMTransport,
ZenohTransport,
pLCMTransport,
SHMTransport,
pSHMTransport,
)
from dimos.protocol.rpc.lcmrpc import LCMRPC
from dimos.protocol.rpc.spec import RPCSpec
from dimos.protocol.tf import LCMTF, TF, PubSubTF, TFConfig, TFSpec
Expand Down
49 changes: 49 additions & 0 deletions dimos/core/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from dimos.core.stream import In, RemoteIn, Transport
from dimos.protocol.pubsub.lcmpubsub import LCM, PickleLCM
from dimos.protocol.pubsub.lcmpubsub import Topic as LCMTopic
from dimos.protocol.pubsub.shmpubsub import SharedMemory, PickleSharedMemory

T = TypeVar("T")

Expand Down Expand Up @@ -106,6 +107,54 @@ def subscribe(self, callback: Callable[[T], None], selfstream: In[T] = None) ->
return self.lcm.subscribe(self.topic, lambda msg, topic: callback(msg))


class pSHMTransport(PubSubTransport[T]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: This is rather cryptic. How about PickledSharedMemTransport

_started: bool = False

def __init__(self, topic: str, **kwargs):
super().__init__(topic)
self.shm = PickleSharedMemory(**kwargs)

def __reduce__(self):
return (pSHMTransport, (self.topic,))

def broadcast(self, _, msg):
if not self._started:
self.shm.start()
self._started = True

self.shm.publish(self.topic, msg)

def subscribe(self, callback: Callable[[T], None], selfstream: In[T] = None) -> None:
if not self._started:
self.shm.start()
self._started = True
return self.shm.subscribe(self.topic, lambda msg, topic: callback(msg))


class SHMTransport(PubSubTransport[T]):
_started: bool = False

def __init__(self, topic: str, **kwargs):
super().__init__(topic)
self.shm = SharedMemory(**kwargs)

def __reduce__(self):
return (SHMTransport, (self.topic,))

def broadcast(self, _, msg):
if not self._started:
self.shm.start()
self._started = True

self.shm.publish(self.topic, msg)

def subscribe(self, callback: Callable[[T], None], selfstream: In[T] = None) -> None:
if not self._started:
self.shm.start()
self._started = True
return self.shm.subscribe(self.topic, lambda msg, topic: callback(msg))


class DaskTransport(Transport[T]):
subscribers: List[Callable[[T], None]]
_started: bool = False
Expand Down
304 changes: 304 additions & 0 deletions dimos/protocol/pubsub/shm/ipc_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,304 @@
# Copyright 2025 Dimensional Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# frame_ipc.py
# Python 3.9+
import base64
import time
from abc import ABC, abstractmethod
import os
from typing import Optional, Tuple

import numpy as np
from multiprocessing.shared_memory import SharedMemory
from multiprocessing.managers import SharedMemoryManager

_UNLINK_ON_GC = os.getenv("DIMOS_IPC_UNLINK_ON_GC", "0").lower() not in ("0", "false", "no")


def _open_shm_with_retry(name: str) -> SharedMemory:
tries = int(os.getenv("DIMOS_IPC_ATTACH_RETRIES", "40")) # ~40 tries
base_ms = float(os.getenv("DIMOS_IPC_ATTACH_BACKOFF_MS", "5")) # 5 ms
cap_ms = float(os.getenv("DIMOS_IPC_ATTACH_BACKOFF_CAP_MS", "200")) # 200 ms
last = None
for i in range(tries):
try:
return SharedMemory(name=name)
except FileNotFoundError as e:
last = e
# exponential backoff, capped
time.sleep(min((base_ms * (2**i)), cap_ms) / 1000.0)
raise FileNotFoundError(f"SHM not found after {tries} retries: {name}") from last


def _sanitize_shm_name(name: str) -> str:
# Python's SharedMemory expects names like 'psm_abc', without leading '/'
return name.lstrip("/") if isinstance(name, str) else name


# ---------------------------
# 1) Abstract interface
# ---------------------------


class FrameChannel(ABC):
"""Single-slot 'freshest frame' IPC channel with a tiny control block.
- Double-buffered to avoid torn reads.
- Descriptor is JSON-safe; attach() reconstructs in another process.
"""

@property
@abstractmethod
def device(self) -> str: # "cpu" or "cuda"
...

@property
@abstractmethod
def shape(self) -> tuple: ...

@property
@abstractmethod
def dtype(self) -> np.dtype: ...

@abstractmethod
def publish(self, frame) -> None:
"""Write into inactive buffer, then flip visible index (write control last)."""
...

@abstractmethod
def read(self, last_seq: int = -1, require_new: bool = True):
"""Return (seq:int, ts_ns:int, view-or-None)."""
...

@abstractmethod
def descriptor(self) -> dict:
"""Tiny JSON-safe descriptor (names/handles/shape/dtype/device)."""
...

@classmethod
@abstractmethod
def attach(cls, desc: dict) -> "FrameChannel":
"""Attach in another process."""
...

@abstractmethod
def close(self) -> None:
"""Detach resources (owner also unlinks manager if applicable)."""
...


from multiprocessing.shared_memory import SharedMemory
import weakref, os


def _safe_unlink(name):
try:
shm = SharedMemory(name=name)
shm.unlink()
except FileNotFoundError:
pass
except Exception:
pass


# ---------------------------
# 2) CPU shared-memory backend
# ---------------------------


class CpuShmChannel(FrameChannel):
def __init__(self, shape, dtype=np.uint8, *, data_name=None, ctrl_name=None):
self._shape = tuple(shape)
self._dtype = np.dtype(dtype)
self._nbytes = int(self._dtype.itemsize * np.prod(self._shape))

def _create_or_open(name, size):
try:
shm = SharedMemory(create=True, size=size, name=name)
owner = True
except FileExistsError:
shm = SharedMemory(name=name) # attach existing
owner = False
return shm, owner

if data_name is None or ctrl_name is None:
# fallback: random names (old behavior)
self._shm_data = SharedMemory(create=True, size=2 * self._nbytes)
self._shm_ctrl = SharedMemory(create=True, size=24)
self._is_owner = True
else:
self._shm_data, own_d = _create_or_open(data_name, 2 * self._nbytes)
self._shm_ctrl, own_c = _create_or_open(ctrl_name, 24)
self._is_owner = own_d and own_c

self._ctrl = np.ndarray((3,), dtype=np.int64, buffer=self._shm_ctrl.buf)
if self._is_owner:
self._ctrl[:] = 0 # initialize only once

# only owners set unlink finalizers (beware cross-process timing)
self._finalizer_data = (
weakref.finalize(self, _safe_unlink, self._shm_data.name)
if (_UNLINK_ON_GC and self._is_owner)
else None
)
self._finalizer_ctrl = (
weakref.finalize(self, _safe_unlink, self._shm_ctrl.name)
if (_UNLINK_ON_GC and self._is_owner)
else None
)

def descriptor(self):
return {
"kind": "cpu",
"shape": self._shape,
"dtype": self._dtype.str,
"nbytes": self._nbytes,
"data_name": self._shm_data.name,
"ctrl_name": self._shm_ctrl.name,
}

@property
def device(self):
return "cpu"

@property
def shape(self):
return self._shape

@property
def dtype(self):
return self._dtype

def publish(self, frame):
assert isinstance(frame, np.ndarray)
assert frame.shape == self._shape and frame.dtype == self._dtype
active = int(self._ctrl[2])
inactive = 1 - active
view = np.ndarray(
self._shape,
dtype=self._dtype,
buffer=self._shm_data.buf,
offset=inactive * self._nbytes,
)
np.copyto(view, frame, casting="no")
ts = np.int64(time.time_ns())
# Publish order: ts -> idx -> seq
self._ctrl[1] = ts
self._ctrl[2] = inactive
self._ctrl[0] += 1

def read(self, last_seq: int = -1, require_new=True):
for _ in range(3):
seq1 = int(self._ctrl[0])
idx = int(self._ctrl[2])
ts = int(self._ctrl[1])
view = np.ndarray(
self._shape, dtype=self._dtype, buffer=self._shm_data.buf, offset=idx * self._nbytes
)
if seq1 == int(self._ctrl[0]):
if require_new and seq1 == last_seq:
return seq1, ts, None
return seq1, ts, view
return last_seq, 0, None

def descriptor(self):
return {
"kind": "cpu",
"shape": self._shape,
"dtype": self._dtype.str,
"nbytes": self._nbytes,
"data_name": self._shm_data.name,
"ctrl_name": self._shm_ctrl.name,
}

@classmethod
def attach(cls, desc):
obj = object.__new__(cls)
obj._shape = tuple(desc["shape"])
obj._dtype = np.dtype(desc["dtype"])
obj._nbytes = int(desc["nbytes"])
data_name = desc["data_name"]
ctrl_name = desc["ctrl_name"]
try:
obj._shm_data = _open_shm_with_retry(data_name)
obj._shm_ctrl = _open_shm_with_retry(ctrl_name)
except FileNotFoundError as e:
raise FileNotFoundError(
f"CPU IPC attach failed: control/data SHM not found "
f"(ctrl='{ctrl_name}', data='{data_name}'). "
f"Ensure the writer is running on the same host and the channel is alive."
) from e
obj._ctrl = np.ndarray((3,), dtype=np.int64, buffer=obj._shm_ctrl.buf)
# attachments don’t own/unlink
obj._finalizer_data = obj._finalizer_ctrl = None
return obj

def close(self):
if getattr(self, "_is_owner", False):
try:
self._shm_ctrl.close()
finally:
try:
_safe_unlink(self._shm_ctrl.name)
except:
pass
if hasattr(self, "_shm_data"):
try:
self._shm_data.close()
finally:
try:
_safe_unlink(self._shm_data.name)
except:
pass
return
# readers: just close handles
try:
self._shm_ctrl.close()
except:
pass
try:
self._shm_data.close()
except:
pass


# ---------------------------
# 3) Factories
# ---------------------------


class CPU_IPC_Factory:
"""Creates/attaches CPU shared-memory channels."""

@staticmethod
def create(shape, dtype=np.uint8) -> CpuShmChannel:
return CpuShmChannel(shape, dtype=dtype)

@staticmethod
def attach(desc: dict) -> CpuShmChannel:
assert desc.get("kind") == "cpu", "Descriptor kind mismatch"
return CpuShmChannel.attach(desc)


# ---------------------------
# 4) Runtime selector
# ---------------------------


def make_frame_channel(
shape, dtype=np.uint8, prefer: str = "auto", device: int = 0
) -> FrameChannel:
"""Choose CUDA IPC if available (or requested), otherwise CPU SHM."""
# TODO: Implement the CUDA version of creating this factory
return CPU_IPC_Factory.create(shape, dtype=dtype)
Loading
Loading