Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 65 additions & 2 deletions dimos/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,71 @@ def deploy(

return RPCClient(actor, actor_class)

def check_worker_memory():
"""Check memory usage of all workers."""
info = dask_client.scheduler_info()
console = Console()
total_workers = len(info.get("workers", {}))
total_memory_used = 0
total_memory_limit = 0

for worker_addr, worker_info in info.get("workers", {}).items():
metrics = worker_info.get("metrics", {})
memory_used = metrics.get("memory", 0)
memory_limit = worker_info.get("memory_limit", 0)

cpu_percent = metrics.get("cpu", 0)
managed_bytes = metrics.get("managed_bytes", 0)
spilled = metrics.get("spilled_bytes", {}).get("memory", 0)
worker_status = worker_info.get("status", "unknown")
worker_id = worker_info.get("id", "?")

memory_used_gb = memory_used / 1e9
memory_limit_gb = memory_limit / 1e9
managed_gb = managed_bytes / 1e9
spilled_gb = spilled / 1e9

total_memory_used += memory_used
total_memory_limit += memory_limit

percentage = (memory_used_gb / memory_limit_gb * 100) if memory_limit_gb > 0 else 0

if worker_status == "paused":
status = "[red]PAUSED"
elif percentage >= 95:
status = "[red]CRITICAL"
elif percentage >= 80:
status = "[yellow]WARNING"
else:
status = "[green]OK"

console.print(
f"Worker-{worker_id} {worker_addr}: "
f"{memory_used_gb:.2f}/{memory_limit_gb:.2f}GB ({percentage:.1f}%) "
f"CPU:{cpu_percent:.0f}% Managed:{managed_gb:.2f}GB "
f"{status}"
)

if total_workers > 0:
total_used_gb = total_memory_used / 1e9
total_limit_gb = total_memory_limit / 1e9
total_percentage = (total_used_gb / total_limit_gb * 100) if total_limit_gb > 0 else 0
console.print(
f"[bold]Total: {total_used_gb:.2f}/{total_limit_gb:.2f}GB ({total_percentage:.1f}%) across {total_workers} workers[/bold]"
)

dask_client.deploy = deploy
dask_client.check_worker_memory = check_worker_memory
return dask_client


def start(n: Optional[int] = None) -> Client:
def start(n: Optional[int] = None, memory_limit: str = "auto") -> Client:
"""Start a Dask LocalCluster with specified workers and memory limits.

Args:
n: Number of workers (defaults to CPU count)
memory_limit: Memory limit per worker (e.g., '4GB', '2GiB', or 'auto' for Dask's default)
"""
console = Console()
if not n:
n = mp.cpu_count()
Expand All @@ -96,10 +156,13 @@ def start(n: Optional[int] = None) -> Client:
cluster = LocalCluster(
n_workers=n,
threads_per_worker=4,
memory_limit=memory_limit,
)
client = Client(cluster)

console.print(f"[green]Initialized dimos local cluster with [bright_blue]{n} workers")
console.print(
f"[green]Initialized dimos local cluster with [bright_blue]{n} workers, memory limit: {memory_limit}"
)
return patchdask(client)


Expand Down