Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 94 additions & 8 deletions dimos/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,33 @@
__all__ = ["TF", "LCMTF", "PubSubTF", "TFSpec", "TFConfig"]


class CudaCleanupPlugin:
"""Dask worker plugin to cleanup CUDA resources on shutdown."""

def setup(self, worker):
"""Called when worker starts."""
pass

def teardown(self, worker):
"""Clean up CUDA resources when worker shuts down."""
try:
import sys

if "cupy" in sys.modules:
import cupy as cp

# Clear memory pools
mempool = cp.get_default_memory_pool()
pinned_mempool = cp.get_default_pinned_memory_pool()
mempool.free_all_blocks()
pinned_mempool.free_all_blocks()
cp.cuda.Stream.null.synchronize()
mempool.free_all_blocks()
pinned_mempool.free_all_blocks()
except Exception:
pass


def patch_actor(actor, cls): ...


Expand Down Expand Up @@ -183,17 +210,46 @@ def check_worker_memory():
)

def close_all():
# Prevents multiple calls to close_all
if hasattr(dask_client, "_closed") and dask_client._closed:
return
dask_client._closed = True

import time

# Stop all SharedMemory transports before closing Dask
# This prevents the "leaked shared_memory objects" warning and hangs
try:
from dimos.protocol.pubsub import shmpubsub
import gc

for obj in gc.get_objects():
if isinstance(obj, (shmpubsub.SharedMemory, shmpubsub.PickleSharedMemory)):
try:
obj.stop()
except Exception:
pass
except Exception:
pass

# Get the event loop before shutting down
loop = dask_client.loop

# Close cluster and client
# Clear the actor registry
ActorRegistry.clear()
local_cluster.close()
dask_client.close()

# Stop the Tornado IOLoop to clean up IO loop and Profile threads
# Close cluster and client with reasonable timeout
# The CudaCleanupPlugin will handle CUDA cleanup on each worker
try:
local_cluster.close(timeout=5)
except Exception:
pass

try:
dask_client.close(timeout=5)
except Exception:
pass

if loop and hasattr(loop, "add_callback") and hasattr(loop, "stop"):
try:
loop.add_callback(loop.stop)
Expand All @@ -209,9 +265,6 @@ def close_all():
except Exception:
pass

# Give threads a moment to clean up
time.sleep(0.1)

dask_client.deploy = deploy
dask_client.check_worker_memory = check_worker_memory
dask_client.stop = lambda: dask_client.close()
Expand All @@ -226,6 +279,9 @@ def start(n: Optional[int] = None, memory_limit: str = "auto") -> Client:
n: Number of workers (defaults to CPU count)
memory_limit: Memory limit per worker (e.g., '4GB', '2GiB', or 'auto' for Dask's default)
"""
import signal
import atexit

console = Console()
if not n:
n = mp.cpu_count()
Expand All @@ -236,10 +292,40 @@ def start(n: Optional[int] = None, memory_limit: str = "auto") -> Client:
n_workers=n,
threads_per_worker=4,
memory_limit=memory_limit,
plugins=[CudaCleanupPlugin()], # Register CUDA cleanup plugin
)
client = Client(cluster)

console.print(
f"[green]Initialized dimos local cluster with [bright_blue]{n} workers, memory limit: {memory_limit}"
)
return patchdask(client, cluster)

patched_client = patchdask(client, cluster)
patched_client._closed = False # Initialize the flag
patched_client._shutting_down = False # Track if we're shutting down

# Signal handler with proper exit handling
def signal_handler(sig, frame):
# If already shutting down, force exit
if patched_client._shutting_down:
import os

console.print("[red]Force exit!")
os._exit(1)

patched_client._shutting_down = True
console.print(f"[yellow]Shutting down (signal {sig})...")

try:
patched_client.close_all()
except Exception:
pass

import sys

sys.exit(0)

signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)

return patched_client
Loading