diff --git a/dimos/core/__init__.py b/dimos/core/__init__.py index 9bb1a3dc68..81b1ad4cee 100644 --- a/dimos/core/__init__.py +++ b/dimos/core/__init__.py @@ -82,11 +82,71 @@ def deploy( return RPCClient(actor, actor_class) + def check_worker_memory(): + """Check memory usage of all workers.""" + info = dask_client.scheduler_info() + console = Console() + total_workers = len(info.get("workers", {})) + total_memory_used = 0 + total_memory_limit = 0 + + for worker_addr, worker_info in info.get("workers", {}).items(): + metrics = worker_info.get("metrics", {}) + memory_used = metrics.get("memory", 0) + memory_limit = worker_info.get("memory_limit", 0) + + cpu_percent = metrics.get("cpu", 0) + managed_bytes = metrics.get("managed_bytes", 0) + spilled = metrics.get("spilled_bytes", {}).get("memory", 0) + worker_status = worker_info.get("status", "unknown") + worker_id = worker_info.get("id", "?") + + memory_used_gb = memory_used / 1e9 + memory_limit_gb = memory_limit / 1e9 + managed_gb = managed_bytes / 1e9 + spilled_gb = spilled / 1e9 + + total_memory_used += memory_used + total_memory_limit += memory_limit + + percentage = (memory_used_gb / memory_limit_gb * 100) if memory_limit_gb > 0 else 0 + + if worker_status == "paused": + status = "[red]PAUSED" + elif percentage >= 95: + status = "[red]CRITICAL" + elif percentage >= 80: + status = "[yellow]WARNING" + else: + status = "[green]OK" + + console.print( + f"Worker-{worker_id} {worker_addr}: " + f"{memory_used_gb:.2f}/{memory_limit_gb:.2f}GB ({percentage:.1f}%) " + f"CPU:{cpu_percent:.0f}% Managed:{managed_gb:.2f}GB " + f"{status}" + ) + + if total_workers > 0: + total_used_gb = total_memory_used / 1e9 + total_limit_gb = total_memory_limit / 1e9 + total_percentage = (total_used_gb / total_limit_gb * 100) if total_limit_gb > 0 else 0 + console.print( + f"[bold]Total: {total_used_gb:.2f}/{total_limit_gb:.2f}GB ({total_percentage:.1f}%) across {total_workers} workers[/bold]" + ) + dask_client.deploy = deploy + dask_client.check_worker_memory = check_worker_memory return dask_client -def start(n: Optional[int] = None) -> Client: +def start(n: Optional[int] = None, memory_limit: str = "auto") -> Client: + """Start a Dask LocalCluster with specified workers and memory limits. + + Args: + n: Number of workers (defaults to CPU count) + memory_limit: Memory limit per worker (e.g., '4GB', '2GiB', or 'auto' for Dask's default) + """ console = Console() if not n: n = mp.cpu_count() @@ -96,10 +156,13 @@ def start(n: Optional[int] = None) -> Client: cluster = LocalCluster( n_workers=n, threads_per_worker=4, + memory_limit=memory_limit, ) client = Client(cluster) - console.print(f"[green]Initialized dimos local cluster with [bright_blue]{n} workers") + console.print( + f"[green]Initialized dimos local cluster with [bright_blue]{n} workers, memory limit: {memory_limit}" + ) return patchdask(client)