From 43fd588263de4815e83ad805d8f6b6a9228db941 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Mon, 9 Jun 2025 22:47:12 -0700 Subject: [PATCH 1/8] support non colocated in unit test with fake rank Signed-off-by: Yuki Huang --- nemo_rl/algorithms/grpo.py | 42 +++++++++--- nemo_rl/models/generation/interfaces.py | 10 +++ nemo_rl/models/generation/vllm.py | 66 ++++++++++++++++++- nemo_rl/models/generation/vllm_backend.py | 24 +++++++ .../models/policy/dtensor_policy_worker.py | 35 ++++++++++ nemo_rl/models/policy/hf_policy.py | 23 +++++++ nemo_rl/models/policy/interfaces.py | 12 ++++ .../models/generation/test_vllm_generation.py | 58 ++++++---------- 8 files changed, 220 insertions(+), 50 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index a8428bdce2..4b76ba0bee 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -212,7 +212,9 @@ def setup( # Cluster # ========================== print("\n▶ Setting up compute cluster...") - colocated_inference = generation_config["backend"] != "hf" + colocated_inference = generation_config["colocated"] + if generation_config["backend"] == "hf": + colocated_inference = False cluster = RayVirtualCluster( name="grpo_policy_cluster", bundle_ct_per_node_list=[cluster_config["gpus_per_node"]] @@ -258,6 +260,11 @@ def setup( init_optimizer=True, ) + if generation_config["colocated"]: + world_size = cluster_config["num_nodes"] * cluster_config["gpus_per_node"] + policy.init_collective(world_size) + policy_generation.init_collective(world_size) + loss_fn = ClippedPGLossFn(loss_config) print("\n" + "=" * 60) @@ -286,6 +293,7 @@ def setup( def refit_policy_generation( policy: ColocatablePolicyInterface, policy_generation: GenerationInterface, + is_colocated: bool = True, _refit_buffer_size_gb: Optional[int] = None, ) -> None: """Refit the policy generation interface with the latest policy weights. @@ -299,20 +307,34 @@ def refit_policy_generation( """ policy.offload_before_refit() policy_generation.prepare_for_generation(tags=["weights"]) - # get model param keys, which is grouped by size - grouped_param_keys = policy.prepare_weights_for_ipc( - _refit_buffer_size_gb=_refit_buffer_size_gb - ) - # do update - for keys in grouped_param_keys: - ipc_handles = policy.get_weights_ipc_handles(keys) - if not policy_generation.update_weights(ipc_handles): + + # update weights + if is_colocated: + # get model param keys, which is grouped by size + grouped_param_keys = policy.prepare_weights_for_ipc( + _refit_buffer_size_gb=_refit_buffer_size_gb + ) + # do update + for keys in grouped_param_keys: + ipc_handles = policy.get_weights_ipc_handles(keys) + if not policy_generation.update_weights(ipc_handles): + error_message = ( + "❌ Error: Updating weights for the generation policy failed during refit.\n" + "This often indicates an issue with cuda-ipc or " + "a problem within the generation backend (e.g., vLLM worker).\n" + ) + raise RuntimeError(error_message) + else: + state_dict_info = policy.prepare_info_for_collective() + policy.broadcast_weights_for_collective() + if not policy_generation.update_weights_from_collective(state_dict_info): error_message = ( "❌ Error: Updating weights for the generation policy failed during refit.\n" - "This often indicates an issue with cuda-ipc or " + "This often indicates an issue with nccl or " "a problem within the generation backend (e.g., vLLM worker).\n" ) raise RuntimeError(error_message) + policy.offload_after_refit() policy_generation.prepare_for_generation(tags=["kv_cache"]) diff --git a/nemo_rl/models/generation/interfaces.py b/nemo_rl/models/generation/interfaces.py index 8e7a0c1a7e..9e4fdbca52 100644 --- a/nemo_rl/models/generation/interfaces.py +++ b/nemo_rl/models/generation/interfaces.py @@ -102,6 +102,7 @@ class GenerationConfig(TypedDict): """Configuration for generation.""" backend: str + colocated: bool max_new_tokens: int temperature: float top_p: float @@ -196,6 +197,11 @@ class GenerationOutputSpec(TypedDict): class GenerationInterface(ABC): """Abstract base class defining the interface for RL policies.""" + @abstractmethod + def init_collective(self, world_size: int) -> None: + """Initialize the collective communication.""" + pass + @abstractmethod def generate( self, data: BatchedDataDict["GenerationDatumSpec"], greedy: bool @@ -213,3 +219,7 @@ def finish_generation(self, *args: Any, **kwargs: Any) -> bool: def update_weights(self, ipc_handles: dict[str, Any]) -> bool: """Update the model weights from the given IPC handles.""" raise NotImplementedError + + def update_weights_from_collective(self, info: dict[str, Any]) -> bool: + """Update the model weights from collective communication.""" + raise NotImplementedError diff --git a/nemo_rl/models/generation/vllm.py b/nemo_rl/models/generation/vllm.py index 0287fb8b63..658741a4de 100644 --- a/nemo_rl/models/generation/vllm.py +++ b/nemo_rl/models/generation/vllm.py @@ -342,6 +342,9 @@ def _patch_vllm_init_workers_ray(): else: self.llm = vllm.LLM(**llm_kwargs) + def init_collective(self, world_size: int) -> None: + self.llm.collective_rpc("init_collective", args=(world_size,)) + def llm(self): return self.llm @@ -925,6 +928,34 @@ async def update_weights_from_ipc_handles_async( traceback.print_exc() return False + def update_weights_from_collective(self, info: dict[str, Any]) -> bool: + """Update the model weights from collective communication.""" + try: + assert self.llm is not None, ( + "Attempting to update weights with either an uninitialized vLLM or non-model-owner" + ) + + if self.cfg["vllm_cfg"]["async_engine"]: + raise RuntimeError( + "update_weights_from_collective cannot be used with async_engine=True. Use update_weights_from_ipc_handles_async instead." + ) + + result_or_coro = self.llm.collective_rpc("update_weights_from_collective", args=(info,)) + worker_result = result_or_coro[0] + + if not worker_result: + print( + f"Error: Worker failed to update weights. Result: {worker_result}" + ) + return False + return True + except Exception as e: + print(f"Exception during collective_rpc for weight update: {e}") + import traceback + + traceback.print_exc() + return False + def sleep(self): """Put the vLLM engine to sleep.""" assert self.llm is not None, ( @@ -1206,6 +1237,23 @@ def _report_device_id(self) -> list[list[str]]: results = ray.get(futures) return results + def init_collective(self, world_size: int) -> None: + """Initialize the collective communication.""" + try: + # Use run_all_workers_single_data to send data to all workers + print("[init_collective] in vllm") + futures = self.worker_group.run_all_workers_single_data( + "init_collective", + world_size=world_size, + run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"] + ) + # Wait for all futures to complete + results = ray.get(futures) + return all(result for result in results if result is not None) + except Exception as e: + print(f"Error during init collective: {e}") + return False + def generate( self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False ) -> BatchedDataDict[GenerationOutputSpec]: @@ -1437,7 +1485,23 @@ def update_weights(self, ipc_handles: dict[str, Any]) -> bool: results = ray.get(futures) return all(result for result in results if result is not None) except Exception as e: - print(f"Error updating weights: {e}") + print(f"Error during update weights: {e}") + return False + + def update_weights_from_collective(self, info: dict[str, Any]) -> bool: + """Update weights of the policy using collective communication.""" + try: + # Use run_all_workers_single_data to send data to all workers + futures = self.worker_group.run_all_workers_single_data( + "update_weights_from_collective", + info=info, + run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"] + ) + # Wait for all futures to complete + results = ray.get(futures) + return all(result for result in results if result is not None) + except Exception as e: + print(f"Error during update weights from collective: {e}") return False def __del__(self) -> None: diff --git a/nemo_rl/models/generation/vllm_backend.py b/nemo_rl/models/generation/vllm_backend.py index 5305fd6bce..455c454469 100644 --- a/nemo_rl/models/generation/vllm_backend.py +++ b/nemo_rl/models/generation/vllm_backend.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch +from typing import Any try: import vllm # noqa: F401 @@ -24,6 +25,12 @@ class VllmInternalWorkerExtension: + def init_collective(self, world_size: int) -> None: + """Initialize the collective communication.""" + import ray.util.collective as collective + + collective.init_collective_group(world_size=world_size, rank=1, backend="nccl", group_name="refit") + def report_device_id(self) -> str: from nemo_rl.utils.nvml import get_device_uuid @@ -63,3 +70,20 @@ def update_weights_from_ipc_handles(self, ipc_handles): f"Error in VllmInternalWorkerExtension.update_weights_from_ipc_handles: {e}" ) return False + + def update_weights_from_collective(self, info: dict[str, Any]) -> None: + """Update the model weights from collective communication.""" + import ray.util.collective as collective + + try: + for name, (shape, dtype) in info.items(): + weight = torch.empty(shape, dtype=dtype, device="cuda") + collective.broadcast(weight, 0, group_name="refit") + self.model_runner.model.load_weights(weights=[(name, weight)]) + except Exception as e: + print( + f"Error in VllmInternalWorkerExtension.update_weights_from_collective: {e}" + ) + return False + + return True diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index d13fc23c67..e263dfbf62 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -262,6 +262,12 @@ def __init__( "No weights path provided. Starting from scratch (default policy init)" ) + def init_collective(self, world_size: int) -> None: + """Initialize the collective communication.""" + import ray.util.collective as collective + + collective.init_collective_group(world_size=world_size, rank=0, backend="nccl", group_name="refit") + def is_alive(self) -> bool: return True @@ -753,6 +759,35 @@ def get_weights_ipc_handles(self, keys: Iterable[str]) -> dict[str, Any]: return {device_uuid: all_handles} + @torch.no_grad() + def prepare_info_for_collective(self) -> dict[str, Any]: + """Prepare the info for collective communication. + + Returns: + dict: A dictionary containing the info for collective communication. + """ + # Get state_dict + self.model = self.move_to_cuda(self.model) + state_dict = self.model.state_dict() + + # Collect info for collective communication + state_dict_info = {} + for name, tensor in state_dict.items(): + state_dict_info[name] = (tensor.shape, self.dtype) + + return state_dict_info + + @torch.no_grad() + def broadcast_weights_for_collective(self) -> None: + """Broadcast the weights for collective communication.""" + import ray.util.collective as collective + + for _, tensor in self.model.state_dict().items(): + if isinstance(tensor, DTensor): + tensor = tensor.full_tensor() + tensor = tensor.to(self.dtype, non_blocking=True) + collective.broadcast(tensor.data, 0, group_name="refit") + def prepare_for_lp_inference(self) -> None: if not self.cpu_offload: self.move_to_cuda(self.model) diff --git a/nemo_rl/models/policy/hf_policy.py b/nemo_rl/models/policy/hf_policy.py index 8cfe0a6e0c..0af89d61f8 100644 --- a/nemo_rl/models/policy/hf_policy.py +++ b/nemo_rl/models/policy/hf_policy.py @@ -119,6 +119,10 @@ def __init__( self.cfg = config + def init_collective(self, world_size: int) -> None: + """Initialize the collective communication.""" + self.worker_group.run_all_workers_single_data("init_collective", world_size=world_size) + def get_logprobs( self, data: BatchedDataDict[GenerationDatumSpec] ) -> BatchedDataDict[LogprobOutputSpec]: @@ -402,6 +406,25 @@ def get_weights_ipc_handles(self, keys: list[str]) -> dict[str, Any]: return all_handles + def prepare_info_for_collective(self) -> dict[str, Any]: + """Prepare the info for collective communication. + + Returns: + dict: A dictionary containing the info for collective communication. + """ + futures = self.worker_group.run_all_workers_single_data( + "prepare_info_for_collective" + ) + results = ray.get(futures) + # Only get the first worker's info since all workers will have the same result + return results[0] + + def broadcast_weights_for_collective(self) -> None: + """Broadcast the weights for collective communication.""" + self.worker_group.run_all_workers_single_data( + "broadcast_weights_for_collective" + ) + def offload_before_refit(self) -> None: """Offload the optimizer and buffers to the CPU.""" futures = self.worker_group.run_all_workers_single_data("offload_before_refit") diff --git a/nemo_rl/models/policy/interfaces.py b/nemo_rl/models/policy/interfaces.py index d4b51b0cb3..ef3747ce6f 100644 --- a/nemo_rl/models/policy/interfaces.py +++ b/nemo_rl/models/policy/interfaces.py @@ -93,6 +93,10 @@ def shutdown(self) -> bool: class ColocatablePolicyInterface(PolicyInterface): + @abstractmethod + def init_collective(self, world_size: int) -> None: + pass + @abstractmethod def offload_before_refit(self) -> None: pass @@ -108,3 +112,11 @@ def prepare_weights_for_ipc(self, *args: Any, **kwargs: Any) -> list[list[str]]: @abstractmethod def get_weights_ipc_handles(self, keys: list[str]) -> dict[str, Any]: pass + + @abstractmethod + def prepare_info_for_collective(self) -> dict[str, Any]: + pass + + @abstractmethod + def broadcast_weights_for_collective(self) -> None: + pass diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index b8225e62ee..717b44ff0a 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -33,6 +33,7 @@ # Define basic vLLM test config basic_vllm_test_config: VllmConfig = { "backend": "vllm", + "colocated": True, "model_name": model_name, "tokenizer": { "name": model_name, @@ -1040,7 +1041,7 @@ def test_vllm_non_divisible_batch_handling(policy): ) -def test_vllm_refit_non_collocated_handles_update_failure( +def test_vllm_refit_non_collocated_handles_update( policy_cluster_separate, generation_cluster_separate, tokenizer, @@ -1055,6 +1056,7 @@ def test_vllm_refit_non_collocated_handles_update_failure( ) # Create HfPolicy on its own cluster + os.environ["NCCL_CUMEM_ENABLE"] = "0" hf_config = get_basic_hf_test_config(enable_dtensor=True) hf_config["dtensor_cfg"]["tensor_parallel_size"] = 1 hf_policy = HfPolicy(policy_cluster_separate, hf_config, tokenizer) @@ -1063,45 +1065,23 @@ def test_vllm_refit_non_collocated_handles_update_failure( vllm_config = deepcopy(basic_vllm_test_config) vllm_config = configure_generation_config(vllm_config, tokenizer, is_eval=True) vllm_config["vllm_cfg"]["tensor_parallel_size"] = 1 - vllm_policy = VllmGeneration(generation_cluster_separate, vllm_config) + vllm_generation = VllmGeneration(generation_cluster_separate, vllm_config) - hf_policy_instance = None - vllm_policy_instance = None + hf_policy.init_collective(2) + vllm_generation.init_collective(2) - try: - hf_policy_instance = hf_policy - vllm_policy_instance = vllm_policy - ray.get( - [ - worker._add_noise_to_weights.remote() - for worker in hf_policy_instance.worker_group.workers - ] - ) - print("Refitting vLLM policy from HF policy (non-collocated)") - with mock.patch.object( - vllm_policy_instance, "update_weights", return_value=False - ): - with pytest.raises(RuntimeError): - refit_policy_generation( - hf_policy_instance, - vllm_policy_instance, - ) - print("RuntimeError during refit correctly caught.") + print("refitting vllm policy...") + refit_policy_generation(hf_policy, vllm_generation, is_colocated=False) - finally: - print("Cleaning up non-collocated test resources...") - if hf_policy_instance: - try: - hf_policy_instance.shutdown() - except Exception as e: - print(f"Error during HfPolicy cleanup: {e}") - if vllm_policy_instance: - try: - vllm_policy_instance.shutdown() - except Exception as e: - print(f"Error during VllmPolicy cleanup: {e}") - # Force garbage collection - import gc + # test generate + outputs = vllm_generation.generate(test_input_data, greedy=True) + output_ids = outputs["output_ids"] + generated_texts = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + assert generated_texts == [ + "Hello, my name is Lina. I'm", + "The capital of France is Paris. The capital of", + ], "Output should be the same as the expected output" - gc.collect() - torch.cuda.empty_cache() + # Clean up + vllm_generation.shutdown() + hf_policy.shutdown() From 10f5b3e546f32847100c33ab6535a3d657c2fc92 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Mon, 9 Jun 2025 22:48:37 -0700 Subject: [PATCH 2/8] correct rank and add inference gpu settings Signed-off-by: Yuki Huang --- examples/configs/grpo_math_1B.yaml | 6 + nemo_rl/algorithms/grpo.py | 157 +++++++++++++----- nemo_rl/models/generation/interfaces.py | 8 +- nemo_rl/models/generation/vllm.py | 74 +++++---- nemo_rl/models/generation/vllm_backend.py | 14 +- .../models/policy/dtensor_policy_worker.py | 17 +- nemo_rl/models/policy/hf_policy.py | 14 +- nemo_rl/models/policy/interfaces.py | 5 +- .../models/generation/test_vllm_generation.py | 34 +++- 9 files changed, 229 insertions(+), 100 deletions(-) diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 2f9b6535e4..6200d223c4 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -106,6 +106,12 @@ policy: pipeline_parallel_size: 1 gpu_memory_utilization: 0.6 max_model_len: ${policy.max_total_sequence_length} + colocated: + enabled: true + # decides how many GPUs to use for inference + # used when train and inference is not colocated + gpus_per_node: -1 # used when cluster.num_nodes is 1 + num_nodes: -1 # used when cluster.num_nodes > 1 data: max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 4b76ba0bee..a7a97110e9 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -13,9 +13,10 @@ # limitations under the License. import os from pathlib import Path -from typing import Any, Optional, TypedDict, cast +from typing import Any, Optional, Tuple, TypedDict, cast import numpy as np +import ray import torch from torchdata.stateful_dataloader import StatefulDataLoader from transformers import PreTrainedTokenizerBase @@ -119,7 +120,7 @@ def setup( ) -> tuple[ ColocatablePolicyInterface, Optional[GenerationInterface], - RayVirtualCluster, + Tuple[RayVirtualCluster, RayVirtualCluster], StatefulDataLoader, Optional[StatefulDataLoader], ClippedPGLossFn, @@ -137,7 +138,6 @@ def setup( policy_config = master_config["policy"] generation_config = master_config["policy"]["generation"] loss_config = master_config["loss_fn"] - data_config = master_config["data"] grpo_config = master_config["grpo"] logger_config = master_config["logger"] cluster_config = master_config["cluster"] @@ -212,18 +212,74 @@ def setup( # Cluster # ========================== print("\n▶ Setting up compute cluster...") - colocated_inference = generation_config["colocated"] - if generation_config["backend"] == "hf": - colocated_inference = False - cluster = RayVirtualCluster( - name="grpo_policy_cluster", - bundle_ct_per_node_list=[cluster_config["gpus_per_node"]] - * cluster_config["num_nodes"], - use_gpus=True, - num_gpus_per_node=cluster_config["gpus_per_node"], - max_colocated_worker_groups=2 if colocated_inference else 1, - ) - print(f" ✓ Ray cluster initialized with {cluster_config['num_nodes']} nodes") + colocated_inference = generation_config["colocated"]["enabled"] + + if colocated_inference: + cluster = RayVirtualCluster( + name="grpo_policy_cluster", + bundle_ct_per_node_list=[cluster_config["gpus_per_node"]] + * cluster_config["num_nodes"], + use_gpus=True, + num_gpus_per_node=cluster_config["gpus_per_node"], + max_colocated_worker_groups=1 + if generation_config["backend"] == "hf" + else 2, + ) + train_cluster = cluster + inference_cluster = cluster + print(f" ✓ Ray cluster initialized with {cluster_config['num_nodes']} nodes") + + else: + assert generation_config["backend"] != "hf", ( + "Non-colocated inference is not supported for HF backend" + ) + + train_gpus_per_node = cluster_config["gpus_per_node"] + train_nodes = cluster_config["num_nodes"] + inference_gpus_per_node = generation_config["colocated"]["gpus_per_node"] + inference_nodes = generation_config["colocated"]["num_nodes"] + + # validate and configure generation.colocated config + if cluster_config["num_nodes"] == 1: + assert inference_gpus_per_node > 0 and inference_nodes == -1, ( + "policy.generation.colocated.gpus_per_node must be set " + "and policy.generation.colocated.num_nodes must be -1 " + "when cluster.num_nodes=1 and inference is not colocated" + ) + inference_nodes = 1 + train_gpus_per_node -= inference_gpus_per_node + else: + assert inference_gpus_per_node == -1 and inference_nodes > 0, ( + "policy.generation.colocated.gpus_per_node must be -1 and " + "policy.generation.colocated.num_nodes must be set " + "when cluster.num_nodes > 1 and inference is not colocated" + ) + inference_gpus_per_node = cluster_config["gpus_per_node"] + train_nodes -= inference_nodes + + # initialize train cluster + train_cluster = RayVirtualCluster( + name="grpo_train_cluster", + bundle_ct_per_node_list=[train_gpus_per_node] * train_nodes, + use_gpus=True, + num_gpus_per_node=train_gpus_per_node, + max_colocated_worker_groups=1, + ) + print( + f" ✓ Ray train cluster initialized with {train_nodes} nodes with {train_gpus_per_node} GPUs per node" + ) + + # initialize inference cluster + inference_cluster = RayVirtualCluster( + name="grpo_inference_cluster", + bundle_ct_per_node_list=[inference_gpus_per_node] * inference_nodes, + use_gpus=True, + num_gpus_per_node=inference_gpus_per_node, + max_colocated_worker_groups=1, + ) + print( + f" ✓ Ray inference cluster initialized with {inference_nodes} nodes with {inference_gpus_per_node} GPUs per node" + ) # ========================== # Training and Inference @@ -239,7 +295,9 @@ def setup( print(f" ✓ Using HF backend for generation with {policy_config['model_name']}") elif backend == "vllm": generation_config = cast(VllmConfig, generation_config) - policy_generation = VllmGeneration(cluster=cluster, config=generation_config) + policy_generation = VllmGeneration( + cluster=inference_cluster, config=generation_config + ) # Worker groups are not initialized until the first call to run something on workergroups. # vllm 0.8 fails in initialization if its called in the first training step since it has no clean view of the GPU memory (HF is sharing the same memory). policy_generation.finish_generation() @@ -248,7 +306,7 @@ def setup( ) policy = HfPolicy( - cluster=cluster, + cluster=train_cluster, config=policy_config, tokenizer=tokenizer, weights_path=Path(last_checkpoint_path) / "policy" / "weights" @@ -260,10 +318,14 @@ def setup( init_optimizer=True, ) - if generation_config["colocated"]: - world_size = cluster_config["num_nodes"] * cluster_config["gpus_per_node"] - policy.init_collective(world_size) - policy_generation.init_collective(world_size) + # if it is not colocated inference, initialize collective communication for update weights + if not colocated_inference: + world_size = ( + inference_nodes * inference_gpus_per_node + 1 + ) # inference cluster + head node of the train cluster + futures_train = policy.init_collective(world_size) + futures_inference = policy_generation.init_collective(world_size) # type: ignore + ray.get(futures_train + futures_inference) loss_fn = ClippedPGLossFn(loss_config) @@ -274,7 +336,7 @@ def setup( return ( policy, policy_generation, - cluster, + (train_cluster, inference_cluster), dataloader, val_dataloader, loss_fn, @@ -293,7 +355,7 @@ def setup( def refit_policy_generation( policy: ColocatablePolicyInterface, policy_generation: GenerationInterface, - is_colocated: bool = True, + colocated_inference: bool, _refit_buffer_size_gb: Optional[int] = None, ) -> None: """Refit the policy generation interface with the latest policy weights. @@ -309,7 +371,8 @@ def refit_policy_generation( policy_generation.prepare_for_generation(tags=["weights"]) # update weights - if is_colocated: + update_success = False + if colocated_inference: # get model param keys, which is grouped by size grouped_param_keys = policy.prepare_weights_for_ipc( _refit_buffer_size_gb=_refit_buffer_size_gb @@ -317,23 +380,28 @@ def refit_policy_generation( # do update for keys in grouped_param_keys: ipc_handles = policy.get_weights_ipc_handles(keys) - if not policy_generation.update_weights(ipc_handles): - error_message = ( - "❌ Error: Updating weights for the generation policy failed during refit.\n" - "This often indicates an issue with cuda-ipc or " - "a problem within the generation backend (e.g., vLLM worker).\n" - ) - raise RuntimeError(error_message) + update_success = policy_generation.update_weights(ipc_handles) + if not update_success: + break else: state_dict_info = policy.prepare_info_for_collective() - policy.broadcast_weights_for_collective() - if not policy_generation.update_weights_from_collective(state_dict_info): - error_message = ( - "❌ Error: Updating weights for the generation policy failed during refit.\n" - "This often indicates an issue with nccl or " - "a problem within the generation backend (e.g., vLLM worker).\n" - ) - raise RuntimeError(error_message) + futures_train = policy.broadcast_weights_for_collective() + futures_inference = policy_generation.update_weights_from_collective( + state_dict_info + ) + ray.get(futures_train) + results = ray.get(futures_inference) + update_success = all(result for result in results if result is not None) + + # check if update is successful + if not update_success: + tag = "cuda-ipc" if colocated_inference else "nccl" + error_message = ( + "❌ Error: Updating weights for the generation policy failed during refit.\n" + f"This often indicates an issue with {tag} or " + "a problem within the generation backend (e.g., vLLM worker).\n" + ) + raise RuntimeError(error_message) policy.offload_after_refit() policy_generation.prepare_for_generation(tags=["kv_cache"]) @@ -373,12 +441,13 @@ def grpo_train( consumed_samples = grpo_save_state["consumed_samples"] val_period = master_config["grpo"]["val_period"] val_at_start = master_config["grpo"]["val_at_start"] + colocated_inference = master_config["policy"]["generation"]["colocated"]["enabled"] # Run validation at the start if configured if val_at_start and step == 0: print("\n🔍 Running initial validation...") if NEED_REFIT and POLICY_GENERATION_STALE: - refit_policy_generation(policy, policy_generation) + refit_policy_generation(policy, policy_generation, colocated_inference) POLICY_GENERATION_STALE = False else: policy_generation.prepare_for_generation() @@ -421,7 +490,9 @@ def grpo_train( print(f"▶ Generating responses for batch of size {repeated_batch.size}...") with timer.time("prepare_for_generation"): if NEED_REFIT and POLICY_GENERATION_STALE: - refit_policy_generation(policy, policy_generation) + refit_policy_generation( + policy, policy_generation, colocated_inference + ) POLICY_GENERATION_STALE = False else: policy_generation.prepare_for_generation() @@ -533,7 +604,9 @@ def grpo_train( # Run validation if it's a validation step if is_last_step or (val_period > 0 and (step + 1) % val_period == 0): if NEED_REFIT and POLICY_GENERATION_STALE: - refit_policy_generation(policy, policy_generation) + refit_policy_generation( + policy, policy_generation, colocated_inference + ) POLICY_GENERATION_STALE = False else: policy_generation.prepare_for_generation() diff --git a/nemo_rl/models/generation/interfaces.py b/nemo_rl/models/generation/interfaces.py index 9e4fdbca52..239ea61847 100644 --- a/nemo_rl/models/generation/interfaces.py +++ b/nemo_rl/models/generation/interfaces.py @@ -14,6 +14,7 @@ from abc import ABC, abstractmethod from typing import Any, NotRequired, Optional, TypedDict, Union +import ray import torch from nemo_rl.distributed.batched_data_dict import BatchedDataDict @@ -102,7 +103,6 @@ class GenerationConfig(TypedDict): """Configuration for generation.""" backend: str - colocated: bool max_new_tokens: int temperature: float top_p: float @@ -198,7 +198,7 @@ class GenerationInterface(ABC): """Abstract base class defining the interface for RL policies.""" @abstractmethod - def init_collective(self, world_size: int) -> None: + def init_collective(self, world_size: int) -> list[ray.ObjectRef]: """Initialize the collective communication.""" pass @@ -220,6 +220,8 @@ def update_weights(self, ipc_handles: dict[str, Any]) -> bool: """Update the model weights from the given IPC handles.""" raise NotImplementedError - def update_weights_from_collective(self, info: dict[str, Any]) -> bool: + def update_weights_from_collective( + self, info: dict[str, Any] + ) -> list[ray.ObjectRef]: """Update the model weights from collective communication.""" raise NotImplementedError diff --git a/nemo_rl/models/generation/vllm.py b/nemo_rl/models/generation/vllm.py index 658741a4de..4fe8b86ad8 100644 --- a/nemo_rl/models/generation/vllm.py +++ b/nemo_rl/models/generation/vllm.py @@ -342,8 +342,11 @@ def _patch_vllm_init_workers_ray(): else: self.llm = vllm.LLM(**llm_kwargs) - def init_collective(self, world_size: int) -> None: - self.llm.collective_rpc("init_collective", args=(world_size,)) + def init_collective(self, rank_prefix: int, world_size: int) -> None: + self.llm.collective_rpc( + "init_collective", + args=(rank_prefix, world_size,), + ) def llm(self): return self.llm @@ -940,7 +943,9 @@ def update_weights_from_collective(self, info: dict[str, Any]) -> bool: "update_weights_from_collective cannot be used with async_engine=True. Use update_weights_from_ipc_handles_async instead." ) - result_or_coro = self.llm.collective_rpc("update_weights_from_collective", args=(info,)) + result_or_coro = self.llm.collective_rpc( + "update_weights_from_collective", args=(info,) + ) worker_result = result_or_coro[0] if not worker_result: @@ -1237,22 +1242,26 @@ def _report_device_id(self) -> list[list[str]]: results = ray.get(futures) return results - def init_collective(self, world_size: int) -> None: + def init_collective(self, world_size: int) -> list[ray.ObjectRef]: """Initialize the collective communication.""" - try: - # Use run_all_workers_single_data to send data to all workers - print("[init_collective] in vllm") - futures = self.worker_group.run_all_workers_single_data( - "init_collective", - world_size=world_size, - run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"] - ) - # Wait for all futures to complete - results = ray.get(futures) - return all(result for result in results if result is not None) - except Exception as e: - print(f"Error during init collective: {e}") - return False + if not self.worker_group or not self.worker_group.workers: + raise RuntimeError("Worker group is not initialized") + + # Prepare rank + total_workers = len(self.worker_group.workers) + workers_per_group = len(self.worker_group.tied_workers_groups[0]) + rank_prefix_list = list(range(0, total_workers, workers_per_group)) + + # Send world_size and rank for init collective to all workers + futures = self.worker_group.run_all_workers_multiple_data( + "init_collective", + data=rank_prefix_list, + run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"], + common_kwargs={"world_size": world_size}, + ) + + # this function should co-work with hf_policy, so we should wait for all futures to complete outside + return futures def generate( self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False @@ -1488,21 +1497,22 @@ def update_weights(self, ipc_handles: dict[str, Any]) -> bool: print(f"Error during update weights: {e}") return False - def update_weights_from_collective(self, info: dict[str, Any]) -> bool: + def update_weights_from_collective( + self, info: dict[str, Any] + ) -> list[ray.ObjectRef]: """Update weights of the policy using collective communication.""" - try: - # Use run_all_workers_single_data to send data to all workers - futures = self.worker_group.run_all_workers_single_data( - "update_weights_from_collective", - info=info, - run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"] - ) - # Wait for all futures to complete - results = ray.get(futures) - return all(result for result in results if result is not None) - except Exception as e: - print(f"Error during update weights from collective: {e}") - return False + if not self.worker_group or not self.worker_group.workers: + raise RuntimeError("Worker group is not initialized") + + # Use run_all_workers_single_data to send data to all workers + futures = self.worker_group.run_all_workers_single_data( + "update_weights_from_collective", + info=info, + run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"], + ) + + # this function should co-work with hf_policy, so we should wait for all futures to complete outside + return futures def __del__(self) -> None: """Shuts down the worker groups when the object is deleted or is garbage collected. diff --git a/nemo_rl/models/generation/vllm_backend.py b/nemo_rl/models/generation/vllm_backend.py index 455c454469..c5430910bf 100644 --- a/nemo_rl/models/generation/vllm_backend.py +++ b/nemo_rl/models/generation/vllm_backend.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import torch from typing import Any +import torch + try: import vllm # noqa: F401 except ImportError: @@ -25,11 +26,16 @@ class VllmInternalWorkerExtension: - def init_collective(self, world_size: int) -> None: + def init_collective(self, rank_prefix: int, world_size: int) -> None: """Initialize the collective communication.""" import ray.util.collective as collective - collective.init_collective_group(world_size=world_size, rank=1, backend="nccl", group_name="refit") + local_rank = torch.distributed.get_rank() + rank = rank_prefix + local_rank + 1 # 1 is the head node of the train cluster + + collective.init_collective_group( + world_size=world_size, rank=rank, backend="nccl", group_name="refit" + ) def report_device_id(self) -> str: from nemo_rl.utils.nvml import get_device_uuid @@ -71,7 +77,7 @@ def update_weights_from_ipc_handles(self, ipc_handles): ) return False - def update_weights_from_collective(self, info: dict[str, Any]) -> None: + def update_weights_from_collective(self, info: dict[str, Any]) -> bool: """Update the model weights from collective communication.""" import ray.util.collective as collective diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index e263dfbf62..18de242f15 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -128,7 +128,7 @@ def __init__( self.cfg = config # torch distributed init. Envars for rank, world_size, and master_addr and master_port are set from the ray remote call torch.distributed.init_process_group(backend="nccl") - rank = torch.distributed.get_rank() + self.rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() model_name = self.cfg["model_name"] @@ -144,7 +144,7 @@ def __init__( else: raise ValueError(f"Unknown precision: {self.cfg['precision']}") - print(f"[Rank {rank}] Loading model {model_name} on CPU...") + print(f"[Rank {self.rank}] Loading model {model_name} on CPU...") self.model = AutoModelForCausalLM.from_pretrained( model_name, device_map="cpu", # load weights onto CPU initially @@ -266,7 +266,15 @@ def init_collective(self, world_size: int) -> None: """Initialize the collective communication.""" import ray.util.collective as collective - collective.init_collective_group(world_size=world_size, rank=0, backend="nccl", group_name="refit") + # keep the same behavior as vllm + # see https://github.com/vllm-project/vllm/blob/v0.8.5/vllm/env_override.py#L25 + if not os.path.exists("/dev/nvidia-caps-imex-channels"): + os.environ["NCCL_CUMEM_ENABLE"] = "0" + + if self.rank == 0: + collective.init_collective_group( + world_size=world_size, rank=0, backend="nccl", group_name="refit" + ) def is_alive(self) -> bool: return True @@ -786,7 +794,8 @@ def broadcast_weights_for_collective(self) -> None: if isinstance(tensor, DTensor): tensor = tensor.full_tensor() tensor = tensor.to(self.dtype, non_blocking=True) - collective.broadcast(tensor.data, 0, group_name="refit") + if self.rank == 0: + collective.broadcast(tensor.data, 0, group_name="refit") def prepare_for_lp_inference(self) -> None: if not self.cpu_offload: diff --git a/nemo_rl/models/policy/hf_policy.py b/nemo_rl/models/policy/hf_policy.py index 0af89d61f8..b573226a47 100644 --- a/nemo_rl/models/policy/hf_policy.py +++ b/nemo_rl/models/policy/hf_policy.py @@ -119,9 +119,13 @@ def __init__( self.cfg = config - def init_collective(self, world_size: int) -> None: + def init_collective(self, world_size: int) -> list[ray.ObjectRef]: """Initialize the collective communication.""" - self.worker_group.run_all_workers_single_data("init_collective", world_size=world_size) + futures = self.worker_group.run_all_workers_single_data( + "init_collective", world_size=world_size + ) + # this function should co-work with vllm, so we should wait for all futures to complete outside + return futures def get_logprobs( self, data: BatchedDataDict[GenerationDatumSpec] @@ -419,11 +423,13 @@ def prepare_info_for_collective(self) -> dict[str, Any]: # Only get the first worker's info since all workers will have the same result return results[0] - def broadcast_weights_for_collective(self) -> None: + def broadcast_weights_for_collective(self) -> list[ray.ObjectRef]: """Broadcast the weights for collective communication.""" - self.worker_group.run_all_workers_single_data( + futures = self.worker_group.run_all_workers_single_data( "broadcast_weights_for_collective" ) + # this function should co-work with vllm, so we should wait for all futures to complete outside + return futures def offload_before_refit(self) -> None: """Offload the optimizer and buffers to the CPU.""" diff --git a/nemo_rl/models/policy/interfaces.py b/nemo_rl/models/policy/interfaces.py index ef3747ce6f..5d6214d8c0 100644 --- a/nemo_rl/models/policy/interfaces.py +++ b/nemo_rl/models/policy/interfaces.py @@ -14,6 +14,7 @@ from abc import ABC, abstractmethod from typing import Any, TypedDict +import ray import torch from nemo_rl.algorithms.interfaces import LossFunction @@ -94,7 +95,7 @@ def shutdown(self) -> bool: class ColocatablePolicyInterface(PolicyInterface): @abstractmethod - def init_collective(self, world_size: int) -> None: + def init_collective(self, world_size: int) -> list[ray.ObjectRef]: pass @abstractmethod @@ -118,5 +119,5 @@ def prepare_info_for_collective(self) -> dict[str, Any]: pass @abstractmethod - def broadcast_weights_for_collective(self) -> None: + def broadcast_weights_for_collective(self) -> list[ray.ObjectRef]: pass diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index 717b44ff0a..00002f6838 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -14,7 +14,6 @@ import os from copy import deepcopy -from unittest import mock import pytest import ray @@ -33,7 +32,6 @@ # Define basic vLLM test config basic_vllm_test_config: VllmConfig = { "backend": "vllm", - "colocated": True, "model_name": model_name, "tokenizer": { "name": model_name, @@ -55,6 +53,11 @@ "skip_tokenizer_init": False, "load_format": "auto", }, + "colocated": { + "enabled": True, + "gpus_for_inference": -1, + "nodes_for_inference": -1, + }, "vllm_kwargs": {}, } @@ -312,7 +315,9 @@ async def test_vllm_policy_generation_async( print("creating hf policy...") hf_policy = HfPolicy(cluster, hf_config, tokenizer) - refit_policy_generation(hf_policy, async_policy) + refit_policy_generation( + hf_policy, async_policy, vllm_config["colocated"]["enabled"] + ) outputs = async_policy.generate_async(test_input_data) # Validate outputs format @@ -407,7 +412,7 @@ def test_vllm_worker_seed_behavior(cluster, tokenizer): hf_policy = HfPolicy(cluster, hf_config, tokenizer) print("refitting vllm policy...") - refit_policy_generation(hf_policy, policy) + refit_policy_generation(hf_policy, policy, vllm_config["colocated"]["enabled"]) try: # Generate with duplicated prompts @@ -563,7 +568,9 @@ def test_vllm_generation_with_hf_training( hf_policy = HfPolicy(cluster, hf_config, tokenizer) print("refitting vllm policy...") - refit_policy_generation(hf_policy, vllm_policy) + refit_policy_generation( + hf_policy, vllm_policy, vllm_config["colocated"]["enabled"] + ) # Step 1: Use vLLM for generation print("Using vLLM policy for fast generation...") @@ -915,7 +922,12 @@ def test_vllm_weight_update_memory(cluster, tokenizer, enable_dtensor): # reset peak memory stats before refit workers = hf_policy.worker_group.workers ray.get([w.reset_peak_memory_stats.remote() for w in workers]) - refit_policy_generation(hf_policy, vllm_policy, _refit_buffer_size_gb=1) + refit_policy_generation( + hf_policy, + vllm_policy, + vllm_config["colocated"]["enabled"], + _refit_buffer_size_gb=1, + ) gpu_infos = ray.get([w.get_gpu_info.remote() for w in workers]) # Gather memory stats @@ -982,7 +994,9 @@ def test_vllm_generation_with_stop( hf_policy = HfPolicy(cluster, hf_config, tokenizer) print("refitting vllm policy...") - refit_policy_generation(hf_policy, vllm_generation) + refit_policy_generation( + hf_policy, vllm_generation, vllm_config["colocated"]["enabled"] + ) # test generate outputs = vllm_generation.generate(test_input_data, greedy=True) @@ -1056,7 +1070,6 @@ def test_vllm_refit_non_collocated_handles_update( ) # Create HfPolicy on its own cluster - os.environ["NCCL_CUMEM_ENABLE"] = "0" hf_config = get_basic_hf_test_config(enable_dtensor=True) hf_config["dtensor_cfg"]["tensor_parallel_size"] = 1 hf_policy = HfPolicy(policy_cluster_separate, hf_config, tokenizer) @@ -1065,13 +1078,16 @@ def test_vllm_refit_non_collocated_handles_update( vllm_config = deepcopy(basic_vllm_test_config) vllm_config = configure_generation_config(vllm_config, tokenizer, is_eval=True) vllm_config["vllm_cfg"]["tensor_parallel_size"] = 1 + vllm_config["colocated"]["enabled"] = False vllm_generation = VllmGeneration(generation_cluster_separate, vllm_config) hf_policy.init_collective(2) vllm_generation.init_collective(2) print("refitting vllm policy...") - refit_policy_generation(hf_policy, vllm_generation, is_colocated=False) + refit_policy_generation( + hf_policy, vllm_generation, vllm_config["colocated"]["enabled"] + ) # test generate outputs = vllm_generation.generate(test_input_data, greedy=True) From fe519c82da029505d9d5eb8f99b74d3b564b2701 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Fri, 6 Jun 2025 09:02:52 -0700 Subject: [PATCH 3/8] no need to offload when non-colocated Signed-off-by: Yuki Huang --- nemo_rl/algorithms/grpo.py | 12 +++++++----- nemo_rl/models/generation/vllm.py | 22 +++++++++++++++++----- 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index a7a97110e9..47bab0d9fe 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -324,7 +324,7 @@ def setup( inference_nodes * inference_gpus_per_node + 1 ) # inference cluster + head node of the train cluster futures_train = policy.init_collective(world_size) - futures_inference = policy_generation.init_collective(world_size) # type: ignore + futures_inference = policy_generation.init_collective(world_size) # type: ignore ray.get(futures_train + futures_inference) loss_fn = ClippedPGLossFn(loss_config) @@ -367,8 +367,9 @@ def refit_policy_generation( If it is None, the buffer size will be computed by the remaining memory. This parameter is primarily used for testing. """ - policy.offload_before_refit() - policy_generation.prepare_for_generation(tags=["weights"]) + if colocated_inference: + policy.offload_before_refit() + policy_generation.prepare_for_generation(tags=["weights"]) # update weights update_success = False @@ -403,8 +404,9 @@ def refit_policy_generation( ) raise RuntimeError(error_message) - policy.offload_after_refit() - policy_generation.prepare_for_generation(tags=["kv_cache"]) + if colocated_inference: + policy.offload_after_refit() + policy_generation.prepare_for_generation(tags=["kv_cache"]) # =============================================================================== diff --git a/nemo_rl/models/generation/vllm.py b/nemo_rl/models/generation/vllm.py index 4fe8b86ad8..620d207122 100644 --- a/nemo_rl/models/generation/vllm.py +++ b/nemo_rl/models/generation/vllm.py @@ -345,7 +345,10 @@ def _patch_vllm_init_workers_ray(): def init_collective(self, rank_prefix: int, world_size: int) -> None: self.llm.collective_rpc( "init_collective", - args=(rank_prefix, world_size,), + args=( + rank_prefix, + world_size, + ), ) def llm(self): @@ -961,6 +964,12 @@ def update_weights_from_collective(self, info: dict[str, Any]) -> bool: traceback.print_exc() return False + def reset_prefix_cache(self): + """Reset the prefix cache of vLLM engine.""" + self.llm.llm_engine.reset_prefix_cache() + gc.collect() + torch.cuda.empty_cache() + def sleep(self): """Put the vLLM engine to sleep.""" assert self.llm is not None, ( @@ -1426,12 +1435,15 @@ def prepare_for_generation(self, *args: Any, **kwargs: Any) -> bool: return False def finish_generation(self, *args: Any, **kwargs: Any) -> bool: - """Sleep workers.""" + """Sleep workers and reset prefix cache.""" try: # Choose the appropriate method based on async_engine setting - method_name = ( - "sleep_async" if self.cfg["vllm_cfg"]["async_engine"] else "sleep" - ) + if not self.cfg["colocated"]["enabled"]: + method_name = "reset_prefix_cache" + else: + method_name = ( + "sleep_async" if self.cfg["vllm_cfg"]["async_engine"] else "sleep" + ) # Use run_all_workers_single_data for methods that don't need data futures = self.worker_group.run_all_workers_single_data( method_name, From 2c195bf09ec8a97efcec63f4628cbf4a3f5fef1f Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Sun, 8 Jun 2025 20:17:56 -0700 Subject: [PATCH 4/8] fix vllm tp-size>1 by using vllm collective Signed-off-by: Yuki Huang lint Signed-off-by: Yuki Huang fix ip and config Signed-off-by: Yuki Huang --- nemo_rl/algorithms/grpo.py | 25 +++++++++++++------ nemo_rl/models/generation/interfaces.py | 4 ++- nemo_rl/models/generation/vllm.py | 21 ++++++++++++---- nemo_rl/models/generation/vllm_backend.py | 16 ++++++------ .../models/policy/dtensor_policy_worker.py | 17 +++++++------ nemo_rl/models/policy/hf_policy.py | 6 +++-- nemo_rl/models/policy/interfaces.py | 4 ++- .../models/generation/test_vllm_generation.py | 12 ++++++--- 8 files changed, 70 insertions(+), 35 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 47bab0d9fe..bef38826e3 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -38,7 +38,10 @@ get_keys_from_message_log, ) from nemo_rl.distributed.batched_data_dict import BatchedDataDict -from nemo_rl.distributed.virtual_cluster import ClusterConfig, RayVirtualCluster +from nemo_rl.distributed.virtual_cluster import ( + ClusterConfig, + RayVirtualCluster, +) from nemo_rl.environments.interfaces import ( EnvironmentInterface, ) @@ -320,11 +323,14 @@ def setup( # if it is not colocated inference, initialize collective communication for update weights if not colocated_inference: - world_size = ( - inference_nodes * inference_gpus_per_node + 1 - ) # inference cluster + head node of the train cluster - futures_train = policy.init_collective(world_size) - futures_inference = policy_generation.init_collective(world_size) # type: ignore + ip, port = train_cluster.get_master_address_and_port() + print(f"Using ip: {ip}, port: {port} for collective communication") + # inference cluster + head node of the train cluster + world_size = inference_nodes * inference_gpus_per_node + 1 + # init collective + futures_train = policy.init_collective(ip, port, world_size) + futures_inference = policy_generation.init_collective(ip, port, world_size) # type: ignore + # wait for all futures to complete ray.get(futures_train + futures_inference) loss_fn = ClippedPGLossFn(loss_config) @@ -385,21 +391,24 @@ def refit_policy_generation( if not update_success: break else: + # prepare info for update weights state_dict_info = policy.prepare_info_for_collective() + # update weights through nccl futures_train = policy.broadcast_weights_for_collective() futures_inference = policy_generation.update_weights_from_collective( state_dict_info ) + # wait for all futures to complete ray.get(futures_train) results = ray.get(futures_inference) update_success = all(result for result in results if result is not None) # check if update is successful if not update_success: - tag = "cuda-ipc" if colocated_inference else "nccl" + error_tag = "cuda-ipc" if colocated_inference else "nccl" error_message = ( "❌ Error: Updating weights for the generation policy failed during refit.\n" - f"This often indicates an issue with {tag} or " + f"This often indicates an issue with {error_tag} or " "a problem within the generation backend (e.g., vLLM worker).\n" ) raise RuntimeError(error_message) diff --git a/nemo_rl/models/generation/interfaces.py b/nemo_rl/models/generation/interfaces.py index 239ea61847..c670ec16b2 100644 --- a/nemo_rl/models/generation/interfaces.py +++ b/nemo_rl/models/generation/interfaces.py @@ -198,7 +198,9 @@ class GenerationInterface(ABC): """Abstract base class defining the interface for RL policies.""" @abstractmethod - def init_collective(self, world_size: int) -> list[ray.ObjectRef]: + def init_collective( + self, ip: str, port: int, world_size: int + ) -> list[ray.ObjectRef]: """Initialize the collective communication.""" pass diff --git a/nemo_rl/models/generation/vllm.py b/nemo_rl/models/generation/vllm.py index 620d207122..e620b3bd92 100644 --- a/nemo_rl/models/generation/vllm.py +++ b/nemo_rl/models/generation/vllm.py @@ -342,11 +342,15 @@ def _patch_vllm_init_workers_ray(): else: self.llm = vllm.LLM(**llm_kwargs) - def init_collective(self, rank_prefix: int, world_size: int) -> None: + def init_collective( + self, rank_prefix: int, ip: str, port: int, world_size: int + ) -> None: self.llm.collective_rpc( "init_collective", args=( rank_prefix, + ip, + port, world_size, ), ) @@ -1251,7 +1255,9 @@ def _report_device_id(self) -> list[list[str]]: results = ray.get(futures) return results - def init_collective(self, world_size: int) -> list[ray.ObjectRef]: + def init_collective( + self, ip: str, port: int, world_size: int + ) -> list[ray.ObjectRef]: """Initialize the collective communication.""" if not self.worker_group or not self.worker_group.workers: raise RuntimeError("Worker group is not initialized") @@ -1266,7 +1272,7 @@ def init_collective(self, world_size: int) -> list[ray.ObjectRef]: "init_collective", data=rank_prefix_list, run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"], - common_kwargs={"world_size": world_size}, + common_kwargs={"ip": ip, "port": port, "world_size": world_size}, ) # this function should co-work with hf_policy, so we should wait for all futures to complete outside @@ -1415,7 +1421,11 @@ def generate_async( return combined def prepare_for_generation(self, *args: Any, **kwargs: Any) -> bool: - """Wake workers up.""" + """Wake workers up for colocated inference.""" + # non-colocated no need to wake up + if not self.cfg["colocated"]["enabled"]: + return True + try: # Choose the appropriate method based on async_engine setting method_name = ( @@ -1437,7 +1447,8 @@ def prepare_for_generation(self, *args: Any, **kwargs: Any) -> bool: def finish_generation(self, *args: Any, **kwargs: Any) -> bool: """Sleep workers and reset prefix cache.""" try: - # Choose the appropriate method based on async_engine setting + # Choose the appropriate method based on setting + # non-colocated only needs reset prefix cache, no need to sleep. if not self.cfg["colocated"]["enabled"]: method_name = "reset_prefix_cache" else: diff --git a/nemo_rl/models/generation/vllm_backend.py b/nemo_rl/models/generation/vllm_backend.py index c5430910bf..c40aea4418 100644 --- a/nemo_rl/models/generation/vllm_backend.py +++ b/nemo_rl/models/generation/vllm_backend.py @@ -26,16 +26,20 @@ class VllmInternalWorkerExtension: - def init_collective(self, rank_prefix: int, world_size: int) -> None: + def init_collective( + self, rank_prefix: int, ip: str, port: int, world_size: int + ) -> None: """Initialize the collective communication.""" - import ray.util.collective as collective + from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator + from vllm.distributed.utils import StatelessProcessGroup local_rank = torch.distributed.get_rank() rank = rank_prefix + local_rank + 1 # 1 is the head node of the train cluster - collective.init_collective_group( - world_size=world_size, rank=rank, backend="nccl", group_name="refit" + pg = StatelessProcessGroup.create( + host=ip, port=port, rank=rank, world_size=world_size ) + self.model_update_group = PyNcclCommunicator(pg, device=self.device) def report_device_id(self) -> str: from nemo_rl.utils.nvml import get_device_uuid @@ -79,12 +83,10 @@ def update_weights_from_ipc_handles(self, ipc_handles): def update_weights_from_collective(self, info: dict[str, Any]) -> bool: """Update the model weights from collective communication.""" - import ray.util.collective as collective - try: for name, (shape, dtype) in info.items(): weight = torch.empty(shape, dtype=dtype, device="cuda") - collective.broadcast(weight, 0, group_name="refit") + self.model_update_group.broadcast(weight, src=0) self.model_runner.model.load_weights(weights=[(name, weight)]) except Exception as e: print( diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 18de242f15..673cf3df20 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -262,9 +262,10 @@ def __init__( "No weights path provided. Starting from scratch (default policy init)" ) - def init_collective(self, world_size: int) -> None: + def init_collective(self, ip: str, port: int, world_size: int) -> None: """Initialize the collective communication.""" - import ray.util.collective as collective + from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator + from vllm.distributed.utils import StatelessProcessGroup # keep the same behavior as vllm # see https://github.com/vllm-project/vllm/blob/v0.8.5/vllm/env_override.py#L25 @@ -272,9 +273,11 @@ def init_collective(self, world_size: int) -> None: os.environ["NCCL_CUMEM_ENABLE"] = "0" if self.rank == 0: - collective.init_collective_group( - world_size=world_size, rank=0, backend="nccl", group_name="refit" + pg = StatelessProcessGroup.create( + host=ip, port=port, rank=0, world_size=world_size ) + device = torch.cuda.current_device() + self.model_update_group = PyNcclCommunicator(pg, device=device) def is_alive(self) -> bool: return True @@ -788,14 +791,12 @@ def prepare_info_for_collective(self) -> dict[str, Any]: @torch.no_grad() def broadcast_weights_for_collective(self) -> None: """Broadcast the weights for collective communication.""" - import ray.util.collective as collective - for _, tensor in self.model.state_dict().items(): if isinstance(tensor, DTensor): tensor = tensor.full_tensor() - tensor = tensor.to(self.dtype, non_blocking=True) if self.rank == 0: - collective.broadcast(tensor.data, 0, group_name="refit") + tensor = tensor.to(self.dtype, non_blocking=True) + self.model_update_group.broadcast(tensor.data, src=0) def prepare_for_lp_inference(self) -> None: if not self.cpu_offload: diff --git a/nemo_rl/models/policy/hf_policy.py b/nemo_rl/models/policy/hf_policy.py index b573226a47..b1928c2fc5 100644 --- a/nemo_rl/models/policy/hf_policy.py +++ b/nemo_rl/models/policy/hf_policy.py @@ -119,10 +119,12 @@ def __init__( self.cfg = config - def init_collective(self, world_size: int) -> list[ray.ObjectRef]: + def init_collective( + self, ip: str, port: int, world_size: int + ) -> list[ray.ObjectRef]: """Initialize the collective communication.""" futures = self.worker_group.run_all_workers_single_data( - "init_collective", world_size=world_size + "init_collective", ip=ip, port=port, world_size=world_size ) # this function should co-work with vllm, so we should wait for all futures to complete outside return futures diff --git a/nemo_rl/models/policy/interfaces.py b/nemo_rl/models/policy/interfaces.py index 5d6214d8c0..614340c67b 100644 --- a/nemo_rl/models/policy/interfaces.py +++ b/nemo_rl/models/policy/interfaces.py @@ -95,7 +95,9 @@ def shutdown(self) -> bool: class ColocatablePolicyInterface(PolicyInterface): @abstractmethod - def init_collective(self, world_size: int) -> list[ray.ObjectRef]: + def init_collective( + self, ip: str, port: int, world_size: int + ) -> list[ray.ObjectRef]: pass @abstractmethod diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index 00002f6838..a07b8d545b 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -22,7 +22,10 @@ from nemo_rl.algorithms.grpo import refit_policy_generation from nemo_rl.algorithms.utils import get_tokenizer from nemo_rl.distributed.batched_data_dict import BatchedDataDict -from nemo_rl.distributed.virtual_cluster import RayVirtualCluster +from nemo_rl.distributed.virtual_cluster import ( + RayVirtualCluster, + _get_node_ip_and_free_port, +) from nemo_rl.models.generation import configure_generation_config from nemo_rl.models.generation.vllm import VllmConfig, VllmGeneration from nemo_rl.models.policy import PolicyConfig @@ -1081,8 +1084,11 @@ def test_vllm_refit_non_collocated_handles_update( vllm_config["colocated"]["enabled"] = False vllm_generation = VllmGeneration(generation_cluster_separate, vllm_config) - hf_policy.init_collective(2) - vllm_generation.init_collective(2) + # initialize collective communication for update weights + ip, port = ray.get(_get_node_ip_and_free_port.remote()) + futures_train = hf_policy.init_collective(ip, port, world_size=2) + futures_inference = vllm_generation.init_collective(ip, port, world_size=2) + ray.get(futures_train + futures_inference) print("refitting vllm policy...") refit_policy_generation( From a803dbca9144880c7e3d27c193568bd5daa98095 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Tue, 10 Jun 2025 08:18:11 -0700 Subject: [PATCH 5/8] update config Signed-off-by: Yuki Huang update config structure Signed-off-by: Yuki Huang --- examples/configs/grpo-deepscaler-1.5b-8K.yaml | 8 +++++ examples/configs/grpo_math_1B.yaml | 10 ++++--- .../llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml | 5 ++++ ...-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml | 5 ++++ ...3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml | 5 ++++ ...llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml | 5 ++++ ...-32b-16n8g-fsdp2tp8sp-actckpt-long.v3.yaml | 5 ++++ ...en2.5-32b-16n8g-fsdp2tp8sp-actckpt.v3.yaml | 5 ++++ ...rpo-qwen2.5-7b-instruct-4n8g-fsdp1.v3.yaml | 5 ++++ ...wen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml | 5 ++++ ...5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml | 5 ++++ nemo_rl/algorithms/grpo.py | 30 ++++++++++++------- tests/unit/experience/test_rollouts.py | 7 +++++ .../models/generation/test_vllm_generation.py | 11 +++---- .../generation/test_vllm_large_model.py | 7 +++++ tests/unit/utils/test_native_checkpoint.py | 2 ++ 16 files changed, 100 insertions(+), 20 deletions(-) diff --git a/examples/configs/grpo-deepscaler-1.5b-8K.yaml b/examples/configs/grpo-deepscaler-1.5b-8K.yaml index eef1c4e205..3055a106f9 100644 --- a/examples/configs/grpo-deepscaler-1.5b-8K.yaml +++ b/examples/configs/grpo-deepscaler-1.5b-8K.yaml @@ -101,6 +101,14 @@ policy: # For most cases, use "dummy" to load the initial weights, since they will be overwritten during refit # For Gemma models, we need to use "auto" due to a vllm bug load_format: dummy + colocated: + # true: generation shares training GPUs + # false: uses dedicated generation resources + enabled: true + # only relevant when enabled is false + resources: + gpus_per_node: null # used when cluster.num_nodes is 1 + num_nodes: null # used when cluster.num_nodes > 1 data: max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 6200d223c4..b0d240bdc7 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -107,11 +107,13 @@ policy: gpu_memory_utilization: 0.6 max_model_len: ${policy.max_total_sequence_length} colocated: + # true: generation shares training GPUs + # false: uses dedicated generation resources enabled: true - # decides how many GPUs to use for inference - # used when train and inference is not colocated - gpus_per_node: -1 # used when cluster.num_nodes is 1 - num_nodes: -1 # used when cluster.num_nodes > 1 + # only relevant when enabled is false + resources: + gpus_per_node: null # used when cluster.num_nodes is 1 + num_nodes: null # used when cluster.num_nodes > 1 data: max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len diff --git a/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml b/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml index a737c5e35e..f7079fe4a8 100644 --- a/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml +++ b/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml @@ -88,6 +88,11 @@ policy: pipeline_parallel_size: 1 gpu_memory_utilization: 0.6 max_model_len: 512 + colocated: + enabled: true + resources: + gpus_per_node: null + num_nodes: null data: max_input_seq_length: 512 prompt_file: examples/prompts/cot.txt diff --git a/examples/configs/recipes/llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml b/examples/configs/recipes/llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml index ac0627f487..9b4d7ffc84 100644 --- a/examples/configs/recipes/llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml +++ b/examples/configs/recipes/llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml @@ -89,6 +89,11 @@ policy: pipeline_parallel_size: 1 gpu_memory_utilization: 0.6 max_model_len: 16384 + colocated: + enabled: true + resources: + gpus_per_node: null + num_nodes: null data: max_input_seq_length: 16384 prompt_file: examples/prompts/cot.txt diff --git a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml index af375d4f14..65a909ce80 100644 --- a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml +++ b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml @@ -89,6 +89,11 @@ policy: pipeline_parallel_size: 1 gpu_memory_utilization: 0.6 max_model_len: 4096 + colocated: + enabled: true + resources: + gpus_per_node: null + num_nodes: null data: max_input_seq_length: 4096 prompt_file: examples/prompts/cot.txt diff --git a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml index 90cbcdc8ff..5e2a98c022 100644 --- a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml +++ b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml @@ -89,6 +89,11 @@ policy: pipeline_parallel_size: 1 gpu_memory_utilization: 0.6 max_model_len: 512 + colocated: + enabled: true + resources: + gpus_per_node: null + num_nodes: null data: max_input_seq_length: 512 prompt_file: examples/prompts/cot.txt diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt-long.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt-long.v3.yaml index 6ec2201eaf..5b53de160a 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt-long.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt-long.v3.yaml @@ -89,6 +89,11 @@ policy: pipeline_parallel_size: 1 gpu_memory_utilization: 0.6 max_model_len: 16384 + colocated: + enabled: true + resources: + gpus_per_node: null + num_nodes: null data: max_input_seq_length: 16384 prompt_file: examples/prompts/cot.txt diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt.v3.yaml index 01fbb27245..382d8b80f3 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt.v3.yaml @@ -89,6 +89,11 @@ policy: pipeline_parallel_size: 1 gpu_memory_utilization: 0.6 max_model_len: 16384 + colocated: + enabled: true + resources: + gpus_per_node: null + num_nodes: null data: max_input_seq_length: 16384 prompt_file: examples/prompts/cot.txt diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp1.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp1.v3.yaml index 63dc2a85c3..0db76b931d 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp1.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp1.v3.yaml @@ -86,6 +86,11 @@ policy: pipeline_parallel_size: 1 gpu_memory_utilization: 0.6 max_model_len: 4096 + colocated: + enabled: true + resources: + gpus_per_node: null + num_nodes: null data: max_input_seq_length: 4096 prompt_file: examples/prompts/cot.txt diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml index 05ff7233a6..6196a7f3b4 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml @@ -89,6 +89,11 @@ policy: pipeline_parallel_size: 1 gpu_memory_utilization: 0.6 max_model_len: 4096 + colocated: + enabled: true + resources: + gpus_per_node: null + num_nodes: null data: max_input_seq_length: 4096 prompt_file: examples/prompts/cot.txt diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml index ecd3f1d417..38d774bd76 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml @@ -89,6 +89,11 @@ policy: pipeline_parallel_size: 1 gpu_memory_utilization: 0.6 max_model_len: 512 + colocated: + enabled: true + resources: + gpus_per_node: null + num_nodes: null data: max_input_seq_length: 512 prompt_file: examples/prompts/cot.txt diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index bef38826e3..8b576089b6 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -237,25 +237,33 @@ def setup( "Non-colocated inference is not supported for HF backend" ) + # train resources will be updated through overall and inference resources below train_gpus_per_node = cluster_config["gpus_per_node"] train_nodes = cluster_config["num_nodes"] - inference_gpus_per_node = generation_config["colocated"]["gpus_per_node"] - inference_nodes = generation_config["colocated"]["num_nodes"] - # validate and configure generation.colocated config + inference_resources = generation_config["colocated"]["resources"] + inference_gpus_per_node = inference_resources["gpus_per_node"] + inference_nodes = inference_resources["num_nodes"] + + # validate and configure resources if cluster_config["num_nodes"] == 1: - assert inference_gpus_per_node > 0 and inference_nodes == -1, ( - "policy.generation.colocated.gpus_per_node must be set " - "and policy.generation.colocated.num_nodes must be -1 " - "when cluster.num_nodes=1 and inference is not colocated" + assert inference_gpus_per_node > 0 and ( + inference_nodes is None or inference_nodes == 1 + ), ( + "policy.generation.colocated.resources.gpus_per_node must be set and " + "policy.generation.colocated.resources.num_nodes must be 1 or set to null " + "when cluster.num_nodes = 1 and inference is non-colocated" ) inference_nodes = 1 train_gpus_per_node -= inference_gpus_per_node else: - assert inference_gpus_per_node == -1 and inference_nodes > 0, ( - "policy.generation.colocated.gpus_per_node must be -1 and " - "policy.generation.colocated.num_nodes must be set " - "when cluster.num_nodes > 1 and inference is not colocated" + assert inference_nodes > 0 and ( + inference_gpus_per_node is None + or inference_gpus_per_node == cluster_config["gpus_per_node"] + ), ( + "policy.generation.colocated.resources.num_nodes must be set and " + "policy.generation.colocated.resources.gpus_per_node must be equal to cluster.gpus_per_node or set to null " + "when cluster.num_nodes > 1 and inference is non-colocated" ) inference_gpus_per_node = cluster_config["gpus_per_node"] train_nodes -= inference_nodes diff --git a/tests/unit/experience/test_rollouts.py b/tests/unit/experience/test_rollouts.py index 2898b262ae..7516b0ba6e 100644 --- a/tests/unit/experience/test_rollouts.py +++ b/tests/unit/experience/test_rollouts.py @@ -230,6 +230,13 @@ def initial_multi_step_calculator_batch(rollout_tokenizer): "disable_log_requests": True, "gpu_memory_utilization": 0.6, }, + "colocated": { + "enabled": True, + "resources": { + "gpus_per_node": None, + "num_nodes": None, + }, + }, } diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index a07b8d545b..b00b3fb6d5 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -58,8 +58,10 @@ }, "colocated": { "enabled": True, - "gpus_for_inference": -1, - "nodes_for_inference": -1, + "resources": { + "gpus_per_node": None, + "num_nodes": None, + }, }, "vllm_kwargs": {}, } @@ -107,9 +109,7 @@ def get_basic_hf_test_config(enable_dtensor: bool = False) -> PolicyConfig: }, "max_grad_norm": 1.0, "make_sequence_length_divisible_by": 1, - "generation": { - "temperature": 0.8, - }, + "generation": basic_vllm_test_config, } @@ -1075,6 +1075,7 @@ def test_vllm_refit_non_collocated_handles_update( # Create HfPolicy on its own cluster hf_config = get_basic_hf_test_config(enable_dtensor=True) hf_config["dtensor_cfg"]["tensor_parallel_size"] = 1 + hf_config["generation"]["colocated"]["enabled"] = False hf_policy = HfPolicy(policy_cluster_separate, hf_config, tokenizer) # Create VllmGeneration policy on its own cluster diff --git a/tests/unit/models/generation/test_vllm_large_model.py b/tests/unit/models/generation/test_vllm_large_model.py index 2a860d7e6f..97bc9dc66e 100644 --- a/tests/unit/models/generation/test_vllm_large_model.py +++ b/tests/unit/models/generation/test_vllm_large_model.py @@ -51,6 +51,13 @@ "skip_tokenizer_init": False, "load_format": "auto", }, + "colocated": { + "enabled": True, + "resources": { + "gpus_per_node": None, + "num_nodes": None, + }, + }, "vllm_kwargs": {}, } diff --git a/tests/unit/utils/test_native_checkpoint.py b/tests/unit/utils/test_native_checkpoint.py index 68a1f3e217..f77a14914f 100755 --- a/tests/unit/utils/test_native_checkpoint.py +++ b/tests/unit/utils/test_native_checkpoint.py @@ -66,7 +66,9 @@ }, "max_grad_norm": 1.0, "generation": { + "backend": "vllm", "temperature": 1.0, + "colocated": {"enabled": True}, }, } From 2c7f7658501fcd81f80a2875f32546e9a98c8e5c Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Wed, 11 Jun 2025 17:21:44 -0700 Subject: [PATCH 6/8] update comment Signed-off-by: Yuki Huang --- examples/configs/grpo-deepscaler-1.5b-8K.yaml | 4 +-- examples/configs/grpo_math_1B.yaml | 4 +-- nemo_rl/algorithms/grpo.py | 27 ++++++++++++------- 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/examples/configs/grpo-deepscaler-1.5b-8K.yaml b/examples/configs/grpo-deepscaler-1.5b-8K.yaml index 3055a106f9..67086f4929 100644 --- a/examples/configs/grpo-deepscaler-1.5b-8K.yaml +++ b/examples/configs/grpo-deepscaler-1.5b-8K.yaml @@ -107,8 +107,8 @@ policy: enabled: true # only relevant when enabled is false resources: - gpus_per_node: null # used when cluster.num_nodes is 1 - num_nodes: null # used when cluster.num_nodes > 1 + gpus_per_node: null # Decides num gpus to be dedicated to generation when there is one node in the cluster i.e cluster.num_nodes == 1 + num_nodes: null # Decides number of nodes to be dedicated to generation data: max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index b0d240bdc7..33ce350028 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -112,8 +112,8 @@ policy: enabled: true # only relevant when enabled is false resources: - gpus_per_node: null # used when cluster.num_nodes is 1 - num_nodes: null # used when cluster.num_nodes > 1 + gpus_per_node: null # Decides num gpus to be dedicated to generation when there is one node in the cluster i.e cluster.num_nodes == 1 + num_nodes: null # Decides number of nodes to be dedicated to generation data: max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 8b576089b6..d3116bf2ef 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -234,7 +234,8 @@ def setup( else: assert generation_config["backend"] != "hf", ( - "Non-colocated inference is not supported for HF backend" + "Non-colocated inference is not supported for HF generation backend. " + "Please use vLLM backend for generation." ) # train resources will be updated through overall and inference resources below @@ -247,23 +248,31 @@ def setup( # validate and configure resources if cluster_config["num_nodes"] == 1: - assert inference_gpus_per_node > 0 and ( - inference_nodes is None or inference_nodes == 1 - ), ( - "policy.generation.colocated.resources.gpus_per_node must be set and " + assert inference_gpus_per_node > 0, ( + "policy.generation.colocated.resources.gpus_per_node must be > 0 " + "when cluster.num_nodes = 1 and inference is non-colocated, " + f"but got {inference_gpus_per_node}." + ) + assert inference_nodes is None or inference_nodes == 1, ( "policy.generation.colocated.resources.num_nodes must be 1 or set to null " - "when cluster.num_nodes = 1 and inference is non-colocated" + "when cluster.num_nodes = 1 and inference is non-colocated, " + f"but got {inference_nodes}." ) inference_nodes = 1 train_gpus_per_node -= inference_gpus_per_node else: - assert inference_nodes > 0 and ( + assert inference_nodes > 0, ( + "policy.generation.colocated.resources.num_nodes must be > 0 " + "when cluster.num_nodes > 1 and inference is non-colocated, " + f"but got {inference_nodes}." + ) + assert ( inference_gpus_per_node is None or inference_gpus_per_node == cluster_config["gpus_per_node"] ), ( - "policy.generation.colocated.resources.num_nodes must be set and " "policy.generation.colocated.resources.gpus_per_node must be equal to cluster.gpus_per_node or set to null " - "when cluster.num_nodes > 1 and inference is non-colocated" + "when cluster.num_nodes > 1 and inference is non-colocated, " + f"but got {inference_gpus_per_node}." ) inference_gpus_per_node = cluster_config["gpus_per_node"] train_nodes -= inference_nodes From baf3f298e6b85a038a5bd451b8f71ffddd36f6bd Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Thu, 12 Jun 2025 03:39:22 +0000 Subject: [PATCH 7/8] dtensor worker use vllm Signed-off-by: Yuki Huang --- nemo_rl/distributed/ray_actor_environment_registry.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nemo_rl/distributed/ray_actor_environment_registry.py b/nemo_rl/distributed/ray_actor_environment_registry.py index 990ac46a16..697ef88826 100644 --- a/nemo_rl/distributed/ray_actor_environment_registry.py +++ b/nemo_rl/distributed/ray_actor_environment_registry.py @@ -16,7 +16,9 @@ ACTOR_ENVIRONMENT_REGISTRY: dict[str, str] = { "nemo_rl.models.generation.vllm.VllmGenerationWorker": PY_EXECUTABLES.VLLM, - "nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker": PY_EXECUTABLES.BASE, + # Temporary workaround for the coupled implementation of DTensorPolicyWorker and vLLM. + # This will be reverted to PY_EXECUTABLES.BASE once https://github.com/NVIDIA/NeMo-RL/issues/501 is resolved. + "nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker": PY_EXECUTABLES.VLLM, "nemo_rl.models.policy.fsdp1_policy_worker.FSDP1PolicyWorker": PY_EXECUTABLES.BASE, "nemo_rl.environments.math_environment.MathEnvironment": PY_EXECUTABLES.SYSTEM, "nemo_rl.environments.games.sliding_puzzle.SlidingPuzzleEnv": PY_EXECUTABLES.SYSTEM, From 21b61938fc76ff71df833f6dcc0b383d856781d5 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Thu, 12 Jun 2025 03:49:12 +0000 Subject: [PATCH 8/8] add functional test Signed-off-by: Yuki Huang --- .github/workflows/cicd-main.yml | 1 + tests/functional/grpo_non_colocated.sh | 42 ++++++++++++++++++++++++++ 2 files changed, 43 insertions(+) create mode 100755 tests/functional/grpo_non_colocated.sh diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index e466590504..5b8939184b 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -188,6 +188,7 @@ jobs: time uv run --no-sync bash ./tests/functional/sft.sh time uv run --no-sync bash ./tests/functional/grpo.sh time uv run --no-sync bash ./tests/functional/grpo_multiturn.sh + time uv run --no-sync bash ./tests/functional/grpo_non_colocated.sh time uv run --no-sync bash ./tests/functional/dpo.sh time uv run --no-sync bash ./tests/functional/eval.sh time uv run --no-sync bash ./tests/functional/test_mcore_extra_installed_correctly.sh diff --git a/tests/functional/grpo_non_colocated.sh b/tests/functional/grpo_non_colocated.sh new file mode 100755 index 0000000000..2067779fd4 --- /dev/null +++ b/tests/functional/grpo_non_colocated.sh @@ -0,0 +1,42 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) +# Mark the current repo as safe, since wandb fetches metadata about the repo +git config --global --add safe.directory $PROJECT_ROOT + +set -eou pipefail + +EXP_NAME=$(basename $0 .sh) +EXP_DIR=$SCRIPT_DIR/$EXP_NAME +LOG_DIR=$EXP_DIR/logs +JSON_METRICS=$EXP_DIR/metrics.json +RUN_LOG=$EXP_DIR/run.log +export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-} + +rm -rf $EXP_DIR $LOG_DIR +mkdir -p $EXP_DIR $LOG_DIR + +cd $PROJECT_ROOT +uv run $PROJECT_ROOT/examples/run_grpo_math.py \ + policy.model_name=Qwen/Qwen3-0.6B \ + grpo.num_prompts_per_step=2 \ + grpo.num_generations_per_prompt=4 \ + policy.train_global_batch_size=4 \ + policy.train_micro_batch_size=1 \ + policy.generation.colocated.enabled=false \ + policy.generation.colocated.resources.gpus_per_node=1 \ + cluster.gpus_per_node=2 \ + grpo.max_num_steps=2 \ + logger.tensorboard_enabled=true \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=false \ + checkpointing.enabled=false \ + $@ \ + 2>&1 | tee $RUN_LOG + +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +uv run tests/check_metrics.py $JSON_METRICS \ + 'max(data["train/token_mult_prob_error"]) < 1.05' \ +