Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions dimos/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@
from rich.console import Console

import dimos.core.colors as colors
from dimos.core.core import In, Out, RemoteOut, rpc
from dimos.core.core import rpc
from dimos.core.module import Module, ModuleBase
from dimos.core.stream import In, Out, RemoteIn, RemoteOut, Transport
from dimos.core.transport import LCMTransport, ZenohTransport, pLCMTransport
from dimos.protocol.rpc.lcmrpc import LCMRPC
from dimos.protocol.rpc.spec import RPC
from dimos.protocol.tf import LCMTF, TF, PubSubTF, TFConfig, TFSpec

__all__ = ["TF", "LCMTF", "PubSubTF", "TFSpec", "TFConfig"]


def patch_actor(actor, cls): ...
Expand Down Expand Up @@ -87,7 +91,7 @@ def start(n: Optional[int] = None) -> Client:
n = mp.cpu_count()
with console.status(
f"[green]Initializing dimos local cluster with [bright_blue]{n} workers", spinner="arc"
) as status:
):
cluster = LocalCluster(
n_workers=n,
threads_per_worker=4,
Expand Down
239 changes: 1 addition & 238 deletions dimos/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,259 +15,22 @@

from __future__ import annotations

import enum
import inspect
import traceback
from typing import (
Any,
Callable,
Dict,
Generic,
List,
Optional,
Protocol,
TypeVar,
get_args,
get_origin,
get_type_hints,
)

from dask.distributed import Actor

import dimos.core.colors as colors
from dimos.core.o3dpickle import register_picklers

# injects pickling system into o3d
register_picklers()
T = TypeVar("T")


class Transport(Protocol[T]):
# used by local Output
def broadcast(self, selfstream: Out[T], value: T): ...

# used by local Input
def subscribe(self, selfstream: In[T], callback: Callable[[T], any]) -> None: ...


class DaskTransport(Transport[T]):
subscribers: List[Callable[[T], None]]
_started: bool = False

def __init__(self):
self.subscribers = []

def __str__(self) -> str:
return colors.yellow("DaskTransport")

def __reduce__(self):
return (DaskTransport, ())

def broadcast(self, selfstream: RemoteIn[T], msg: T) -> None:
for subscriber in self.subscribers:
# there is some sort of a bug here with losing worker loop
# print(subscriber.owner, subscriber.owner._worker, subscriber.owner._client)
# subscriber.owner._try_bind_worker_client()
# print(subscriber.owner, subscriber.owner._worker, subscriber.owner._client)

subscriber.owner.dask_receive_msg(subscriber.name, msg).result()

def dask_receive_msg(self, msg) -> None:
for subscriber in self.subscribers:
try:
subscriber(msg)
except Exception as e:
print(
colors.red("Error in DaskTransport subscriber callback:"),
e,
traceback.format_exc(),
)

# for outputs
def dask_register_subscriber(self, remoteInput: RemoteIn[T]) -> None:
self.subscribers.append(remoteInput)

# for inputs
def subscribe(self, selfstream: In[T], callback: Callable[[T], None]) -> None:
if not self._started:
selfstream.connection.owner.dask_register_subscriber(
selfstream.connection.name, selfstream
).result()
self._started = True
self.subscribers.append(callback)


class State(enum.Enum):
UNBOUND = "unbound" # descriptor defined but not bound
READY = "ready" # bound to owner but not yet connected
CONNECTED = "connected" # input bound to an output
FLOWING = "flowing" # runtime: data observed


class Stream(Generic[T]):
_transport: Optional[Transport]

def __init__(
self,
type: type[T],
name: str,
owner: Optional[Any] = None,
transport: Optional[Transport] = None,
):
self.name = name
self.owner = owner
self.type = type
if transport:
self._transport = transport
if not hasattr(self, "_transport"):
self._transport = None

@property
def type_name(self) -> str:
return getattr(self.type, "__name__", repr(self.type))

def _color_fn(self) -> Callable[[str], str]:
if self.state == State.UNBOUND:
return colors.orange
if self.state == State.READY:
return colors.blue
if self.state == State.CONNECTED:
return colors.green
return lambda s: s

def __str__(self) -> str: # noqa: D401
return (
self.__class__.__name__
+ " "
+ self._color_fn()(f"{self.name}[{self.type_name}]")
+ " @ "
+ (
colors.orange(self.owner)
if isinstance(self.owner, Actor)
else colors.green(self.owner)
)
+ ("" if not self._transport else " via " + str(self._transport))
)


class Out(Stream[T]):
_transport: Transport

def __init__(self, *argv, **kwargs):
super().__init__(*argv, **kwargs)
if not hasattr(self, "_transport") or self._transport is None:
self._transport = DaskTransport()

@property
def transport(self) -> Transport[T]:
return self._transport

@property
def state(self) -> State: # noqa: D401
return State.UNBOUND if self.owner is None else State.READY

def __reduce__(self): # noqa: D401
if self.owner is None or not hasattr(self.owner, "ref"):
raise ValueError("Cannot serialise Out without an owner ref")
return (
RemoteOut,
(
self.type,
self.name,
self.owner.ref,
self._transport,
),
)

def publish(self, msg):
self._transport.broadcast(self, msg)


class RemoteStream(Stream[T]):
@property
def state(self) -> State: # noqa: D401
return State.UNBOUND if self.owner is None else State.READY

# this won't work but nvm
@property
def transport(self) -> Transport[T]:
return self._transport

@transport.setter
def transport(self, value: Transport[T]) -> None:
self.owner.set_transport(self.name, value).result()
self._transport = value


class RemoteOut(RemoteStream[T]):
def connect(self, other: RemoteIn[T]):
return other.connect(self)

def observable(self):
"""Create an Observable stream from this remote output."""
from reactivex import create

def subscribe(observer, scheduler=None):
def on_msg(msg):
observer.on_next(msg)

self._transport.subscribe(self, on_msg)
return lambda: None

return create(subscribe)


class In(Stream[T]):
connection: Optional[RemoteOut[T]] = None
_transport: Transport

def __str__(self):
mystr = super().__str__()

if not self.connection:
return mystr

return (mystr + " ◀─").ljust(60, "─") + f" {self.connection}"

def __reduce__(self): # noqa: D401
if self.owner is None or not hasattr(self.owner, "ref"):
raise ValueError("Cannot serialise Out without an owner ref")
return (RemoteIn, (self.type, self.name, self.owner.ref, self._transport))

@property
def transport(self) -> Transport[T]:
if not self._transport:
self._transport = self.connection.transport
return self._transport

@property
def state(self) -> State: # noqa: D401
return State.UNBOUND if self.owner is None else State.READY

def subscribe(self, cb):
self.transport.subscribe(self, cb)


class RemoteIn(RemoteStream[T]):
def connect(self, other: RemoteOut[T]) -> None:
return self.owner.connect_stream(self.name, other).result()

# this won't work but that's ok
@property
def transport(self) -> Transport[T]:
return self._transport

def publish(self, msg):
self.transport.broadcast(self, msg)

@transport.setter
def transport(self, value: Transport[T]) -> None:
self.owner.set_transport(self.name, value).result()
self._transport = value


def rpc(fn: Callable[..., Any]) -> Callable[..., Any]:
fn.__rpc__ = True # type: ignore[attr-defined]
return fn


daskTransport = DaskTransport() # singleton instance for use in Out/RemoteOut
4 changes: 3 additions & 1 deletion dimos/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
from dask.distributed import Actor, get_worker

from dimos.core import colors
from dimos.core.core import In, Out, RemoteIn, RemoteOut, T, Transport
from dimos.core.core import T, rpc
from dimos.core.stream import In, Out, RemoteIn, RemoteOut, Transport
from dimos.protocol.rpc.lcmrpc import LCMRPC


Expand Down Expand Up @@ -65,6 +66,7 @@ def rpcs(cls) -> dict[str, Callable]:
and hasattr(getattr(cls, name), "__rpc__")
}

@rpc
def io(self) -> str:
def _box(name: str) -> str:
return [
Expand Down
Loading
Loading