Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
cdfaaed
Reapply "Image upgrades! Impls for CUDA + numpy, along with an abstra…
paul-nechifor Oct 8, 2025
2924852
fix
paul-nechifor Oct 8, 2025
2a017dc
CI code cleanup
paul-nechifor Oct 8, 2025
15d1991
add crop back
paul-nechifor Oct 8, 2025
e886217
fix encoding
paul-nechifor Oct 9, 2025
7bc6f9e
Fix detection 3d test for floating point non determinism
spomichter Oct 9, 2025
3ba5695
Fix CI lack of timer precision for test reactive
spomichter Oct 9, 2025
e73e82e
Added cleaner Dask exit handling for cuda and shared memory object cl…
spomichter Oct 10, 2025
f91d6e2
Fix lcm_encode in new Image and added BGR default
spomichter Oct 10, 2025
820fa59
Moved cudaimage unit tests to correct location
spomichter Oct 10, 2025
348d49c
Fix very broken to_cpu method in Image
spomichter Oct 10, 2025
80d6f4f
Major rewrite of image backend tests, seperated cuda/cpu for testing …
spomichter Oct 10, 2025
68a4985
Minor changes
spomichter Oct 10, 2025
084edff
Fixes
spomichter Oct 10, 2025
90fcc58
Simplify to_cpu Image operation
spomichter Oct 10, 2025
fe021b2
Merge branch 'temp-dev' into fix-iage
spomichter Oct 10, 2025
ab83dad
CI code cleanup
spomichter Oct 10, 2025
7a136d2
Merge pull request #671 from dimensionalOS/fix-iage
spomichter Oct 10, 2025
849adf0
Fix Image __eq__
spomichter Oct 11, 2025
c0c99d8
Removed tests added in merge by accident
spomichter Oct 11, 2025
c83ccf8
Merge branch 'dev' into temp-dev
spomichter Oct 11, 2025
08d3e16
CI code cleanup
spomichter Oct 11, 2025
7b22223
Fix race condition in CI for dask daemon threads
spomichter Oct 12, 2025
ba70c4d
fix module tests
paul-nechifor Oct 12, 2025
e216d24
Re fix floating point flaky test
spomichter Oct 12, 2025
eee5ecf
Fix nvimgcodec tests from reloading Image and failing pytests when ru…
spomichter Oct 12, 2025
6bf2180
Fix potential timestamp issue if set to 0.0
spomichter Oct 12, 2025
bf95182
Added back agent image message to agent encode
spomichter Oct 12, 2025
d7a7836
Added Timestamped back to new Image
spomichter Oct 12, 2025
5aaae9a
Fixed sharpness implemented incorrectly as a method rather than prope…
spomichter Oct 12, 2025
a5443d1
Add co-author attribution
spomichter Oct 12, 2025
2bac165
Fix flaky wavefront test
spomichter Oct 12, 2025
c1f01be
Test 0.1 thread delay for dask
spomichter Oct 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 101 additions & 7 deletions dimos/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,36 @@
]


class CudaCleanupPlugin:
"""Dask worker plugin to cleanup CUDA resources on shutdown."""

def setup(self, worker):
"""Called when worker starts."""
pass

def teardown(self, worker):
"""Clean up CUDA resources when worker shuts down."""
try:
import sys

if "cupy" in sys.modules:
import cupy as cp

# Clear memory pools
mempool = cp.get_default_memory_pool()
pinned_mempool = cp.get_default_pinned_memory_pool()
mempool.free_all_blocks()
pinned_mempool.free_all_blocks()
cp.cuda.Stream.null.synchronize()
mempool.free_all_blocks()
pinned_mempool.free_all_blocks()
except Exception:
pass


def patch_actor(actor, cls): ...


class RPCClient:
def __init__(self, actor_instance, actor_class):
self.rpc = LCMRPC()
Expand Down Expand Up @@ -204,23 +234,52 @@ def check_worker_memory():
)

def close_all():
# Prevents multiple calls to close_all
if hasattr(dask_client, "_closed") and dask_client._closed:
return
dask_client._closed = True

import time

# Close cluster and client
# Stop all SharedMemory transports before closing Dask
# This prevents the "leaked shared_memory objects" warning and hangs
try:
from dimos.protocol.pubsub import shmpubsub
import gc

for obj in gc.get_objects():
if isinstance(obj, (shmpubsub.SharedMemory, shmpubsub.PickleSharedMemory)):
try:
obj.stop()
except Exception:
pass
except Exception:
pass

# Get the event loop before shutting down
loop = dask_client.loop

# Clear the actor registry
ActorRegistry.clear()

# Close client first to signal workers to shut down gracefully
# Close cluster and client with reasonable timeout
# The CudaCleanupPlugin will handle CUDA cleanup on each worker
try:
dask_client.close(timeout=2)
local_cluster.close(timeout=5)
except Exception:
pass

# Then close the cluster
try:
local_cluster.close(timeout=2)
dask_client.close(timeout=5)
except Exception:
pass

if loop and hasattr(loop, "add_callback") and hasattr(loop, "stop"):
try:
loop.add_callback(loop.stop)
except Exception:
pass

# Shutdown the Dask offload thread pool
try:
from distributed.utils import _offload_executor
Expand All @@ -230,7 +289,10 @@ def close_all():
except Exception:
pass

# Give threads a moment to clean up
# Give threads time to clean up
# Dask's IO loop and Profile threads are daemon threads
# that will be cleaned up when the process exits
# This is needed, solves race condition in CI thread check
time.sleep(0.1)

dask_client.deploy = deploy
Expand All @@ -247,6 +309,9 @@ def start(n: Optional[int] = None, memory_limit: str = "auto") -> Client:
n: Number of workers (defaults to CPU count)
memory_limit: Memory limit per worker (e.g., '4GB', '2GiB', or 'auto' for Dask's default)
"""
import signal
import atexit

console = Console()
if not n:
n = mp.cpu_count()
Expand All @@ -257,10 +322,39 @@ def start(n: Optional[int] = None, memory_limit: str = "auto") -> Client:
n_workers=n,
threads_per_worker=4,
memory_limit=memory_limit,
plugins=[CudaCleanupPlugin()], # Register CUDA cleanup plugin
)
client = Client(cluster)

console.print(
f"[green]Initialized dimos local cluster with [bright_blue]{n} workers, memory limit: {memory_limit}"
)
return patchdask(client, cluster)

patched_client = patchdask(client, cluster)
patched_client._shutting_down = False

# Signal handler with proper exit handling
def signal_handler(sig, frame):
# If already shutting down, force exit
if patched_client._shutting_down:
import os

console.print("[red]Force exit!")
os._exit(1)

patched_client._shutting_down = True
console.print(f"[yellow]Shutting down (signal {sig})...")

try:
patched_client.close_all()
except Exception:
pass

import sys

sys.exit(0)

signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)

return patched_client
92 changes: 79 additions & 13 deletions dimos/core/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@ def get_import_aliases(tree: ast.AST) -> Dict[str, str]:


def is_module_subclass(
base_classes: List[str], aliases: Dict[str, str], class_hierarchy: Dict[str, List[str]] = None
base_classes: List[str],
aliases: Dict[str, str],
class_hierarchy: Dict[str, List[str]] = None,
current_module_path: str = None,
) -> bool:
"""Check if any base class is or resolves to dimos.core.Module or its variants (recursively)."""
target_classes = {
Expand All @@ -96,7 +99,34 @@ def is_module_subclass(
"dimos.core.module.DaskModule",
}

def check_base(base: str, visited: Set[str] = None) -> bool:
def find_qualified_name(base: str, context_module: str = None) -> str:
"""Find the qualified name for a base class, using import context if available."""
if not class_hierarchy:
return base

# First try exact match (already fully qualified or in hierarchy)
if base in class_hierarchy:
return base

# Check if it's in our aliases (from imports)
if base in aliases:
resolved = aliases[base]
if resolved in class_hierarchy:
return resolved
# The resolved name might be a qualified name that exists
return resolved

# If we have a context module and base is a simple name,
# try to find it in the same module first (for local classes)
if context_module and "." not in base:
same_module_qualified = f"{context_module}.{base}"
if same_module_qualified in class_hierarchy:
return same_module_qualified

# Otherwise return the base as-is
return base

def check_base(base: str, visited: Set[str] = None, context_module: str = None) -> bool:
if visited is None:
visited = set()

Expand All @@ -118,22 +148,27 @@ def check_base(base: str, visited: Set[str] = None) -> bool:
base = resolved

# If we have a class hierarchy, recursively check parent classes
if class_hierarchy and base in class_hierarchy:
for parent_base in class_hierarchy[base]:
if check_base(parent_base, visited):
return True
if class_hierarchy:
# Resolve the base class name to a qualified name
qualified_name = find_qualified_name(base, context_module)

if qualified_name in class_hierarchy:
# Check all parent classes
for parent_base in class_hierarchy[qualified_name]:
if check_base(parent_base, visited, None): # Parent lookups don't use context
return True

return False

for base in base_classes:
if check_base(base):
if check_base(base, context_module=current_module_path):
return True

return False


def scan_file(
filepath: Path, class_hierarchy: Dict[str, List[str]] = None
filepath: Path, class_hierarchy: Dict[str, List[str]] = None, root_path: Path = None
) -> List[Tuple[str, str, bool, bool, Set[str]]]:
"""
Scan a Python file for Module subclasses.
Expand All @@ -153,9 +188,21 @@ def scan_file(
visitor = ModuleVisitor(str(filepath))
visitor.visit(tree)

# Get module path for this file to properly resolve base classes
current_module_path = None
if root_path:
try:
rel_path = filepath.relative_to(root_path.parent)
module_parts = list(rel_path.parts[:-1])
if rel_path.stem != "__init__":
module_parts.append(rel_path.stem)
current_module_path = ".".join(module_parts)
except ValueError:
pass

results = []
for class_name, base_classes, methods in visitor.classes:
if is_module_subclass(base_classes, aliases, class_hierarchy):
if is_module_subclass(base_classes, aliases, class_hierarchy, current_module_path):
has_start = "start" in methods
has_stop = "stop" in methods
forbidden_found = methods & forbidden_method_names
Expand All @@ -172,7 +219,7 @@ def build_class_hierarchy(root_path: Path) -> Dict[str, List[str]]:
"""Build a complete class hierarchy by scanning all Python files."""
hierarchy = {}

for filepath in root_path.rglob("*.py"):
for filepath in sorted(root_path.rglob("*.py")):
# Skip __pycache__ and other irrelevant directories
if "__pycache__" in filepath.parts or ".venv" in filepath.parts:
continue
Expand All @@ -185,13 +232,32 @@ def build_class_hierarchy(root_path: Path) -> Dict[str, List[str]]:
visitor = ModuleVisitor(str(filepath))
visitor.visit(tree)

# Convert filepath to module path (e.g., dimos/core/module.py -> dimos.core.module)
try:
rel_path = filepath.relative_to(root_path.parent)
except ValueError:
# If we can't get relative path, skip this file
continue

# Convert path to module notation
module_parts = list(rel_path.parts[:-1]) # Exclude filename
if rel_path.stem != "__init__":
module_parts.append(rel_path.stem) # Add filename without .py
module_name = ".".join(module_parts)

for class_name, base_classes, _ in visitor.classes:
hierarchy[class_name] = base_classes
# Use fully qualified name as key to avoid conflicts
qualified_name = f"{module_name}.{class_name}" if module_name else class_name
hierarchy[qualified_name] = base_classes

except (SyntaxError, UnicodeDecodeError):
# Skip files that can't be parsed
continue

from pprint import pprint

pprint(hierarchy)

return hierarchy


Expand All @@ -203,12 +269,12 @@ def scan_directory(root_path: Path) -> List[Tuple[str, str, bool, bool, Set[str]
# Then scan for Module subclasses using the complete hierarchy
results = []

for filepath in root_path.rglob("*.py"):
for filepath in sorted(root_path.rglob("*.py")):
# Skip __pycache__ and other irrelevant directories
if "__pycache__" in filepath.parts or ".venv" in filepath.parts:
continue

file_results = scan_file(filepath, class_hierarchy)
file_results = scan_file(filepath, class_hierarchy, root_path)
results.extend(file_results)

return results
Expand Down
Loading