From 00bc8c3d39d6d934d6008063832f0cf81a7f4d03 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Thu, 3 Jul 2025 08:06:41 +0000 Subject: [PATCH 1/3] fix env vars of vllm worker Signed-off-by: Yuki Huang --- nemo_rl/distributed/worker_groups.py | 19 ++++++++++++++----- nemo_rl/models/generation/vllm.py | 12 ++++++++---- .../models/policy/dtensor_policy_worker.py | 2 +- 3 files changed, 23 insertions(+), 10 deletions(-) diff --git a/nemo_rl/distributed/worker_groups.py b/nemo_rl/distributed/worker_groups.py index c2e849cbee..b008452f1c 100644 --- a/nemo_rl/distributed/worker_groups.py +++ b/nemo_rl/distributed/worker_groups.py @@ -317,6 +317,7 @@ def __init__( name_prefix: str = "", bundle_indices_list: Optional[list[tuple[int, list[int]]]] = None, sharding_annotations: Optional[NamedSharding] = None, + env_vars: dict[str, str] = {}, ): """Initialize a group of distributed Ray workers. @@ -391,7 +392,7 @@ def __init__( # Create workers based on the bundle_indices_list self._create_workers_from_bundle_indices( - remote_worker_builder, bundle_indices_list + remote_worker_builder, bundle_indices_list, env_vars=env_vars ) def get_dp_leader_worker_idx(self, dp_shard_idx: int) -> int: @@ -407,6 +408,7 @@ def _create_workers_from_bundle_indices( self, remote_worker_builder: RayWorkerBuilder, bundle_indices_list: list[tuple[int, list[int]]], + env_vars: dict[str, str] = {}, ) -> None: """Create workers based on explicit bundle indices for tied worker groups. @@ -421,6 +423,10 @@ def _create_workers_from_bundle_indices( self.cluster.get_master_address_and_port() ) + # Update env_vars with the current environment variables + env_vars.update(dict(os.environ)) + + # Get the python environment for the actor actor_python_env = get_actor_python_env( remote_worker_builder.ray_actor_class_fqn ) @@ -459,8 +465,8 @@ def _create_workers_from_bundle_indices( for local_rank, bundle_idx in enumerate(local_bundle_indices): # Set up basic distributed environment variables - env_vars = dict(os.environ) - env_vars.update( + worker_env_vars = deepcopy(env_vars) + worker_env_vars.update( { "RANK": str(global_rank), "LOCAL_RANK": str(bundle_idx), @@ -470,7 +476,7 @@ def _create_workers_from_bundle_indices( "NODE_RANK": str(pg_idx), } ) - env_vars.pop("RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", None) + worker_env_vars.pop("RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", None) # Only the first worker in each group gets bundle_indices # This ensures only one worker per group is the model owner @@ -494,7 +500,10 @@ def _create_workers_from_bundle_indices( ) # Pass these options to the remote_worker_builder - runtime_env = {"env_vars": env_vars, "py_executable": py_executable} + runtime_env = { + "env_vars": worker_env_vars, + "py_executable": py_executable, + } runtime_env["env_vars"]["VIRTUAL_ENV"] = py_executable runtime_env["env_vars"]["UV_PROJECT_ENVIRONMENT"] = py_executable diff --git a/nemo_rl/models/generation/vllm.py b/nemo_rl/models/generation/vllm.py index 64e97c3314..84948ae77e 100644 --- a/nemo_rl/models/generation/vllm.py +++ b/nemo_rl/models/generation/vllm.py @@ -319,10 +319,6 @@ def _patch_vllm_init_workers_ray(): os.environ["VLLM_USE_V1"] = os.environ.get("NRL_VLLM_USE_V1", "1") os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1" - if not self.cfg["colocated"]["enabled"]: - os.environ["NCCL_SHM_DISABLE"] = "1" - os.environ["NCCL_P2P_DISABLE"] = "1" - load_format = self.cfg["vllm_cfg"]["load_format"] if ModelFlag.VLLM_LOAD_FORMAT_AUTO.matches(self.model_name): load_format = "auto" @@ -1225,6 +1221,12 @@ def __init__( "nemo_rl.models.generation.vllm.VllmGenerationWorker", config ) + # Disable NCCL SHM if training and generation are not co-located: https://github.com/NVIDIA-NeMo/RL/issues/564 + env_vars = {} + if not self.cfg["colocated"]["enabled"]: + env_vars["NCCL_SHM_DISABLE"] = "1" + env_vars["NCCL_P2P_DISABLE"] = "1" + # Check if we need parallelism-aware worker group creation if self.model_parallel_size > 1: # For parallelism, create node-aware worker groups @@ -1236,6 +1238,7 @@ def __init__( name_prefix=name_prefix, bundle_indices_list=node_bundle_indices, sharding_annotations=self.sharding_annotations, + env_vars=env_vars, ) else: # Use standard worker group creation for non-parallel case @@ -1245,6 +1248,7 @@ def __init__( name_prefix=name_prefix, workers_per_node=workers_per_node, sharding_annotations=self.sharding_annotations, + env_vars=env_vars, ) # Number of data parallel groups is the number of tied worker groups diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 6872250d10..b0d5e213a1 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -387,7 +387,7 @@ def init_collective(self, ip: str, port: int, world_size: int) -> None: from vllm.distributed.utils import StatelessProcessGroup # keep the same behavior as vllm - # see https://github.com/vllm-project/vllm/blob/v0.8.5/vllm/env_override.py#L25 + # see https://github.com/vllm-project/vllm/blob/v0.9.0/vllm/env_override.py#L25 if not os.path.exists("/dev/nvidia-caps-imex-channels"): os.environ["NCCL_CUMEM_ENABLE"] = "0" From b517e6e483cd8ecfaedea1ee58c283a596a886f8 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Thu, 3 Jul 2025 08:06:54 +0000 Subject: [PATCH 2/3] add unit test Signed-off-by: Yuki Huang --- .../models/generation/test_vllm_generation.py | 36 ++++++++++++------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index 8371fababb..75c74dbd79 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -264,17 +264,15 @@ def policy_cluster_separate(): print(f"Error during policy_cluster_separate shutdown: {e}") -@pytest.fixture(scope="function") -def generation_cluster_separate(): - """Create a virtual cluster for the VllmGeneration policy, using 1 GPU.""" - cluster = _create_ray_virtual_cluster_for_test( - "vllm-test-generation-cluster-separate" +def get_generation_cluster_separate(num_gpus_per_node: int = 1) -> RayVirtualCluster: + """Create a virtual cluster for the VllmGeneration policy, using num_gpus_per_node GPU.""" + return RayVirtualCluster( + bundle_ct_per_node_list=[num_gpus_per_node], + use_gpus=True, + max_colocated_worker_groups=1, + num_gpus_per_node=num_gpus_per_node, + name="vllm-test-generation-cluster-separate", ) - yield cluster - try: - cluster.shutdown() - except Exception as e: - print(f"Error during generation_cluster_separate shutdown: {e}") @pytest.fixture(scope="function") @@ -1177,13 +1175,22 @@ def test_vllm_non_divisible_batch_handling(policy): @pytest.mark.asyncio @pytest.mark.parametrize("async_engine", [True, False]) +@pytest.mark.parametrize("tensor_parallel_size", [1, 2]) async def test_vllm_refit_non_collocated_update_weights( policy_cluster_separate, - generation_cluster_separate, tokenizer, test_input_data, async_engine, + tensor_parallel_size, ): + # Skip tensor_parallel_size == 2 until we have resources in CI + if tensor_parallel_size == 2: + pytest.skip( + "Test requires at least three GPUs to run with tensor_parallel_size == 2 on separate clusters." + ) + + generation_cluster_separate = get_generation_cluster_separate(tensor_parallel_size) + if ( policy_cluster_separate.num_gpus_per_node < 1 or generation_cluster_separate.num_gpus_per_node < 1 @@ -1194,7 +1201,6 @@ async def test_vllm_refit_non_collocated_update_weights( # Create Policy on its own cluster hf_config = get_basic_hf_test_config(enable_dtensor=True) - hf_config["dtensor_cfg"]["tensor_parallel_size"] = 1 hf_config["generation"]["colocated"]["enabled"] = False lm_policy = Policy(policy_cluster_separate, hf_config, tokenizer) @@ -1202,7 +1208,7 @@ async def test_vllm_refit_non_collocated_update_weights( vllm_config = deepcopy(basic_vllm_test_config) vllm_config = configure_generation_config(vllm_config, tokenizer, is_eval=True) vllm_config["vllm_cfg"]["async_engine"] = async_engine - vllm_config["vllm_cfg"]["tensor_parallel_size"] = 1 + vllm_config["vllm_cfg"]["tensor_parallel_size"] = tensor_parallel_size vllm_config["colocated"]["enabled"] = False vllm_generation = VllmGeneration(generation_cluster_separate, vllm_config) @@ -1234,6 +1240,10 @@ async def test_vllm_refit_non_collocated_update_weights( # Clean up vllm_generation.shutdown() lm_policy.shutdown() + try: + generation_cluster_separate.shutdown() + except Exception as e: + print(f"Error during generation_cluster_separate shutdown: {e}") @pytest.mark.timeout(210) From 7b3c7cc0b33de70008547d84dbc09670328e5965 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Fri, 4 Jul 2025 05:11:34 +0000 Subject: [PATCH 3/3] add comment Signed-off-by: Yuki Huang --- nemo_rl/models/generation/vllm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nemo_rl/models/generation/vllm.py b/nemo_rl/models/generation/vllm.py index 84948ae77e..9d7c7873e9 100644 --- a/nemo_rl/models/generation/vllm.py +++ b/nemo_rl/models/generation/vllm.py @@ -1221,6 +1221,7 @@ def __init__( "nemo_rl.models.generation.vllm.VllmGenerationWorker", config ) + # It's necessary to set env_vars here to ensure that vllm non-leader workers also have these env_vars # Disable NCCL SHM if training and generation are not co-located: https://github.com/NVIDIA-NeMo/RL/issues/564 env_vars = {} if not self.cfg["colocated"]["enabled"]: