Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 87 additions & 17 deletions nemo_rl/distributed/worker_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@
import os
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Iterable, Optional, Union
from typing import Any, Optional, Union

import ray
from ray.util.placement_group import PlacementGroup
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

from nemo_rl.distributed.batched_data_dict import SlicedDataDict
from nemo_rl.distributed.named_sharding import NamedSharding
from nemo_rl.distributed.ray_actor_environment_registry import (
get_actor_python_env,
Expand Down Expand Up @@ -583,32 +582,75 @@ def run_single_worker_single_data(
Returns:
ray.ObjectRef: A Ray future for the result.
"""
assert len(args) == 0, (
"run_single_worker_single_data will fail with args under certain circumstances. "
"Please use kwargs instead. "
"See https://github.com/NVIDIA-NeMo/RL/issues/582 for more details."
)

worker = self.workers[worker_idx]
method = getattr(worker, method_name)
return method.remote(*args, **kwargs)

def run_all_workers_multiple_data(
self,
method_name: str,
data: list[Any],
*args,
run_rank_0_only_axes: list[str] | None = None,
common_kwargs: Optional[dict[str, Any]] = None,
**kwargs,
) -> list[ray.ObjectRef]:
"""Run a method on all workers in parallel with different data.

Args:
method_name: Name of the method to call on each worker
data: List of data to pass to workers/groups
*args: List of arguments to pass to workers/groups
e.g. [[arg1_for_worker_1, arg1_for_worker_2], [arg2_for_worker_1, arg2_for_worker_2]]
run_rank_0_only_axes: List of named axes for which only rank 0 should run the method.
common_kwargs: Additional keyword arguments to pass to all workers
common_kwargs: Keyword arguments to pass to all workers
**kwargs: Keyword arguments to pass to workers/groups
e.g. {"key1": [value_for_worker_1, value_for_worker_2], "key2": [value_for_worker_1, value_for_worker_2]}

Returns:
list[ray.ObjectRef]: A list of ray futures
"""
assert len(args) == 0, (
"run_all_workers_multiple_data will fail with args under certain circumstances. "
"Please use kwargs instead. "
"See https://github.com/NVIDIA-NeMo/RL/issues/582 for more details."
)

# Check at least one arg or kwarg is provided
assert len(args) > 0 or len(kwargs) > 0, (
"At least one args (positional arguments) or kwargs (keyword arguments) must be provided in run_all_workers_multiple_data. "
"Otherwise, please use run_all_workers_single_data."
)

# Check all args and kwargs have the same length
args_count = [len(arg) for arg in args]
assert all(count == args_count[0] for count in args_count), (
"All args must have the same length"
)
args_count = args_count[0] if len(args_count) > 0 else 0

kwargs_count = [len(value) for value in kwargs.values()]
assert all(count == kwargs_count[0] for count in kwargs_count), (
"All kwargs must have the same length"
)
kwargs_count = kwargs_count[0] if len(kwargs_count) > 0 else 0

if args_count > 0 and kwargs_count > 0:
assert args_count == kwargs_count, (
"The number of args and kwargs must be the same in run_all_workers_multiple_data. "
f"args length = {args_count}, kwargs length = {kwargs_count}"
)
data_count = max(args_count, kwargs_count)

# Check the data length is equal to the number of workers
if run_rank_0_only_axes is None:
assert len(data) == len(self.workers), (
assert data_count == len(self.workers), (
"data length should be equal to the number of workers: "
f"data length = {len(data)}, number of workers = {len(self.workers)}"
f"data length = {data_count}, number of workers = {len(self.workers)}"
)

futures = []
Expand All @@ -633,12 +675,16 @@ def run_all_workers_multiple_data(

if should_run:
method = getattr(worker, method_name)
futures.append(method.remote(data=data[data_idx], **common_kwargs))
worker_args = [arg[data_idx] for arg in args]
worker_kwargs = {key: value[data_idx] for key, value in kwargs.items()}
futures.append(
method.remote(*worker_args, **worker_kwargs, **common_kwargs)
)
data_idx += 1

assert data_idx == len(data), (
assert data_idx == data_count, (
"data length should be equal to the number of workers started: "
f"data length = {len(data)}, number of workers started = {data_idx}"
f"data length = {data_count}, number of workers started = {data_idx}"
)

return futures
Expand All @@ -660,6 +706,12 @@ def run_all_workers_single_data(
Returns:
list[ray.ObjectRef]: A list of ray futures
"""
assert len(args) == 0, (
"run_all_workers_single_data will fail with args under certain circumstances. "
"Please use kwargs instead. "
"See https://github.com/NVIDIA-NeMo/RL/issues/582 for more details."
)

futures = []

if run_rank_0_only_axes is None:
Expand All @@ -686,12 +738,13 @@ def run_all_workers_single_data(
def run_all_workers_sharded_data(
self,
method_name: str,
data: Iterable[SlicedDataDict], # arbitrary nested iterables of SlicedDataDicts
*args,
in_sharded_axes: list[str] | None = None,
replicate_on_axes: list[str] | None = None,
output_is_replicated: list[str] | None = None,
make_dummy_calls_to_free_axes: bool = False,
common_kwargs: Optional[dict[str, Any]] = None,
**kwargs,
) -> MultiWorkerFuture:
"""Run a method on all workers in parallel with sharded data.

Expand All @@ -701,17 +754,27 @@ def run_all_workers_sharded_data(

Args:
method_name: Name of the method to call on each worker
data: Iterable of SlicedDataDicts to pass to workers/groups
*args: List of arguments to pass to workers/groups
e.g. [[arg1_for_worker_1, arg1_for_worker_2], [arg2_for_worker_1, arg2_for_worker_2]]
in_sharded_axes: List of axes that are sharded
replicate_on_axes: List of axes that are to be replicated
output_is_replicated: List of axes along which the output is replicated (and we should just return the first result).
We also just return from rank 0 of free axes.
make_dummy_calls_to_free_axes: Whether to make dummy calls (with None) to workers that
aren't rank 0 on 'free axes' (axes not in in_sharded_axes or replicate_on_axes).
common_kwargs: Additional keyword arguments to pass to all workers
common_kwargs: Keyword arguments to pass to all workers
**kwargs: Keyword arguments to pass to workers/groups
e.g. {"key1": [value_for_worker_1, value_for_worker_2], "key2": [value_for_worker_1, value_for_worker_2]}

Returns:
MultiWorkerFuture: Object containing futures and their associated worker information
"""
assert len(args) == 0, (
"run_all_workers_sharded_data will fail with args under certain circumstances. "
"Please use kwargs instead. "
"See https://github.com/NVIDIA-NeMo/RL/issues/582 for more details."
)

if self.sharding_annotations is None:
raise ValueError(
"Sharding annotations must be provided to use sharded data distribution"
Expand Down Expand Up @@ -771,24 +834,31 @@ def run_all_workers_sharded_data(

if should_receive_data:
# Find the appropriate data slice for this worker
worker_data = data
worker_args = args
worker_kwargs = kwargs
for axis in in_sharded_axes:
if axis in worker_coords:
# Select the appropriate slice for this axis
worker_data = worker_data[worker_coords[axis]]
worker_args = [arg[worker_coords[axis]] for arg in worker_args]
worker_kwargs = {
key: value[worker_coords[axis]]
for key, value in worker_kwargs.items()
}

# Call the method on the worker with its data slice
future = getattr(worker, method_name).remote(
data=worker_data, **common_kwargs
*worker_args, **worker_kwargs, **common_kwargs
)
futures.append(future)
called_workers.append(worker_idx)
else:
# If this worker doesn't need data:
if make_dummy_calls_to_free_axes:
# If make_dummy_calls_to_free_axes is True, just call the method with None
worker_args = [None] * len(args)
worker_kwargs = {key: None for key in kwargs.keys()}
future = getattr(worker, method_name).remote(
data=None, **common_kwargs
*worker_args, **worker_kwargs, **common_kwargs
)
futures.append(future)
called_workers.append(worker_idx)
Expand Down
42 changes: 23 additions & 19 deletions nemo_rl/models/generation/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,24 +348,26 @@ def _patch_vllm_init_workers_ray():
else:
self.llm = vllm.LLM(**llm_kwargs)

def init_collective(self, data: int, ip: str, port: int, world_size: int) -> None:
def init_collective(
self, rank_prefix: int, ip: str, port: int, world_size: int
) -> None:
self.llm.collective_rpc(
"init_collective",
args=(
data,
rank_prefix,
ip,
port,
world_size,
),
)

async def init_collective_async(
self, data: int, ip: str, port: int, world_size: int
self, rank_prefix: int, ip: str, port: int, world_size: int
) -> None:
await self.llm.collective_rpc(
"init_collective",
args=(
data,
rank_prefix,
ip,
port,
world_size,
Expand Down Expand Up @@ -903,11 +905,11 @@ async def report_device_id_async(self) -> list[str]:

return cast(list[str], list_of_worker_results)

def update_weights_from_ipc_handles(self, data: dict[str, Any]) -> bool:
def update_weights_from_ipc_handles(self, ipc_handles: dict[str, Any]) -> bool:
"""Update weights from IPC handles by delegating to the vLLM Worker implementation.

Args:
data (dict): Dictionary mapping device UUIDs (str) to parameter IPC handles.
ipc_handles (dict): Dictionary mapping device UUIDs (str) to parameter IPC handles.

Returns:
bool: True if weights were successfully updated, False otherwise.
Expand All @@ -923,7 +925,7 @@ def update_weights_from_ipc_handles(self, data: dict[str, Any]) -> bool:
)

result_or_coro = self.llm.collective_rpc(
"update_weights_from_ipc_handles", args=(data,)
"update_weights_from_ipc_handles", args=(ipc_handles,)
)
worker_result = result_or_coro[0]

Expand All @@ -940,11 +942,13 @@ def update_weights_from_ipc_handles(self, data: dict[str, Any]) -> bool:
traceback.print_exc()
return False

async def update_weights_from_ipc_handles_async(self, data: dict[str, Any]) -> bool:
async def update_weights_from_ipc_handles_async(
self, ipc_handles: dict[str, Any]
) -> bool:
"""Async version of update_weights_from_ipc_handles.

Args:
data (dict): Dictionary mapping device UUIDs (str) to parameter IPC handles.
ipc_handles (dict): Dictionary mapping device UUIDs (str) to parameter IPC handles.

Returns:
bool: True if weights were successfully updated, False otherwise.
Expand All @@ -960,7 +964,7 @@ async def update_weights_from_ipc_handles_async(self, data: dict[str, Any]) -> b
)

result_or_coro = await self.llm.collective_rpc(
"update_weights_from_ipc_handles", args=(data,)
"update_weights_from_ipc_handles", args=(ipc_handles,)
)

if asyncio.iscoroutine(result_or_coro):
Expand All @@ -983,7 +987,7 @@ async def update_weights_from_ipc_handles_async(self, data: dict[str, Any]) -> b
traceback.print_exc()
return False

def update_weights_from_collective(self, data: dict[str, Any]) -> bool:
def update_weights_from_collective(self, info: dict[str, Any]) -> bool:
"""Update the model weights from collective communication."""
try:
assert self.llm is not None, (
Expand All @@ -996,7 +1000,7 @@ def update_weights_from_collective(self, data: dict[str, Any]) -> bool:
)

result_or_coro = self.llm.collective_rpc(
"update_weights_from_collective", args=(data,)
"update_weights_from_collective", args=(info,)
)
worker_result = result_or_coro[0]

Expand All @@ -1013,7 +1017,7 @@ def update_weights_from_collective(self, data: dict[str, Any]) -> bool:
traceback.print_exc()
return False

async def update_weights_from_collective_async(self, data: dict[str, Any]) -> bool:
async def update_weights_from_collective_async(self, info: dict[str, Any]) -> bool:
"""Async version of update_weights_from_collective."""
try:
assert self.llm is not None, (
Expand All @@ -1026,7 +1030,7 @@ async def update_weights_from_collective_async(self, data: dict[str, Any]) -> bo
)

result_or_coro = await self.llm.collective_rpc(
"update_weights_from_collective", args=(data,)
"update_weights_from_collective", args=(info,)
)

if asyncio.iscoroutine(result_or_coro):
Expand Down Expand Up @@ -1403,7 +1407,7 @@ def init_collective(
# Send world_size and rank for init collective to all workers
futures = self.worker_group.run_all_workers_multiple_data(
method_name,
data=rank_prefix_list,
rank_prefix=rank_prefix_list,
run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"],
common_kwargs={"ip": ip, "port": port, "world_size": world_size},
)
Expand All @@ -1429,7 +1433,7 @@ def generate(
)
future_bundle = self.worker_group.run_all_workers_sharded_data(
"generate",
sharded_data,
data=sharded_data,
in_sharded_axes=["data_parallel"],
replicate_on_axes=None, # just run on tp rank 0
output_is_replicated=None,
Expand Down Expand Up @@ -1474,7 +1478,7 @@ def generate_text(
)
future_bundle = self.worker_group.run_all_workers_sharded_data(
"generate_text",
sharded_data,
data=sharded_data,
in_sharded_axes=["data_parallel"],
replicate_on_axes=None, # just run on tp rank 0
output_is_replicated=None,
Expand Down Expand Up @@ -1708,7 +1712,7 @@ def update_weights(self, ipc_handles: dict[str, Any]) -> bool:
# Directly pass ipc_handles to the method
futures = self.worker_group.run_all_workers_multiple_data(
method_name,
ipc_handles_list,
ipc_handles=ipc_handles_list,
run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"],
)
# Wait for all futures to complete
Expand All @@ -1735,7 +1739,7 @@ def update_weights_from_collective(
# Use run_all_workers_single_data to send data to all workers
futures = self.worker_group.run_all_workers_single_data(
method_name,
data=info,
info=info,
run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"],
)

Expand Down
Loading
Loading