diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index e466590504..5b8939184b 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -188,6 +188,7 @@ jobs: time uv run --no-sync bash ./tests/functional/sft.sh time uv run --no-sync bash ./tests/functional/grpo.sh time uv run --no-sync bash ./tests/functional/grpo_multiturn.sh + time uv run --no-sync bash ./tests/functional/grpo_non_colocated.sh time uv run --no-sync bash ./tests/functional/dpo.sh time uv run --no-sync bash ./tests/functional/eval.sh time uv run --no-sync bash ./tests/functional/test_mcore_extra_installed_correctly.sh diff --git a/examples/configs/grpo-deepscaler-1.5b-8K.yaml b/examples/configs/grpo-deepscaler-1.5b-8K.yaml index eef1c4e205..67086f4929 100644 --- a/examples/configs/grpo-deepscaler-1.5b-8K.yaml +++ b/examples/configs/grpo-deepscaler-1.5b-8K.yaml @@ -101,6 +101,14 @@ policy: # For most cases, use "dummy" to load the initial weights, since they will be overwritten during refit # For Gemma models, we need to use "auto" due to a vllm bug load_format: dummy + colocated: + # true: generation shares training GPUs + # false: uses dedicated generation resources + enabled: true + # only relevant when enabled is false + resources: + gpus_per_node: null # Decides num gpus to be dedicated to generation when there is one node in the cluster i.e cluster.num_nodes == 1 + num_nodes: null # Decides number of nodes to be dedicated to generation data: max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 2f9b6535e4..33ce350028 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -106,6 +106,14 @@ policy: pipeline_parallel_size: 1 gpu_memory_utilization: 0.6 max_model_len: ${policy.max_total_sequence_length} + colocated: + # true: generation shares training GPUs + # false: uses dedicated generation resources + enabled: true + # only relevant when enabled is false + resources: + gpus_per_node: null # Decides num gpus to be dedicated to generation when there is one node in the cluster i.e cluster.num_nodes == 1 + num_nodes: null # Decides number of nodes to be dedicated to generation data: max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len diff --git a/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml b/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml index a737c5e35e..f7079fe4a8 100644 --- a/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml +++ b/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml @@ -88,6 +88,11 @@ policy: pipeline_parallel_size: 1 gpu_memory_utilization: 0.6 max_model_len: 512 + colocated: + enabled: true + resources: + gpus_per_node: null + num_nodes: null data: max_input_seq_length: 512 prompt_file: examples/prompts/cot.txt diff --git a/examples/configs/recipes/llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml b/examples/configs/recipes/llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml index ac0627f487..9b4d7ffc84 100644 --- a/examples/configs/recipes/llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml +++ b/examples/configs/recipes/llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml @@ -89,6 +89,11 @@ policy: pipeline_parallel_size: 1 gpu_memory_utilization: 0.6 max_model_len: 16384 + colocated: + enabled: true + resources: + gpus_per_node: null + num_nodes: null data: max_input_seq_length: 16384 prompt_file: examples/prompts/cot.txt diff --git a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml index af375d4f14..65a909ce80 100644 --- a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml +++ b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml @@ -89,6 +89,11 @@ policy: pipeline_parallel_size: 1 gpu_memory_utilization: 0.6 max_model_len: 4096 + colocated: + enabled: true + resources: + gpus_per_node: null + num_nodes: null data: max_input_seq_length: 4096 prompt_file: examples/prompts/cot.txt diff --git a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml index 90cbcdc8ff..5e2a98c022 100644 --- a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml +++ b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml @@ -89,6 +89,11 @@ policy: pipeline_parallel_size: 1 gpu_memory_utilization: 0.6 max_model_len: 512 + colocated: + enabled: true + resources: + gpus_per_node: null + num_nodes: null data: max_input_seq_length: 512 prompt_file: examples/prompts/cot.txt diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt-long.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt-long.v3.yaml index 6ec2201eaf..5b53de160a 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt-long.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt-long.v3.yaml @@ -89,6 +89,11 @@ policy: pipeline_parallel_size: 1 gpu_memory_utilization: 0.6 max_model_len: 16384 + colocated: + enabled: true + resources: + gpus_per_node: null + num_nodes: null data: max_input_seq_length: 16384 prompt_file: examples/prompts/cot.txt diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt.v3.yaml index 01fbb27245..382d8b80f3 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt.v3.yaml @@ -89,6 +89,11 @@ policy: pipeline_parallel_size: 1 gpu_memory_utilization: 0.6 max_model_len: 16384 + colocated: + enabled: true + resources: + gpus_per_node: null + num_nodes: null data: max_input_seq_length: 16384 prompt_file: examples/prompts/cot.txt diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp1.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp1.v3.yaml index 63dc2a85c3..0db76b931d 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp1.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp1.v3.yaml @@ -86,6 +86,11 @@ policy: pipeline_parallel_size: 1 gpu_memory_utilization: 0.6 max_model_len: 4096 + colocated: + enabled: true + resources: + gpus_per_node: null + num_nodes: null data: max_input_seq_length: 4096 prompt_file: examples/prompts/cot.txt diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml index 05ff7233a6..6196a7f3b4 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml @@ -89,6 +89,11 @@ policy: pipeline_parallel_size: 1 gpu_memory_utilization: 0.6 max_model_len: 4096 + colocated: + enabled: true + resources: + gpus_per_node: null + num_nodes: null data: max_input_seq_length: 4096 prompt_file: examples/prompts/cot.txt diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml index ecd3f1d417..38d774bd76 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml @@ -89,6 +89,11 @@ policy: pipeline_parallel_size: 1 gpu_memory_utilization: 0.6 max_model_len: 512 + colocated: + enabled: true + resources: + gpus_per_node: null + num_nodes: null data: max_input_seq_length: 512 prompt_file: examples/prompts/cot.txt diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index a8428bdce2..d3116bf2ef 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -13,9 +13,10 @@ # limitations under the License. import os from pathlib import Path -from typing import Any, Optional, TypedDict, cast +from typing import Any, Optional, Tuple, TypedDict, cast import numpy as np +import ray import torch from torchdata.stateful_dataloader import StatefulDataLoader from transformers import PreTrainedTokenizerBase @@ -37,7 +38,10 @@ get_keys_from_message_log, ) from nemo_rl.distributed.batched_data_dict import BatchedDataDict -from nemo_rl.distributed.virtual_cluster import ClusterConfig, RayVirtualCluster +from nemo_rl.distributed.virtual_cluster import ( + ClusterConfig, + RayVirtualCluster, +) from nemo_rl.environments.interfaces import ( EnvironmentInterface, ) @@ -119,7 +123,7 @@ def setup( ) -> tuple[ ColocatablePolicyInterface, Optional[GenerationInterface], - RayVirtualCluster, + Tuple[RayVirtualCluster, RayVirtualCluster], StatefulDataLoader, Optional[StatefulDataLoader], ClippedPGLossFn, @@ -137,7 +141,6 @@ def setup( policy_config = master_config["policy"] generation_config = master_config["policy"]["generation"] loss_config = master_config["loss_fn"] - data_config = master_config["data"] grpo_config = master_config["grpo"] logger_config = master_config["logger"] cluster_config = master_config["cluster"] @@ -212,16 +215,91 @@ def setup( # Cluster # ========================== print("\n▶ Setting up compute cluster...") - colocated_inference = generation_config["backend"] != "hf" - cluster = RayVirtualCluster( - name="grpo_policy_cluster", - bundle_ct_per_node_list=[cluster_config["gpus_per_node"]] - * cluster_config["num_nodes"], - use_gpus=True, - num_gpus_per_node=cluster_config["gpus_per_node"], - max_colocated_worker_groups=2 if colocated_inference else 1, - ) - print(f" ✓ Ray cluster initialized with {cluster_config['num_nodes']} nodes") + colocated_inference = generation_config["colocated"]["enabled"] + + if colocated_inference: + cluster = RayVirtualCluster( + name="grpo_policy_cluster", + bundle_ct_per_node_list=[cluster_config["gpus_per_node"]] + * cluster_config["num_nodes"], + use_gpus=True, + num_gpus_per_node=cluster_config["gpus_per_node"], + max_colocated_worker_groups=1 + if generation_config["backend"] == "hf" + else 2, + ) + train_cluster = cluster + inference_cluster = cluster + print(f" ✓ Ray cluster initialized with {cluster_config['num_nodes']} nodes") + + else: + assert generation_config["backend"] != "hf", ( + "Non-colocated inference is not supported for HF generation backend. " + "Please use vLLM backend for generation." + ) + + # train resources will be updated through overall and inference resources below + train_gpus_per_node = cluster_config["gpus_per_node"] + train_nodes = cluster_config["num_nodes"] + + inference_resources = generation_config["colocated"]["resources"] + inference_gpus_per_node = inference_resources["gpus_per_node"] + inference_nodes = inference_resources["num_nodes"] + + # validate and configure resources + if cluster_config["num_nodes"] == 1: + assert inference_gpus_per_node > 0, ( + "policy.generation.colocated.resources.gpus_per_node must be > 0 " + "when cluster.num_nodes = 1 and inference is non-colocated, " + f"but got {inference_gpus_per_node}." + ) + assert inference_nodes is None or inference_nodes == 1, ( + "policy.generation.colocated.resources.num_nodes must be 1 or set to null " + "when cluster.num_nodes = 1 and inference is non-colocated, " + f"but got {inference_nodes}." + ) + inference_nodes = 1 + train_gpus_per_node -= inference_gpus_per_node + else: + assert inference_nodes > 0, ( + "policy.generation.colocated.resources.num_nodes must be > 0 " + "when cluster.num_nodes > 1 and inference is non-colocated, " + f"but got {inference_nodes}." + ) + assert ( + inference_gpus_per_node is None + or inference_gpus_per_node == cluster_config["gpus_per_node"] + ), ( + "policy.generation.colocated.resources.gpus_per_node must be equal to cluster.gpus_per_node or set to null " + "when cluster.num_nodes > 1 and inference is non-colocated, " + f"but got {inference_gpus_per_node}." + ) + inference_gpus_per_node = cluster_config["gpus_per_node"] + train_nodes -= inference_nodes + + # initialize train cluster + train_cluster = RayVirtualCluster( + name="grpo_train_cluster", + bundle_ct_per_node_list=[train_gpus_per_node] * train_nodes, + use_gpus=True, + num_gpus_per_node=train_gpus_per_node, + max_colocated_worker_groups=1, + ) + print( + f" ✓ Ray train cluster initialized with {train_nodes} nodes with {train_gpus_per_node} GPUs per node" + ) + + # initialize inference cluster + inference_cluster = RayVirtualCluster( + name="grpo_inference_cluster", + bundle_ct_per_node_list=[inference_gpus_per_node] * inference_nodes, + use_gpus=True, + num_gpus_per_node=inference_gpus_per_node, + max_colocated_worker_groups=1, + ) + print( + f" ✓ Ray inference cluster initialized with {inference_nodes} nodes with {inference_gpus_per_node} GPUs per node" + ) # ========================== # Training and Inference @@ -237,7 +315,9 @@ def setup( print(f" ✓ Using HF backend for generation with {policy_config['model_name']}") elif backend == "vllm": generation_config = cast(VllmConfig, generation_config) - policy_generation = VllmGeneration(cluster=cluster, config=generation_config) + policy_generation = VllmGeneration( + cluster=inference_cluster, config=generation_config + ) # Worker groups are not initialized until the first call to run something on workergroups. # vllm 0.8 fails in initialization if its called in the first training step since it has no clean view of the GPU memory (HF is sharing the same memory). policy_generation.finish_generation() @@ -246,7 +326,7 @@ def setup( ) policy = HfPolicy( - cluster=cluster, + cluster=train_cluster, config=policy_config, tokenizer=tokenizer, weights_path=Path(last_checkpoint_path) / "policy" / "weights" @@ -258,6 +338,18 @@ def setup( init_optimizer=True, ) + # if it is not colocated inference, initialize collective communication for update weights + if not colocated_inference: + ip, port = train_cluster.get_master_address_and_port() + print(f"Using ip: {ip}, port: {port} for collective communication") + # inference cluster + head node of the train cluster + world_size = inference_nodes * inference_gpus_per_node + 1 + # init collective + futures_train = policy.init_collective(ip, port, world_size) + futures_inference = policy_generation.init_collective(ip, port, world_size) # type: ignore + # wait for all futures to complete + ray.get(futures_train + futures_inference) + loss_fn = ClippedPGLossFn(loss_config) print("\n" + "=" * 60) @@ -267,7 +359,7 @@ def setup( return ( policy, policy_generation, - cluster, + (train_cluster, inference_cluster), dataloader, val_dataloader, loss_fn, @@ -286,6 +378,7 @@ def setup( def refit_policy_generation( policy: ColocatablePolicyInterface, policy_generation: GenerationInterface, + colocated_inference: bool, _refit_buffer_size_gb: Optional[int] = None, ) -> None: """Refit the policy generation interface with the latest policy weights. @@ -297,24 +390,49 @@ def refit_policy_generation( If it is None, the buffer size will be computed by the remaining memory. This parameter is primarily used for testing. """ - policy.offload_before_refit() - policy_generation.prepare_for_generation(tags=["weights"]) - # get model param keys, which is grouped by size - grouped_param_keys = policy.prepare_weights_for_ipc( - _refit_buffer_size_gb=_refit_buffer_size_gb - ) - # do update - for keys in grouped_param_keys: - ipc_handles = policy.get_weights_ipc_handles(keys) - if not policy_generation.update_weights(ipc_handles): - error_message = ( - "❌ Error: Updating weights for the generation policy failed during refit.\n" - "This often indicates an issue with cuda-ipc or " - "a problem within the generation backend (e.g., vLLM worker).\n" - ) - raise RuntimeError(error_message) - policy.offload_after_refit() - policy_generation.prepare_for_generation(tags=["kv_cache"]) + if colocated_inference: + policy.offload_before_refit() + policy_generation.prepare_for_generation(tags=["weights"]) + + # update weights + update_success = False + if colocated_inference: + # get model param keys, which is grouped by size + grouped_param_keys = policy.prepare_weights_for_ipc( + _refit_buffer_size_gb=_refit_buffer_size_gb + ) + # do update + for keys in grouped_param_keys: + ipc_handles = policy.get_weights_ipc_handles(keys) + update_success = policy_generation.update_weights(ipc_handles) + if not update_success: + break + else: + # prepare info for update weights + state_dict_info = policy.prepare_info_for_collective() + # update weights through nccl + futures_train = policy.broadcast_weights_for_collective() + futures_inference = policy_generation.update_weights_from_collective( + state_dict_info + ) + # wait for all futures to complete + ray.get(futures_train) + results = ray.get(futures_inference) + update_success = all(result for result in results if result is not None) + + # check if update is successful + if not update_success: + error_tag = "cuda-ipc" if colocated_inference else "nccl" + error_message = ( + "❌ Error: Updating weights for the generation policy failed during refit.\n" + f"This often indicates an issue with {error_tag} or " + "a problem within the generation backend (e.g., vLLM worker).\n" + ) + raise RuntimeError(error_message) + + if colocated_inference: + policy.offload_after_refit() + policy_generation.prepare_for_generation(tags=["kv_cache"]) # =============================================================================== @@ -351,12 +469,13 @@ def grpo_train( consumed_samples = grpo_save_state["consumed_samples"] val_period = master_config["grpo"]["val_period"] val_at_start = master_config["grpo"]["val_at_start"] + colocated_inference = master_config["policy"]["generation"]["colocated"]["enabled"] # Run validation at the start if configured if val_at_start and step == 0: print("\n🔍 Running initial validation...") if NEED_REFIT and POLICY_GENERATION_STALE: - refit_policy_generation(policy, policy_generation) + refit_policy_generation(policy, policy_generation, colocated_inference) POLICY_GENERATION_STALE = False else: policy_generation.prepare_for_generation() @@ -399,7 +518,9 @@ def grpo_train( print(f"▶ Generating responses for batch of size {repeated_batch.size}...") with timer.time("prepare_for_generation"): if NEED_REFIT and POLICY_GENERATION_STALE: - refit_policy_generation(policy, policy_generation) + refit_policy_generation( + policy, policy_generation, colocated_inference + ) POLICY_GENERATION_STALE = False else: policy_generation.prepare_for_generation() @@ -511,7 +632,9 @@ def grpo_train( # Run validation if it's a validation step if is_last_step or (val_period > 0 and (step + 1) % val_period == 0): if NEED_REFIT and POLICY_GENERATION_STALE: - refit_policy_generation(policy, policy_generation) + refit_policy_generation( + policy, policy_generation, colocated_inference + ) POLICY_GENERATION_STALE = False else: policy_generation.prepare_for_generation() diff --git a/nemo_rl/distributed/ray_actor_environment_registry.py b/nemo_rl/distributed/ray_actor_environment_registry.py index 990ac46a16..697ef88826 100644 --- a/nemo_rl/distributed/ray_actor_environment_registry.py +++ b/nemo_rl/distributed/ray_actor_environment_registry.py @@ -16,7 +16,9 @@ ACTOR_ENVIRONMENT_REGISTRY: dict[str, str] = { "nemo_rl.models.generation.vllm.VllmGenerationWorker": PY_EXECUTABLES.VLLM, - "nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker": PY_EXECUTABLES.BASE, + # Temporary workaround for the coupled implementation of DTensorPolicyWorker and vLLM. + # This will be reverted to PY_EXECUTABLES.BASE once https://github.com/NVIDIA/NeMo-RL/issues/501 is resolved. + "nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker": PY_EXECUTABLES.VLLM, "nemo_rl.models.policy.fsdp1_policy_worker.FSDP1PolicyWorker": PY_EXECUTABLES.BASE, "nemo_rl.environments.math_environment.MathEnvironment": PY_EXECUTABLES.SYSTEM, "nemo_rl.environments.games.sliding_puzzle.SlidingPuzzleEnv": PY_EXECUTABLES.SYSTEM, diff --git a/nemo_rl/models/generation/interfaces.py b/nemo_rl/models/generation/interfaces.py index 8e7a0c1a7e..c670ec16b2 100644 --- a/nemo_rl/models/generation/interfaces.py +++ b/nemo_rl/models/generation/interfaces.py @@ -14,6 +14,7 @@ from abc import ABC, abstractmethod from typing import Any, NotRequired, Optional, TypedDict, Union +import ray import torch from nemo_rl.distributed.batched_data_dict import BatchedDataDict @@ -196,6 +197,13 @@ class GenerationOutputSpec(TypedDict): class GenerationInterface(ABC): """Abstract base class defining the interface for RL policies.""" + @abstractmethod + def init_collective( + self, ip: str, port: int, world_size: int + ) -> list[ray.ObjectRef]: + """Initialize the collective communication.""" + pass + @abstractmethod def generate( self, data: BatchedDataDict["GenerationDatumSpec"], greedy: bool @@ -213,3 +221,9 @@ def finish_generation(self, *args: Any, **kwargs: Any) -> bool: def update_weights(self, ipc_handles: dict[str, Any]) -> bool: """Update the model weights from the given IPC handles.""" raise NotImplementedError + + def update_weights_from_collective( + self, info: dict[str, Any] + ) -> list[ray.ObjectRef]: + """Update the model weights from collective communication.""" + raise NotImplementedError diff --git a/nemo_rl/models/generation/vllm.py b/nemo_rl/models/generation/vllm.py index 0287fb8b63..e620b3bd92 100644 --- a/nemo_rl/models/generation/vllm.py +++ b/nemo_rl/models/generation/vllm.py @@ -342,6 +342,19 @@ def _patch_vllm_init_workers_ray(): else: self.llm = vllm.LLM(**llm_kwargs) + def init_collective( + self, rank_prefix: int, ip: str, port: int, world_size: int + ) -> None: + self.llm.collective_rpc( + "init_collective", + args=( + rank_prefix, + ip, + port, + world_size, + ), + ) + def llm(self): return self.llm @@ -925,6 +938,42 @@ async def update_weights_from_ipc_handles_async( traceback.print_exc() return False + def update_weights_from_collective(self, info: dict[str, Any]) -> bool: + """Update the model weights from collective communication.""" + try: + assert self.llm is not None, ( + "Attempting to update weights with either an uninitialized vLLM or non-model-owner" + ) + + if self.cfg["vllm_cfg"]["async_engine"]: + raise RuntimeError( + "update_weights_from_collective cannot be used with async_engine=True. Use update_weights_from_ipc_handles_async instead." + ) + + result_or_coro = self.llm.collective_rpc( + "update_weights_from_collective", args=(info,) + ) + worker_result = result_or_coro[0] + + if not worker_result: + print( + f"Error: Worker failed to update weights. Result: {worker_result}" + ) + return False + return True + except Exception as e: + print(f"Exception during collective_rpc for weight update: {e}") + import traceback + + traceback.print_exc() + return False + + def reset_prefix_cache(self): + """Reset the prefix cache of vLLM engine.""" + self.llm.llm_engine.reset_prefix_cache() + gc.collect() + torch.cuda.empty_cache() + def sleep(self): """Put the vLLM engine to sleep.""" assert self.llm is not None, ( @@ -1206,6 +1255,29 @@ def _report_device_id(self) -> list[list[str]]: results = ray.get(futures) return results + def init_collective( + self, ip: str, port: int, world_size: int + ) -> list[ray.ObjectRef]: + """Initialize the collective communication.""" + if not self.worker_group or not self.worker_group.workers: + raise RuntimeError("Worker group is not initialized") + + # Prepare rank + total_workers = len(self.worker_group.workers) + workers_per_group = len(self.worker_group.tied_workers_groups[0]) + rank_prefix_list = list(range(0, total_workers, workers_per_group)) + + # Send world_size and rank for init collective to all workers + futures = self.worker_group.run_all_workers_multiple_data( + "init_collective", + data=rank_prefix_list, + run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"], + common_kwargs={"ip": ip, "port": port, "world_size": world_size}, + ) + + # this function should co-work with hf_policy, so we should wait for all futures to complete outside + return futures + def generate( self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False ) -> BatchedDataDict[GenerationOutputSpec]: @@ -1349,7 +1421,11 @@ def generate_async( return combined def prepare_for_generation(self, *args: Any, **kwargs: Any) -> bool: - """Wake workers up.""" + """Wake workers up for colocated inference.""" + # non-colocated no need to wake up + if not self.cfg["colocated"]["enabled"]: + return True + try: # Choose the appropriate method based on async_engine setting method_name = ( @@ -1369,12 +1445,16 @@ def prepare_for_generation(self, *args: Any, **kwargs: Any) -> bool: return False def finish_generation(self, *args: Any, **kwargs: Any) -> bool: - """Sleep workers.""" + """Sleep workers and reset prefix cache.""" try: - # Choose the appropriate method based on async_engine setting - method_name = ( - "sleep_async" if self.cfg["vllm_cfg"]["async_engine"] else "sleep" - ) + # Choose the appropriate method based on setting + # non-colocated only needs reset prefix cache, no need to sleep. + if not self.cfg["colocated"]["enabled"]: + method_name = "reset_prefix_cache" + else: + method_name = ( + "sleep_async" if self.cfg["vllm_cfg"]["async_engine"] else "sleep" + ) # Use run_all_workers_single_data for methods that don't need data futures = self.worker_group.run_all_workers_single_data( method_name, @@ -1437,9 +1517,26 @@ def update_weights(self, ipc_handles: dict[str, Any]) -> bool: results = ray.get(futures) return all(result for result in results if result is not None) except Exception as e: - print(f"Error updating weights: {e}") + print(f"Error during update weights: {e}") return False + def update_weights_from_collective( + self, info: dict[str, Any] + ) -> list[ray.ObjectRef]: + """Update weights of the policy using collective communication.""" + if not self.worker_group or not self.worker_group.workers: + raise RuntimeError("Worker group is not initialized") + + # Use run_all_workers_single_data to send data to all workers + futures = self.worker_group.run_all_workers_single_data( + "update_weights_from_collective", + info=info, + run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"], + ) + + # this function should co-work with hf_policy, so we should wait for all futures to complete outside + return futures + def __del__(self) -> None: """Shuts down the worker groups when the object is deleted or is garbage collected. diff --git a/nemo_rl/models/generation/vllm_backend.py b/nemo_rl/models/generation/vllm_backend.py index 5305fd6bce..c40aea4418 100644 --- a/nemo_rl/models/generation/vllm_backend.py +++ b/nemo_rl/models/generation/vllm_backend.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any + import torch try: @@ -24,6 +26,21 @@ class VllmInternalWorkerExtension: + def init_collective( + self, rank_prefix: int, ip: str, port: int, world_size: int + ) -> None: + """Initialize the collective communication.""" + from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator + from vllm.distributed.utils import StatelessProcessGroup + + local_rank = torch.distributed.get_rank() + rank = rank_prefix + local_rank + 1 # 1 is the head node of the train cluster + + pg = StatelessProcessGroup.create( + host=ip, port=port, rank=rank, world_size=world_size + ) + self.model_update_group = PyNcclCommunicator(pg, device=self.device) + def report_device_id(self) -> str: from nemo_rl.utils.nvml import get_device_uuid @@ -63,3 +80,18 @@ def update_weights_from_ipc_handles(self, ipc_handles): f"Error in VllmInternalWorkerExtension.update_weights_from_ipc_handles: {e}" ) return False + + def update_weights_from_collective(self, info: dict[str, Any]) -> bool: + """Update the model weights from collective communication.""" + try: + for name, (shape, dtype) in info.items(): + weight = torch.empty(shape, dtype=dtype, device="cuda") + self.model_update_group.broadcast(weight, src=0) + self.model_runner.model.load_weights(weights=[(name, weight)]) + except Exception as e: + print( + f"Error in VllmInternalWorkerExtension.update_weights_from_collective: {e}" + ) + return False + + return True diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index d13fc23c67..673cf3df20 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -128,7 +128,7 @@ def __init__( self.cfg = config # torch distributed init. Envars for rank, world_size, and master_addr and master_port are set from the ray remote call torch.distributed.init_process_group(backend="nccl") - rank = torch.distributed.get_rank() + self.rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() model_name = self.cfg["model_name"] @@ -144,7 +144,7 @@ def __init__( else: raise ValueError(f"Unknown precision: {self.cfg['precision']}") - print(f"[Rank {rank}] Loading model {model_name} on CPU...") + print(f"[Rank {self.rank}] Loading model {model_name} on CPU...") self.model = AutoModelForCausalLM.from_pretrained( model_name, device_map="cpu", # load weights onto CPU initially @@ -262,6 +262,23 @@ def __init__( "No weights path provided. Starting from scratch (default policy init)" ) + def init_collective(self, ip: str, port: int, world_size: int) -> None: + """Initialize the collective communication.""" + from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator + from vllm.distributed.utils import StatelessProcessGroup + + # keep the same behavior as vllm + # see https://github.com/vllm-project/vllm/blob/v0.8.5/vllm/env_override.py#L25 + if not os.path.exists("/dev/nvidia-caps-imex-channels"): + os.environ["NCCL_CUMEM_ENABLE"] = "0" + + if self.rank == 0: + pg = StatelessProcessGroup.create( + host=ip, port=port, rank=0, world_size=world_size + ) + device = torch.cuda.current_device() + self.model_update_group = PyNcclCommunicator(pg, device=device) + def is_alive(self) -> bool: return True @@ -753,6 +770,34 @@ def get_weights_ipc_handles(self, keys: Iterable[str]) -> dict[str, Any]: return {device_uuid: all_handles} + @torch.no_grad() + def prepare_info_for_collective(self) -> dict[str, Any]: + """Prepare the info for collective communication. + + Returns: + dict: A dictionary containing the info for collective communication. + """ + # Get state_dict + self.model = self.move_to_cuda(self.model) + state_dict = self.model.state_dict() + + # Collect info for collective communication + state_dict_info = {} + for name, tensor in state_dict.items(): + state_dict_info[name] = (tensor.shape, self.dtype) + + return state_dict_info + + @torch.no_grad() + def broadcast_weights_for_collective(self) -> None: + """Broadcast the weights for collective communication.""" + for _, tensor in self.model.state_dict().items(): + if isinstance(tensor, DTensor): + tensor = tensor.full_tensor() + if self.rank == 0: + tensor = tensor.to(self.dtype, non_blocking=True) + self.model_update_group.broadcast(tensor.data, src=0) + def prepare_for_lp_inference(self) -> None: if not self.cpu_offload: self.move_to_cuda(self.model) diff --git a/nemo_rl/models/policy/hf_policy.py b/nemo_rl/models/policy/hf_policy.py index 8cfe0a6e0c..b1928c2fc5 100644 --- a/nemo_rl/models/policy/hf_policy.py +++ b/nemo_rl/models/policy/hf_policy.py @@ -119,6 +119,16 @@ def __init__( self.cfg = config + def init_collective( + self, ip: str, port: int, world_size: int + ) -> list[ray.ObjectRef]: + """Initialize the collective communication.""" + futures = self.worker_group.run_all_workers_single_data( + "init_collective", ip=ip, port=port, world_size=world_size + ) + # this function should co-work with vllm, so we should wait for all futures to complete outside + return futures + def get_logprobs( self, data: BatchedDataDict[GenerationDatumSpec] ) -> BatchedDataDict[LogprobOutputSpec]: @@ -402,6 +412,27 @@ def get_weights_ipc_handles(self, keys: list[str]) -> dict[str, Any]: return all_handles + def prepare_info_for_collective(self) -> dict[str, Any]: + """Prepare the info for collective communication. + + Returns: + dict: A dictionary containing the info for collective communication. + """ + futures = self.worker_group.run_all_workers_single_data( + "prepare_info_for_collective" + ) + results = ray.get(futures) + # Only get the first worker's info since all workers will have the same result + return results[0] + + def broadcast_weights_for_collective(self) -> list[ray.ObjectRef]: + """Broadcast the weights for collective communication.""" + futures = self.worker_group.run_all_workers_single_data( + "broadcast_weights_for_collective" + ) + # this function should co-work with vllm, so we should wait for all futures to complete outside + return futures + def offload_before_refit(self) -> None: """Offload the optimizer and buffers to the CPU.""" futures = self.worker_group.run_all_workers_single_data("offload_before_refit") diff --git a/nemo_rl/models/policy/interfaces.py b/nemo_rl/models/policy/interfaces.py index d4b51b0cb3..614340c67b 100644 --- a/nemo_rl/models/policy/interfaces.py +++ b/nemo_rl/models/policy/interfaces.py @@ -14,6 +14,7 @@ from abc import ABC, abstractmethod from typing import Any, TypedDict +import ray import torch from nemo_rl.algorithms.interfaces import LossFunction @@ -93,6 +94,12 @@ def shutdown(self) -> bool: class ColocatablePolicyInterface(PolicyInterface): + @abstractmethod + def init_collective( + self, ip: str, port: int, world_size: int + ) -> list[ray.ObjectRef]: + pass + @abstractmethod def offload_before_refit(self) -> None: pass @@ -108,3 +115,11 @@ def prepare_weights_for_ipc(self, *args: Any, **kwargs: Any) -> list[list[str]]: @abstractmethod def get_weights_ipc_handles(self, keys: list[str]) -> dict[str, Any]: pass + + @abstractmethod + def prepare_info_for_collective(self) -> dict[str, Any]: + pass + + @abstractmethod + def broadcast_weights_for_collective(self) -> list[ray.ObjectRef]: + pass diff --git a/tests/functional/grpo_non_colocated.sh b/tests/functional/grpo_non_colocated.sh new file mode 100755 index 0000000000..2067779fd4 --- /dev/null +++ b/tests/functional/grpo_non_colocated.sh @@ -0,0 +1,42 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) +# Mark the current repo as safe, since wandb fetches metadata about the repo +git config --global --add safe.directory $PROJECT_ROOT + +set -eou pipefail + +EXP_NAME=$(basename $0 .sh) +EXP_DIR=$SCRIPT_DIR/$EXP_NAME +LOG_DIR=$EXP_DIR/logs +JSON_METRICS=$EXP_DIR/metrics.json +RUN_LOG=$EXP_DIR/run.log +export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-} + +rm -rf $EXP_DIR $LOG_DIR +mkdir -p $EXP_DIR $LOG_DIR + +cd $PROJECT_ROOT +uv run $PROJECT_ROOT/examples/run_grpo_math.py \ + policy.model_name=Qwen/Qwen3-0.6B \ + grpo.num_prompts_per_step=2 \ + grpo.num_generations_per_prompt=4 \ + policy.train_global_batch_size=4 \ + policy.train_micro_batch_size=1 \ + policy.generation.colocated.enabled=false \ + policy.generation.colocated.resources.gpus_per_node=1 \ + cluster.gpus_per_node=2 \ + grpo.max_num_steps=2 \ + logger.tensorboard_enabled=true \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=false \ + checkpointing.enabled=false \ + $@ \ + 2>&1 | tee $RUN_LOG + +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +uv run tests/check_metrics.py $JSON_METRICS \ + 'max(data["train/token_mult_prob_error"]) < 1.05' \ + diff --git a/tests/unit/experience/test_rollouts.py b/tests/unit/experience/test_rollouts.py index 2898b262ae..7516b0ba6e 100644 --- a/tests/unit/experience/test_rollouts.py +++ b/tests/unit/experience/test_rollouts.py @@ -230,6 +230,13 @@ def initial_multi_step_calculator_batch(rollout_tokenizer): "disable_log_requests": True, "gpu_memory_utilization": 0.6, }, + "colocated": { + "enabled": True, + "resources": { + "gpus_per_node": None, + "num_nodes": None, + }, + }, } diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index b8225e62ee..b00b3fb6d5 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -14,7 +14,6 @@ import os from copy import deepcopy -from unittest import mock import pytest import ray @@ -23,7 +22,10 @@ from nemo_rl.algorithms.grpo import refit_policy_generation from nemo_rl.algorithms.utils import get_tokenizer from nemo_rl.distributed.batched_data_dict import BatchedDataDict -from nemo_rl.distributed.virtual_cluster import RayVirtualCluster +from nemo_rl.distributed.virtual_cluster import ( + RayVirtualCluster, + _get_node_ip_and_free_port, +) from nemo_rl.models.generation import configure_generation_config from nemo_rl.models.generation.vllm import VllmConfig, VllmGeneration from nemo_rl.models.policy import PolicyConfig @@ -54,6 +56,13 @@ "skip_tokenizer_init": False, "load_format": "auto", }, + "colocated": { + "enabled": True, + "resources": { + "gpus_per_node": None, + "num_nodes": None, + }, + }, "vllm_kwargs": {}, } @@ -100,9 +109,7 @@ def get_basic_hf_test_config(enable_dtensor: bool = False) -> PolicyConfig: }, "max_grad_norm": 1.0, "make_sequence_length_divisible_by": 1, - "generation": { - "temperature": 0.8, - }, + "generation": basic_vllm_test_config, } @@ -311,7 +318,9 @@ async def test_vllm_policy_generation_async( print("creating hf policy...") hf_policy = HfPolicy(cluster, hf_config, tokenizer) - refit_policy_generation(hf_policy, async_policy) + refit_policy_generation( + hf_policy, async_policy, vllm_config["colocated"]["enabled"] + ) outputs = async_policy.generate_async(test_input_data) # Validate outputs format @@ -406,7 +415,7 @@ def test_vllm_worker_seed_behavior(cluster, tokenizer): hf_policy = HfPolicy(cluster, hf_config, tokenizer) print("refitting vllm policy...") - refit_policy_generation(hf_policy, policy) + refit_policy_generation(hf_policy, policy, vllm_config["colocated"]["enabled"]) try: # Generate with duplicated prompts @@ -562,7 +571,9 @@ def test_vllm_generation_with_hf_training( hf_policy = HfPolicy(cluster, hf_config, tokenizer) print("refitting vllm policy...") - refit_policy_generation(hf_policy, vllm_policy) + refit_policy_generation( + hf_policy, vllm_policy, vllm_config["colocated"]["enabled"] + ) # Step 1: Use vLLM for generation print("Using vLLM policy for fast generation...") @@ -914,7 +925,12 @@ def test_vllm_weight_update_memory(cluster, tokenizer, enable_dtensor): # reset peak memory stats before refit workers = hf_policy.worker_group.workers ray.get([w.reset_peak_memory_stats.remote() for w in workers]) - refit_policy_generation(hf_policy, vllm_policy, _refit_buffer_size_gb=1) + refit_policy_generation( + hf_policy, + vllm_policy, + vllm_config["colocated"]["enabled"], + _refit_buffer_size_gb=1, + ) gpu_infos = ray.get([w.get_gpu_info.remote() for w in workers]) # Gather memory stats @@ -981,7 +997,9 @@ def test_vllm_generation_with_stop( hf_policy = HfPolicy(cluster, hf_config, tokenizer) print("refitting vllm policy...") - refit_policy_generation(hf_policy, vllm_generation) + refit_policy_generation( + hf_policy, vllm_generation, vllm_config["colocated"]["enabled"] + ) # test generate outputs = vllm_generation.generate(test_input_data, greedy=True) @@ -1040,7 +1058,7 @@ def test_vllm_non_divisible_batch_handling(policy): ) -def test_vllm_refit_non_collocated_handles_update_failure( +def test_vllm_refit_non_collocated_handles_update( policy_cluster_separate, generation_cluster_separate, tokenizer, @@ -1057,51 +1075,36 @@ def test_vllm_refit_non_collocated_handles_update_failure( # Create HfPolicy on its own cluster hf_config = get_basic_hf_test_config(enable_dtensor=True) hf_config["dtensor_cfg"]["tensor_parallel_size"] = 1 + hf_config["generation"]["colocated"]["enabled"] = False hf_policy = HfPolicy(policy_cluster_separate, hf_config, tokenizer) # Create VllmGeneration policy on its own cluster vllm_config = deepcopy(basic_vllm_test_config) vllm_config = configure_generation_config(vllm_config, tokenizer, is_eval=True) vllm_config["vllm_cfg"]["tensor_parallel_size"] = 1 - vllm_policy = VllmGeneration(generation_cluster_separate, vllm_config) + vllm_config["colocated"]["enabled"] = False + vllm_generation = VllmGeneration(generation_cluster_separate, vllm_config) - hf_policy_instance = None - vllm_policy_instance = None + # initialize collective communication for update weights + ip, port = ray.get(_get_node_ip_and_free_port.remote()) + futures_train = hf_policy.init_collective(ip, port, world_size=2) + futures_inference = vllm_generation.init_collective(ip, port, world_size=2) + ray.get(futures_train + futures_inference) - try: - hf_policy_instance = hf_policy - vllm_policy_instance = vllm_policy - ray.get( - [ - worker._add_noise_to_weights.remote() - for worker in hf_policy_instance.worker_group.workers - ] - ) - print("Refitting vLLM policy from HF policy (non-collocated)") - with mock.patch.object( - vllm_policy_instance, "update_weights", return_value=False - ): - with pytest.raises(RuntimeError): - refit_policy_generation( - hf_policy_instance, - vllm_policy_instance, - ) - print("RuntimeError during refit correctly caught.") + print("refitting vllm policy...") + refit_policy_generation( + hf_policy, vllm_generation, vllm_config["colocated"]["enabled"] + ) - finally: - print("Cleaning up non-collocated test resources...") - if hf_policy_instance: - try: - hf_policy_instance.shutdown() - except Exception as e: - print(f"Error during HfPolicy cleanup: {e}") - if vllm_policy_instance: - try: - vllm_policy_instance.shutdown() - except Exception as e: - print(f"Error during VllmPolicy cleanup: {e}") - # Force garbage collection - import gc + # test generate + outputs = vllm_generation.generate(test_input_data, greedy=True) + output_ids = outputs["output_ids"] + generated_texts = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + assert generated_texts == [ + "Hello, my name is Lina. I'm", + "The capital of France is Paris. The capital of", + ], "Output should be the same as the expected output" - gc.collect() - torch.cuda.empty_cache() + # Clean up + vllm_generation.shutdown() + hf_policy.shutdown() diff --git a/tests/unit/models/generation/test_vllm_large_model.py b/tests/unit/models/generation/test_vllm_large_model.py index 2a860d7e6f..97bc9dc66e 100644 --- a/tests/unit/models/generation/test_vllm_large_model.py +++ b/tests/unit/models/generation/test_vllm_large_model.py @@ -51,6 +51,13 @@ "skip_tokenizer_init": False, "load_format": "auto", }, + "colocated": { + "enabled": True, + "resources": { + "gpus_per_node": None, + "num_nodes": None, + }, + }, "vllm_kwargs": {}, } diff --git a/tests/unit/utils/test_native_checkpoint.py b/tests/unit/utils/test_native_checkpoint.py index 68a1f3e217..f77a14914f 100755 --- a/tests/unit/utils/test_native_checkpoint.py +++ b/tests/unit/utils/test_native_checkpoint.py @@ -66,7 +66,9 @@ }, "max_grad_norm": 1.0, "generation": { + "backend": "vllm", "temperature": 1.0, + "colocated": {"enabled": True}, }, }