Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion nemo_reinforcer/distributed/virtual_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,30 @@ def _init_placement_groups(self, strategy: str):
if self._node_placement_groups is not None:
return self._node_placement_groups

# Check available resources in the Ray cluster
cluster_resources = ray.cluster_resources()
total_available_gpus = int(cluster_resources.get("GPU", 0))
total_available_cpus = int(cluster_resources.get("CPU", 0))

# Calculate required resources
total_requested_gpus = (
sum(self._bundle_ct_per_node_list) if self.use_gpus else 0
)
total_requested_cpus = (
sum(self._bundle_ct_per_node_list) * self.max_colocated_worker_groups
)

# Validate resources
if self.use_gpus and total_requested_gpus > total_available_gpus:
raise ValueError(
f"Not enough GPUs available. Requested {total_requested_gpus} GPUs, but only {total_available_gpus} are available in the cluster."
)

if total_requested_cpus > total_available_cpus:
raise ValueError(
f"Not enough CPUs available. Requested {total_requested_cpus} CPUs, but only {total_available_cpus} are available in the cluster."
)

num_cpus_per_bundle = self.max_colocated_worker_groups
# num_gpus_per_bundle == 1 indicates that there is 1 GPU per process
num_gpus_per_bundle = 1 if self.use_gpus else 0
Expand All @@ -192,7 +216,24 @@ def _init_placement_groups(self, strategy: str):
for i, bundles in enumerate(resources)
]

ray.get([pg.ready() for pg in self._node_placement_groups])
# Add timeout to prevent hanging indefinitely
try:
ray.get(
[pg.ready() for pg in self._node_placement_groups], timeout=180
) # 3-minute timeout
except (TimeoutError, ray.exceptions.GetTimeoutError):
# Clean up any created placement groups
for pg in self._node_placement_groups:
try:
remove_placement_group(pg)
except Exception:
pass
self._node_placement_groups = None
raise TimeoutError(
"Timed out waiting for placement groups to be ready. The cluster may not have enough resources "
"to satisfy the requested configuration, or the resources may be busy with other tasks."
)

return self._node_placement_groups

def get_placement_groups(self):
Expand Down
Loading