From 3163a2539ff2d65a7ed030c2cbdcc733c22b2259 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Tue, 24 Jun 2025 06:36:04 +0000 Subject: [PATCH 1/5] support args/kwargs in run_all_workers_multiple_data and run_all_workers_sharded_data Signed-off-by: Yuki Huang --- nemo_rl/distributed/worker_groups.py | 69 ++++++++++++++++++++++------ 1 file changed, 54 insertions(+), 15 deletions(-) diff --git a/nemo_rl/distributed/worker_groups.py b/nemo_rl/distributed/worker_groups.py index a283e6b18c..187e99783f 100644 --- a/nemo_rl/distributed/worker_groups.py +++ b/nemo_rl/distributed/worker_groups.py @@ -15,13 +15,12 @@ import os from copy import deepcopy from dataclasses import dataclass -from typing import Any, Iterable, Optional, Union +from typing import Any, Optional, Union import ray from ray.util.placement_group import PlacementGroup from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy -from nemo_rl.distributed.batched_data_dict import SlicedDataDict from nemo_rl.distributed.named_sharding import NamedSharding from nemo_rl.distributed.ray_actor_environment_registry import ( get_actor_python_env, @@ -590,25 +589,53 @@ def run_single_worker_single_data( def run_all_workers_multiple_data( self, method_name: str, - data: list[Any], + *args, run_rank_0_only_axes: list[str] | None = None, common_kwargs: Optional[dict[str, Any]] = None, + **kwargs, ) -> list[ray.ObjectRef]: """Run a method on all workers in parallel with different data. Args: method_name: Name of the method to call on each worker - data: List of data to pass to workers/groups + *args, **kwargs: List of arguments/keyword arguments to pass to workers/groups run_rank_0_only_axes: List of named axes for which only rank 0 should run the method. common_kwargs: Additional keyword arguments to pass to all workers Returns: list[ray.ObjectRef]: A list of ray futures """ + # Check at least one arg or kwarg is provided + assert len(args) > 0 or len(kwargs) > 0, ( + "At least one args (positional arguments) or kwargs (keyword arguments) must be provided in run_all_workers_multiple_data. " + "Otherwise, please use run_all_workers_single_data." + ) + + # Check all args and kwargs have the same length + args_count = [len(arg) for arg in args] + assert all(count == args_count[0] for count in args_count), ( + "All args must have the same length" + ) + args_count = args_count[0] if len(args_count) > 0 else 0 + + kwargs_count = [len(value) for value in kwargs.values()] + assert all(count == kwargs_count[0] for count in kwargs_count), ( + "All kwargs must have the same length" + ) + kwargs_count = kwargs_count[0] if len(kwargs_count) > 0 else 0 + + if args_count > 0 and kwargs_count > 0: + assert args_count == kwargs_count, ( + "The number of args and kwargs must be the same in run_all_workers_multiple_data. " + f"args length = {args_count}, kwargs length = {kwargs_count}" + ) + data_count = max(args_count, kwargs_count) + + # Check the data length is equal to the number of workers if run_rank_0_only_axes is None: - assert len(data) == len(self.workers), ( + assert data_count == len(self.workers), ( "data length should be equal to the number of workers: " - f"data length = {len(data)}, number of workers = {len(self.workers)}" + f"data length = {data_count}, number of workers = {len(self.workers)}" ) futures = [] @@ -633,12 +660,16 @@ def run_all_workers_multiple_data( if should_run: method = getattr(worker, method_name) - futures.append(method.remote(data=data[data_idx], **common_kwargs)) + worker_args = [arg[data_idx] for arg in args] + worker_kwargs = {key: value[data_idx] for key, value in kwargs.items()} + futures.append( + method.remote(*worker_args, **worker_kwargs, **common_kwargs) + ) data_idx += 1 - assert data_idx == len(data), ( + assert data_idx == data_count, ( "data length should be equal to the number of workers started: " - f"data length = {len(data)}, number of workers started = {data_idx}" + f"data length = {data_count}, number of workers started = {data_idx}" ) return futures @@ -686,12 +717,13 @@ def run_all_workers_single_data( def run_all_workers_sharded_data( self, method_name: str, - data: Iterable[SlicedDataDict], # arbitrary nested iterables of SlicedDataDicts + *args, in_sharded_axes: list[str] | None = None, replicate_on_axes: list[str] | None = None, output_is_replicated: list[str] | None = None, make_dummy_calls_to_free_axes: bool = False, common_kwargs: Optional[dict[str, Any]] = None, + **kwargs, ) -> MultiWorkerFuture: """Run a method on all workers in parallel with sharded data. @@ -701,7 +733,7 @@ def run_all_workers_sharded_data( Args: method_name: Name of the method to call on each worker - data: Iterable of SlicedDataDicts to pass to workers/groups + *args, **kwargs: List of arguments/keyword arguments to pass to workers/groups in_sharded_axes: List of axes that are sharded replicate_on_axes: List of axes that are to be replicated output_is_replicated: List of axes along which the output is replicated (and we should just return the first result). @@ -770,16 +802,21 @@ def run_all_workers_sharded_data( return_from_workers.append(worker_idx) if should_receive_data: + worker_args = args + worker_kwargs = kwargs # Find the appropriate data slice for this worker - worker_data = data for axis in in_sharded_axes: if axis in worker_coords: # Select the appropriate slice for this axis - worker_data = worker_data[worker_coords[axis]] + worker_args = [arg[worker_coords[axis]] for arg in worker_args] + worker_kwargs = { + key: value[worker_coords[axis]] + for key, value in worker_kwargs.items() + } # Call the method on the worker with its data slice future = getattr(worker, method_name).remote( - data=worker_data, **common_kwargs + *worker_args, **worker_kwargs, **common_kwargs ) futures.append(future) called_workers.append(worker_idx) @@ -787,8 +824,10 @@ def run_all_workers_sharded_data( # If this worker doesn't need data: if make_dummy_calls_to_free_axes: # If make_dummy_calls_to_free_axes is True, just call the method with None + worker_args = [None] * len(args) + worker_kwargs = {key: None for key in kwargs.keys()} future = getattr(worker, method_name).remote( - data=None, **common_kwargs + *worker_args, **worker_kwargs, **common_kwargs ) futures.append(future) called_workers.append(worker_idx) From 9dd9034e38f25e1bcbb80e0075e3d38165ba7d2c Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Tue, 24 Jun 2025 06:49:52 +0000 Subject: [PATCH 2/5] add and revert args name Signed-off-by: Yuki Huang --- nemo_rl/models/generation/vllm.py | 38 +++++++++++++++++------------- nemo_rl/models/policy/lm_policy.py | 8 +++---- 2 files changed, 25 insertions(+), 21 deletions(-) diff --git a/nemo_rl/models/generation/vllm.py b/nemo_rl/models/generation/vllm.py index f0cd5eb50b..b158f53e24 100644 --- a/nemo_rl/models/generation/vllm.py +++ b/nemo_rl/models/generation/vllm.py @@ -348,11 +348,13 @@ def _patch_vllm_init_workers_ray(): else: self.llm = vllm.LLM(**llm_kwargs) - def init_collective(self, data: int, ip: str, port: int, world_size: int) -> None: + def init_collective( + self, rank_prefix: int, ip: str, port: int, world_size: int + ) -> None: self.llm.collective_rpc( "init_collective", args=( - data, + rank_prefix, ip, port, world_size, @@ -360,12 +362,12 @@ def init_collective(self, data: int, ip: str, port: int, world_size: int) -> Non ) async def init_collective_async( - self, data: int, ip: str, port: int, world_size: int + self, rank_prefix: int, ip: str, port: int, world_size: int ) -> None: await self.llm.collective_rpc( "init_collective", args=( - data, + rank_prefix, ip, port, world_size, @@ -903,11 +905,11 @@ async def report_device_id_async(self) -> list[str]: return cast(list[str], list_of_worker_results) - def update_weights_from_ipc_handles(self, data: dict[str, Any]) -> bool: + def update_weights_from_ipc_handles(self, ipc_handles: dict[str, Any]) -> bool: """Update weights from IPC handles by delegating to the vLLM Worker implementation. Args: - data (dict): Dictionary mapping device UUIDs (str) to parameter IPC handles. + ipc_handles (dict): Dictionary mapping device UUIDs (str) to parameter IPC handles. Returns: bool: True if weights were successfully updated, False otherwise. @@ -923,7 +925,7 @@ def update_weights_from_ipc_handles(self, data: dict[str, Any]) -> bool: ) result_or_coro = self.llm.collective_rpc( - "update_weights_from_ipc_handles", args=(data,) + "update_weights_from_ipc_handles", args=(ipc_handles,) ) worker_result = result_or_coro[0] @@ -940,11 +942,13 @@ def update_weights_from_ipc_handles(self, data: dict[str, Any]) -> bool: traceback.print_exc() return False - async def update_weights_from_ipc_handles_async(self, data: dict[str, Any]) -> bool: + async def update_weights_from_ipc_handles_async( + self, ipc_handles: dict[str, Any] + ) -> bool: """Async version of update_weights_from_ipc_handles. Args: - data (dict): Dictionary mapping device UUIDs (str) to parameter IPC handles. + ipc_handles (dict): Dictionary mapping device UUIDs (str) to parameter IPC handles. Returns: bool: True if weights were successfully updated, False otherwise. @@ -960,7 +964,7 @@ async def update_weights_from_ipc_handles_async(self, data: dict[str, Any]) -> b ) result_or_coro = await self.llm.collective_rpc( - "update_weights_from_ipc_handles", args=(data,) + "update_weights_from_ipc_handles", args=(ipc_handles,) ) if asyncio.iscoroutine(result_or_coro): @@ -983,7 +987,7 @@ async def update_weights_from_ipc_handles_async(self, data: dict[str, Any]) -> b traceback.print_exc() return False - def update_weights_from_collective(self, data: dict[str, Any]) -> bool: + def update_weights_from_collective(self, info: dict[str, Any]) -> bool: """Update the model weights from collective communication.""" try: assert self.llm is not None, ( @@ -996,7 +1000,7 @@ def update_weights_from_collective(self, data: dict[str, Any]) -> bool: ) result_or_coro = self.llm.collective_rpc( - "update_weights_from_collective", args=(data,) + "update_weights_from_collective", args=(info,) ) worker_result = result_or_coro[0] @@ -1403,7 +1407,7 @@ def init_collective( # Send world_size and rank for init collective to all workers futures = self.worker_group.run_all_workers_multiple_data( method_name, - data=rank_prefix_list, + rank_prefix=rank_prefix_list, run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"], common_kwargs={"ip": ip, "port": port, "world_size": world_size}, ) @@ -1429,7 +1433,7 @@ def generate( ) future_bundle = self.worker_group.run_all_workers_sharded_data( "generate", - sharded_data, + data=sharded_data, in_sharded_axes=["data_parallel"], replicate_on_axes=None, # just run on tp rank 0 output_is_replicated=None, @@ -1474,7 +1478,7 @@ def generate_text( ) future_bundle = self.worker_group.run_all_workers_sharded_data( "generate_text", - sharded_data, + data=sharded_data, in_sharded_axes=["data_parallel"], replicate_on_axes=None, # just run on tp rank 0 output_is_replicated=None, @@ -1708,7 +1712,7 @@ def update_weights(self, ipc_handles: dict[str, Any]) -> bool: # Directly pass ipc_handles to the method futures = self.worker_group.run_all_workers_multiple_data( method_name, - ipc_handles_list, + ipc_handles=ipc_handles_list, run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"], ) # Wait for all futures to complete @@ -1735,7 +1739,7 @@ def update_weights_from_collective( # Use run_all_workers_single_data to send data to all workers futures = self.worker_group.run_all_workers_single_data( method_name, - data=info, + info=info, run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"], ) diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index 4d967a4cba..e469b32d16 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -207,7 +207,7 @@ def get_logprobs( futures = self.worker_group.run_all_workers_sharded_data( "get_logprobs", - sharded_data_2d, + data=sharded_data_2d, in_sharded_axes=["data_parallel", "context_parallel"], replicate_on_axes=["tensor_parallel", "pipeline_parallel"], output_is_replicated=["tensor_parallel", "pipeline_parallel"], @@ -263,7 +263,7 @@ def get_reference_policy_logprobs( futures = self.worker_group.run_all_workers_sharded_data( "get_reference_policy_logprobs", - sharded_data_2d, + data=sharded_data_2d, in_sharded_axes=["data_parallel", "context_parallel"], replicate_on_axes=["tensor_parallel", "pipeline_parallel"], output_is_replicated=["tensor_parallel", "pipeline_parallel"], @@ -313,7 +313,7 @@ def train( # Train each shard in parallel futures = self.worker_group.run_all_workers_sharded_data( "train", - sharded_data, + data=sharded_data, in_sharded_axes=["data_parallel"], replicate_on_axes=[ "context_parallel", @@ -365,7 +365,7 @@ def generate( sharded_data = data.shard_by_batch_size(dp_size, batch_size=None) futures = self.worker_group.run_all_workers_sharded_data( "generate", - sharded_data, + data=sharded_data, in_sharded_axes=["data_parallel"], replicate_on_axes=["tensor_parallel", "pipeline_parallel"], output_is_replicated=["tensor_parallel", "pipeline_parallel"], From 98defdeb84871a75451df9e0b483097241075544 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Tue, 24 Jun 2025 07:06:25 +0000 Subject: [PATCH 3/5] update comment Signed-off-by: Yuki Huang --- nemo_rl/distributed/worker_groups.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/nemo_rl/distributed/worker_groups.py b/nemo_rl/distributed/worker_groups.py index 187e99783f..c93b1e35da 100644 --- a/nemo_rl/distributed/worker_groups.py +++ b/nemo_rl/distributed/worker_groups.py @@ -598,9 +598,12 @@ def run_all_workers_multiple_data( Args: method_name: Name of the method to call on each worker - *args, **kwargs: List of arguments/keyword arguments to pass to workers/groups + *args: List of arguments to pass to workers/groups + e.g. [[arg1_for_worker_1, arg1_for_worker_2], [arg2_for_worker_1, arg2_for_worker_2]] run_rank_0_only_axes: List of named axes for which only rank 0 should run the method. - common_kwargs: Additional keyword arguments to pass to all workers + common_kwargs: Keyword arguments to pass to all workers + **kwargs: Keyword arguments to pass to workers/groups + e.g. {"key1": [value_for_worker_1, value_for_worker_2], "key2": [value_for_worker_1, value_for_worker_2]} Returns: list[ray.ObjectRef]: A list of ray futures @@ -733,14 +736,18 @@ def run_all_workers_sharded_data( Args: method_name: Name of the method to call on each worker - *args, **kwargs: List of arguments/keyword arguments to pass to workers/groups + *args: List of arguments to pass to workers/groups + e.g. [[arg1_for_worker_1, arg1_for_worker_2], [arg2_for_worker_1, arg2_for_worker_2]] in_sharded_axes: List of axes that are sharded replicate_on_axes: List of axes that are to be replicated output_is_replicated: List of axes along which the output is replicated (and we should just return the first result). We also just return from rank 0 of free axes. make_dummy_calls_to_free_axes: Whether to make dummy calls (with None) to workers that aren't rank 0 on 'free axes' (axes not in in_sharded_axes or replicate_on_axes). - common_kwargs: Additional keyword arguments to pass to all workers + common_kwargs: Keyword arguments to pass to all workers + **kwargs: Keyword arguments to pass to workers/groups + e.g. {"key1": [value_for_worker_1, value_for_worker_2], "key2": [value_for_worker_1, value_for_worker_2]} + Returns: MultiWorkerFuture: Object containing futures and their associated worker information """ @@ -802,9 +809,9 @@ def run_all_workers_sharded_data( return_from_workers.append(worker_idx) if should_receive_data: + # Find the appropriate data slice for this worker worker_args = args worker_kwargs = kwargs - # Find the appropriate data slice for this worker for axis in in_sharded_axes: if axis in worker_coords: # Select the appropriate slice for this axis From 347d02c867418e705cf9bf14e5eeea6cc71c496a Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Mon, 30 Jun 2025 08:31:29 +0000 Subject: [PATCH 4/5] add assert and fix rebase Signed-off-by: Yuki Huang --- nemo_rl/distributed/worker_groups.py | 24 ++++++++++++++++++++++++ nemo_rl/models/generation/vllm.py | 4 ++-- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/nemo_rl/distributed/worker_groups.py b/nemo_rl/distributed/worker_groups.py index c93b1e35da..c2e849cbee 100644 --- a/nemo_rl/distributed/worker_groups.py +++ b/nemo_rl/distributed/worker_groups.py @@ -582,6 +582,12 @@ def run_single_worker_single_data( Returns: ray.ObjectRef: A Ray future for the result. """ + assert len(args) == 0, ( + "run_single_worker_single_data will fail with args under certain circumstances. " + "Please use kwargs instead. " + "See https://github.com/NVIDIA-NeMo/RL/issues/582 for more details." + ) + worker = self.workers[worker_idx] method = getattr(worker, method_name) return method.remote(*args, **kwargs) @@ -608,6 +614,12 @@ def run_all_workers_multiple_data( Returns: list[ray.ObjectRef]: A list of ray futures """ + assert len(args) == 0, ( + "run_all_workers_multiple_data will fail with args under certain circumstances. " + "Please use kwargs instead. " + "See https://github.com/NVIDIA-NeMo/RL/issues/582 for more details." + ) + # Check at least one arg or kwarg is provided assert len(args) > 0 or len(kwargs) > 0, ( "At least one args (positional arguments) or kwargs (keyword arguments) must be provided in run_all_workers_multiple_data. " @@ -694,6 +706,12 @@ def run_all_workers_single_data( Returns: list[ray.ObjectRef]: A list of ray futures """ + assert len(args) == 0, ( + "run_all_workers_single_data will fail with args under certain circumstances. " + "Please use kwargs instead. " + "See https://github.com/NVIDIA-NeMo/RL/issues/582 for more details." + ) + futures = [] if run_rank_0_only_axes is None: @@ -751,6 +769,12 @@ def run_all_workers_sharded_data( Returns: MultiWorkerFuture: Object containing futures and their associated worker information """ + assert len(args) == 0, ( + "run_all_workers_sharded_data will fail with args under certain circumstances. " + "Please use kwargs instead. " + "See https://github.com/NVIDIA-NeMo/RL/issues/582 for more details." + ) + if self.sharding_annotations is None: raise ValueError( "Sharding annotations must be provided to use sharded data distribution" diff --git a/nemo_rl/models/generation/vllm.py b/nemo_rl/models/generation/vllm.py index b158f53e24..7dbfbd3ea8 100644 --- a/nemo_rl/models/generation/vllm.py +++ b/nemo_rl/models/generation/vllm.py @@ -1017,7 +1017,7 @@ def update_weights_from_collective(self, info: dict[str, Any]) -> bool: traceback.print_exc() return False - async def update_weights_from_collective_async(self, data: dict[str, Any]) -> bool: + async def update_weights_from_collective_async(self, info: dict[str, Any]) -> bool: """Async version of update_weights_from_collective.""" try: assert self.llm is not None, ( @@ -1030,7 +1030,7 @@ async def update_weights_from_collective_async(self, data: dict[str, Any]) -> bo ) result_or_coro = await self.llm.collective_rpc( - "update_weights_from_collective", args=(data,) + "update_weights_from_collective", args=(info,) ) if asyncio.iscoroutine(result_or_coro): From d5b6b5283df7dbe3262f644d05c2b3da691e3164 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Mon, 30 Jun 2025 09:38:34 +0000 Subject: [PATCH 5/5] fix unit test, add run_single_worker_single_data test Signed-off-by: Yuki Huang --- tests/unit/distributed/test_worker_groups.py | 92 +++++++++++++++++--- 1 file changed, 81 insertions(+), 11 deletions(-) diff --git a/tests/unit/distributed/test_worker_groups.py b/tests/unit/distributed/test_worker_groups.py index 53b6133c69..12131fe4a4 100644 --- a/tests/unit/distributed/test_worker_groups.py +++ b/tests/unit/distributed/test_worker_groups.py @@ -328,6 +328,48 @@ def test_configure_worker_interaction(register_test_actor, virtual_cluster): worker_group.shutdown(force=True) +def test_run_single_worker_single_data(worker_group_1d_sharding): + worker_group = worker_group_1d_sharding + assert len(worker_group.workers) == 2 + ray.get([w.reset_call_records.remote() for w in worker_group.workers]) + + data_for_worker0 = SlicedDataDict({"id": 0, "val": "w0_val"}) + data_for_worker1 = SlicedDataDict({"id": 1, "val": "w1_val"}) + + # pass through args + # due to https://github.com/NVIDIA-NeMo/RL/issues/582, args are not supported. + with pytest.raises(AssertionError): + future_0 = worker_group.run_single_worker_single_data( + "record_call", 0, data_for_worker0 + ) + future_1 = worker_group.run_single_worker_single_data( + "record_call", 1, data_for_worker1 + ) + ray.get([future_0, future_1]) + + # pass through kwargs + future_0 = worker_group.run_single_worker_single_data( + "record_call", 0, data=data_for_worker0 + ) + future_1 = worker_group.run_single_worker_single_data( + "record_call", 1, data=data_for_worker1 + ) + results = ray.get([future_0, future_1]) + assert len(results) == 2 + + # Check worker 0 + d, args, _, count = ray.get(worker_group.workers[0].get_recorded_data.remote()) + assert count == 1 + assert d == data_for_worker0 + assert args == () + + # Check worker 1 + d, args, _, count = ray.get(worker_group.workers[1].get_recorded_data.remote()) + assert count == 1 + assert d == data_for_worker1 + assert args == () + + def test_run_all_workers_single_data_1d_sharding(worker_group_1d_sharding): worker_group = worker_group_1d_sharding assert len(worker_group.workers) == 2 @@ -339,17 +381,26 @@ def test_run_all_workers_single_data_1d_sharding(worker_group_1d_sharding): test_arg1 = "arg_single" test_kwarg1 = "kwarg_single_val" + # pass through args + # due to https://github.com/NVIDIA-NeMo/RL/issues/582, args are not supported. + with pytest.raises(AssertionError): + futures = worker_group.run_all_workers_single_data( + "record_call", test_data, test_arg1 + ) + ray.get(futures) + + # pass through kwargs futures = worker_group.run_all_workers_single_data( - "record_call", test_data, test_arg1, kwarg1=test_kwarg1 + "record_call", data=test_data, kwarg1=test_kwarg1 ) results = ray.get(futures) assert len(results) == 2 # Should run on all 2 workers - for i, worker in enumerate(worker_group.workers): + for worker in worker_group.workers: data, args, kwargs, count = ray.get(worker.get_recorded_data.remote()) assert count == 1 assert data == test_data - assert args == (test_arg1,) + assert args == () assert kwargs == {"kwarg1": test_kwarg1} @@ -359,7 +410,7 @@ def test_run_all_workers_single_data_2d_sharding_no_filter(worker_group_2d_shard ray.get([w.reset_call_records.remote() for w in worker_group.workers]) test_data = SlicedDataDict({"key": "value_2d_no_filter"}) - futures = worker_group.run_all_workers_single_data("record_call", test_data) + futures = worker_group.run_all_workers_single_data("record_call", data=test_data) results = ray.get(futures) assert len(results) == 4 # Runs on all 4 workers @@ -377,7 +428,7 @@ def test_run_all_workers_single_data_2d_sharding_filter_tp(worker_group_2d_shard test_data = SlicedDataDict({"key": "value_2d_filter_tp"}) # Only run on tp rank 0 for each dp rank futures = worker_group.run_all_workers_single_data( - "record_call", test_data, run_rank_0_only_axes=["tp"] + "record_call", data=test_data, run_rank_0_only_axes=["tp"] ) results = ray.get(futures) assert len(results) == 2 # Runs on 2 workers (dp0-tp0, dp1-tp0) @@ -403,7 +454,7 @@ def test_run_all_workers_single_data_2d_sharding_filter_dp_tp(worker_group_2d_sh test_data = SlicedDataDict({"key": "value_2d_filter_dp_tp"}) # Only run on dp rank 0 AND tp rank 0 futures = worker_group.run_all_workers_single_data( - "record_call", test_data, run_rank_0_only_axes=["dp", "tp"] + "record_call", data=test_data, run_rank_0_only_axes=["dp", "tp"] ) results = ray.get(futures) assert len(results) == 1 # Runs on 1 worker (dp0-tp0) @@ -430,8 +481,17 @@ def test_run_all_workers_multiple_data_1d_sharding(worker_group_1d_sharding): multi_data = [data_for_worker0, data_for_worker1] common_arg = "common_arg_multi" + # pass through args + # due to https://github.com/NVIDIA-NeMo/RL/issues/582, args are not supported. + with pytest.raises(AssertionError): + futures = worker_group.run_all_workers_multiple_data( + "record_call", multi_data, common_kwargs={"common": common_arg} + ) + ray.get(futures) + + # pass through kwargs futures = worker_group.run_all_workers_multiple_data( - "record_call", multi_data, common_kwargs={"common": common_arg} + "record_call", data=multi_data, common_kwargs={"common": common_arg} ) results = ray.get(futures) assert len(results) == 2 @@ -462,10 +522,11 @@ def test_run_all_workers_multiple_data_fewer_data_than_workers( data_for_worker1 = SlicedDataDict({"id": 1}) multi_data = [data_for_worker0, data_for_worker1] # Only 2 data items - with pytest.raises( - AssertionError, match="data length should be equal to the number of workers: " - ): - futures = worker_group.run_all_workers_multiple_data("record_call", multi_data) + with pytest.raises(AssertionError): + futures = worker_group.run_all_workers_multiple_data( + "record_call", data=multi_data + ) + ray.get(futures) def test_run_all_workers_sharded_data_1d(worker_group_1d_sharding): @@ -479,6 +540,15 @@ def test_run_all_workers_sharded_data_1d(worker_group_1d_sharding): SlicedDataDict({"shard": 1, "val": "val1"}), ] + # pass through args + # due to https://github.com/NVIDIA-NeMo/RL/issues/582, args are not supported. + with pytest.raises(AssertionError): + future_bundle = worker_group.run_all_workers_sharded_data( + "record_call", sharded_data_input, in_sharded_axes=["data"] + ) + worker_group.get_all_worker_results(future_bundle) + + # pass through kwargs future_bundle = worker_group.run_all_workers_sharded_data( "record_call", data=sharded_data_input, in_sharded_axes=["data"] )