diff --git a/.gitmodules b/.gitmodules index 2a588f3a89..09342d3495 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,7 +1,7 @@ [submodule "3rdparty/NeMo"] path = 3rdparty/NeMo-workspace/NeMo url = https://github.com/NVIDIA/NeMo.git - branch = ashors/nemorl-qwen3 + branch = zhiyul/yukih/prepare-refit-info shallow = true [submodule "3rdparty/Megatron-LM"] path = 3rdparty/Megatron-LM-workspace/Megatron-LM diff --git a/3rdparty/NeMo-workspace/NeMo b/3rdparty/NeMo-workspace/NeMo index 33259f2540..8ddf438734 160000 --- a/3rdparty/NeMo-workspace/NeMo +++ b/3rdparty/NeMo-workspace/NeMo @@ -1 +1 @@ -Subproject commit 33259f2540af6eef375d43fc48bdcbd7ec490c29 +Subproject commit 8ddf4387344c6423763ec9ee0c9a755cbb5d8d35 diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 96faec96cf..9f9273ac8c 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -13,6 +13,7 @@ # limitations under the License. import os import warnings +from contextlib import nullcontext from pathlib import Path from typing import Any, NotRequired, Optional, TypedDict, TypeVar, cast @@ -400,6 +401,7 @@ def refit_policy_generation( policy_generation: GenerationInterface, colocated_inference: bool, _refit_buffer_size_gb: Optional[int] = None, + timer: Optional[Timer] = None, ) -> None: """Refit the policy generation interface with the latest policy weights. @@ -414,43 +416,50 @@ def refit_policy_generation( policy.offload_before_refit() policy_generation.prepare_for_generation(tags=["weights"]) - # update weights - update_success = False - if colocated_inference: - # get model param keys, which is grouped by size - grouped_param_keys = policy.prepare_weights_for_ipc( - _refit_buffer_size_gb=_refit_buffer_size_gb - ) - total_num_keys = sum(len(k) for k in grouped_param_keys) - print( - f"[Refit] Split {total_num_keys} keys into {len(grouped_param_keys)} groups" - ) - # do update - for keys in grouped_param_keys: - ipc_handles = policy.get_weights_ipc_handles(keys) - update_success = policy_generation.update_weights_from_ipc_handles( - ipc_handles + # Create a context manager that does nothing when timer is None + timer_context = ( + timer.time("prepare_for_generation/transfer_and_update_weights") + if timer is not None + else nullcontext() + ) + with timer_context: + # update weights + update_success = False + if colocated_inference: + # get model param keys, which is grouped by size + grouped_param_keys = policy.prepare_weights_for_ipc( + _refit_buffer_size_gb=_refit_buffer_size_gb ) - if not update_success: - break - else: - # update weights through nccl - futures_train = policy.broadcast_weights_for_collective() - futures_inference = policy_generation.update_weights_from_collective() - # wait for all futures to complete - ray.get(futures_train) - results = ray.get(futures_inference) - update_success = all(result for result in results if result is not None) - - # check if update is successful - if not update_success: - error_tag = "cuda-ipc" if colocated_inference else "nccl" - error_message = ( - "❌ Error: Updating weights for the generation policy failed during refit.\n" - f"This often indicates an issue with {error_tag} or " - "a problem within the generation backend (e.g., vLLM worker).\n" - ) - raise RuntimeError(error_message) + total_num_keys = sum(len(k) for k in grouped_param_keys) + print( + f"[Refit] Split {total_num_keys} keys into {len(grouped_param_keys)} groups" + ) + # do update + for keys in grouped_param_keys: + ipc_handles = policy.get_weights_ipc_handles(keys) + update_success = policy_generation.update_weights_from_ipc_handles( + ipc_handles + ) + if not update_success: + break + else: + # update weights through nccl + futures_train = policy.broadcast_weights_for_collective() + futures_inference = policy_generation.update_weights_from_collective() + # wait for all futures to complete + ray.get(futures_train) + results = ray.get(futures_inference) + update_success = all(result for result in results if result is not None) + + # check if update is successful + if not update_success: + error_tag = "cuda-ipc" if colocated_inference else "nccl" + error_message = ( + "❌ Error: Updating weights for the generation policy failed during refit.\n" + f"This often indicates an issue with {error_tag} or " + "a problem within the generation backend (e.g., vLLM worker).\n" + ) + raise RuntimeError(error_message) if colocated_inference: policy.offload_after_refit() @@ -544,7 +553,7 @@ def grpo_train( with timer.time("prepare_for_generation"): if NEED_REFIT and POLICY_GENERATION_STALE: refit_policy_generation( - policy, policy_generation, colocated_inference + policy, policy_generation, colocated_inference, timer=timer ) POLICY_GENERATION_STALE = False else: diff --git a/nemo_rl/models/generation/vllm_backend.py b/nemo_rl/models/generation/vllm_backend.py index af7e69d046..090393c0a3 100644 --- a/nemo_rl/models/generation/vllm_backend.py +++ b/nemo_rl/models/generation/vllm_backend.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from collections import defaultdict from typing import Any, Iterable, Optional import torch +from torch.multiprocessing.reductions import rebuild_cuda_tensor try: import vllm # noqa: F401 @@ -136,7 +138,7 @@ def update_weights_from_local_ipc_handles(self, local_device_ipc_handles): try: is_tensor_packed = local_device_ipc_handles[0] if is_tensor_packed: - _, all_handles, tensor_metadata = local_device_ipc_handles + _, all_handles, list_keys = local_device_ipc_handles else: _, name_and_handle_list = local_device_ipc_handles @@ -152,33 +154,40 @@ def update_weights_from_local_ipc_handles(self, local_device_ipc_handles): # Extract packed tensor from IPC handle dtype_to_packed_tensor = {} for dtype, tensor_handle in all_handles: - func, args = tensor_handle + func = rebuild_cuda_tensor + args = tensor_handle[0] list_args = list(args) list_args[6] = device_id tensor = func(*list_args) dtype_to_packed_tensor[dtype] = tensor - # Unpack tensor to weights. Here we only return a view of the tensor to avoid - # using extra memory. - for key, metadata in tensor_metadata.items(): - # dtype for the 1st and 2nd steps may be different (e.g. e_score_correction_bias) - if isinstance(metadata, tuple): - # use dtype of current step - offset, dtype = metadata - shape, _, size = self.state_dict_info[key] - # update record - self.state_dict_info[key] = (shape, dtype, size) - else: - offset = metadata - shape, dtype, size = self.state_dict_info[key] - tensor = dtype_to_packed_tensor[dtype][offset : offset + size].view( - *shape + weights = [] + dtype_to_offset = defaultdict(lambda: 0) + for key in list_keys: + shape, dtype, size = self.state_dict_info[key] + weights.append( + ( + key, + dtype_to_packed_tensor[dtype][ + dtype_to_offset[dtype] : dtype_to_offset[dtype] + size + ].view(*shape), + ) ) - weights.append((key, tensor)) + dtype_to_offset[dtype] += size + + expected_sizes = { + dtype: tensor.numel() + for dtype, tensor in dtype_to_packed_tensor.items() + } + assert dtype_to_offset == expected_sizes, ( + f"Packed tensor size mismatch: expected sizes from keys list {expected_sizes} != actual packed tensor sizes {dtype_to_offset}. " + f"This indicates the keys list order doesn't match the order used when packing tensors." + ) else: # Process each handle to get the tensor for name, handle in name_and_handle_list: - func, args = handle + func = rebuild_cuda_tensor + args = handle[0] list_args = list(args) list_args[6] = device_id tensor = func(*list_args) diff --git a/nemo_rl/models/megatron/refit_utils.py b/nemo_rl/models/megatron/refit_utils.py index 2463d6ae4f..af6ffa2634 100644 --- a/nemo_rl/models/megatron/refit_utils.py +++ b/nemo_rl/models/megatron/refit_utils.py @@ -156,7 +156,6 @@ def gather_params(model, keys: list[str], key_to_global_keys: dict[str, list[str if k is not None: gathered_params[k] = p - print(f"Time taken to gather params: {time.perf_counter() - st}") return gathered_params diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index b590032408..50ea698d18 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -63,6 +63,7 @@ from nemo_rl.models.policy.utils import ( configure_expandable_segments, get_gpu_info, + get_handle_from_tensor, get_runtime_env_for_policy_worker, import_class_from_path, is_vllm_v1_engine_enabled, @@ -1186,8 +1187,6 @@ def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: @torch.no_grad() def get_weights_ipc_handles(self, keys: Iterable[str]) -> dict[str, Any]: - from torch.multiprocessing.reductions import reduce_tensor - assert self._held_sharded_state_dict_reference is not None, ( "prepare_weights_for_ipc must be called before get_weights_ipc_handles" ) @@ -1217,7 +1216,7 @@ def get_weights_ipc_handles(self, keys: Iterable[str]) -> dict[str, Any]: # Create handles for the tensors all_handles = [] for key, p in converted_params.items(): - handle = reduce_tensor(p.detach()) + handle = get_handle_from_tensor(p) all_handles.append((key, handle)) # (pack_tensor_for_ipc: bool, handles: list) diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 9793ea8e9c..d50ffa0aa0 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -51,6 +51,7 @@ ) from megatron.core.pipeline_parallel import get_forward_backward_func from megatron.core.rerun_state_machine import get_rerun_state_machine +from megatron.core.transformer.module import Float16Module from megatron.inference.text_generation.mcore_engine_server import ( run_mcore_engine, ) @@ -118,6 +119,7 @@ from nemo_rl.models.policy.utils import ( configure_expandable_segments, get_gpu_info, + get_handle_from_tensor, get_megatron_checkpoint_dir, get_runtime_env_for_policy_worker, ) @@ -185,11 +187,25 @@ def setup_megatron_model( if policy_cfg["megatron_cfg"]["freeze_moe_router"]: def freeze_moe_router(model_module): + # Handle both wrapped (Float16Module) and unwrapped models + if isinstance(model_module, Float16Module): + model_module = model_module.module for layer in model_module.decoder.layers: if hasattr(layer.mlp, "router"): layer.mlp.router.weight.requires_grad = False + # Re-enable float32 expert bias for moe router to avoid parameter dtype inconsistency + # see https://github.com/NVIDIA/Megatron-LM/blob/e6c510ff3c1159f8955589b26f7c395bdf0607d9/megatron/core/transformer/moe/router.py#L149 + def re_enable_float32_expert_bias(model_module): + # Handle both wrapped (Float16Module) and unwrapped models + if isinstance(model_module, Float16Module): + model_module = model_module.module + for layer in model_module.decoder.layers: + if hasattr(layer.mlp, "router"): + layer.mlp.router._maintain_float32_expert_bias() + model_post_init_fns.append(freeze_moe_router) + model_post_init_fns.append(re_enable_float32_expert_bias) # Model, optimizer, and learning rate. model = get_model_from_config( @@ -692,7 +708,7 @@ def __init__( # vars used for refit ## will be initialized in prepare_refit_info - self.refit_param_info_hf = None + self.refit_param_info_mcore = None self.local_key_to_global_keys = None ## used for streaming update inference engine weights self._held_gather_buffer = None @@ -1355,20 +1371,18 @@ def report_device_id(self) -> str: @torch.no_grad() def prepare_refit_info(self) -> None: # Get parameter info for refit - ## param_info: list of ((name, shape, dtype), size_in_bytes) tuples - # Cannot cache refit_param_info_mcore since dtype and size_in_bytes for the 1st and 2nd steps may be different - ## e.g. e_score_correction_bias - refit_param_info_mcore = get_param_info(self.model, self.dtype) + # param_info: list of ((name, shape, dtype), size_in_bytes) tuples + self.refit_param_info_mcore = get_param_info(self.model, self.dtype) # Create a map that maps any local parameter name to a list of global parameter names. # This map is repeatedly used by parameter gatherring phase during refit of every step. self.local_key_to_global_keys = get_local_key_to_global_keys( - self.model, state_dict_info=refit_param_info_mcore + self.model, state_dict_info=self.refit_param_info_mcore ) # Collect tensor metadata for refit - self.refit_param_info_hf = {} - for key, _ in refit_param_info_mcore: + refit_param_info_hf = {} + for key, _ in self.refit_param_info_mcore: # gather megatron params gathered_megatron_params = gather_params( self.model, @@ -1381,15 +1395,14 @@ def prepare_refit_info(self) -> None: ) # collect tensor metadata for name, tensor in gathered_hf_params.items(): - self.refit_param_info_hf[name] = ( + refit_param_info_hf[name] = ( tensor.shape, tensor.dtype, tensor.numel(), ) - return self.refit_param_info_hf + return refit_param_info_hf - @torch.no_grad() def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: """Prepare Megatron model weights for IPC transfer to vLLM. @@ -1398,12 +1411,6 @@ def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: """ from nemo_rl.utils.nvml import get_free_memory_bytes - # Get parameter info for refit - ## param_info: list of ((name, shape, dtype), size_in_bytes) tuples - # Cannot cache refit_param_info_mcore since dtype and size_in_bytes for the 1st and 2nd steps may be different - ## e.g. e_score_correction_bias - refit_param_info_mcore = get_param_info(self.model, self.dtype) - # Collect current available memory for refit ## Get current device index from torch device_idx = torch.cuda.current_device() @@ -1413,7 +1420,7 @@ def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: memory_ratio = os.getenv("NRL_REFIT_BUFFER_MEMORY_RATIO", "0.2") total_available_bytes *= float(memory_ratio) - return refit_param_info_mcore, total_available_bytes + return self.refit_param_info_mcore, total_available_bytes # Temporary fix, 'keys' is a kwarg due to some sort of ray bug @torch.no_grad() @@ -1441,7 +1448,6 @@ def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: # Get device UUID for IPC handles device_uuid = self.report_device_id() - from torch.multiprocessing.reductions import reduce_tensor # Create IPC handles for each parameter tensor_number_threshold = os.getenv( @@ -1456,25 +1462,9 @@ def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: # Pack tensors in gathered_hf_params into consolidated tensors by dtype # First calculate total size needed for each dtype type_to_total_size = defaultdict(lambda: 0) - tensor_metadata = dict() # Record offset of the tensor for key, tensor in gathered_hf_params.items(): - # dtype for the 1st and 2nd steps may be different (e.g. e_score_correction_bias) - if tensor.dtype == self.refit_param_info_hf[key][1]: - tensor_metadata[key] = type_to_total_size[tensor.dtype] - else: - # also send dtype if it changes - tensor_metadata[key] = ( - type_to_total_size[tensor.dtype], - tensor.dtype, - ) - # update record - self.refit_param_info_hf[key] = ( - tensor.shape, - tensor.dtype, - tensor.numel(), - ) type_to_total_size[tensor.dtype] += tensor.numel() # Allocate consolidated tensors for each dtype @@ -1488,31 +1478,34 @@ def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: for dtype, total_size in type_to_total_size.items() } + dtype_to_offset = defaultdict(lambda: 0) # Copy tensors into consolidated buffers for key, tensor in gathered_hf_params.items(): - offset = tensor_metadata[key] - if isinstance(offset, tuple): - offset, _ = offset dtype = tensor.dtype size = tensor.numel() - packed_tensors[dtype][offset : offset + size].copy_( - tensor.detach().view(-1) - ) + packed_tensors[dtype][ + dtype_to_offset[dtype] : dtype_to_offset[dtype] + size + ].copy_(tensor.detach().view(-1)) + dtype_to_offset[dtype] += size # Create IPC handles for consolidated tensors all_handles = [ - (dtype, reduce_tensor(tensor.detach())) + (dtype, get_handle_from_tensor(tensor)) for dtype, tensor in packed_tensors.items() ] # Store reference to prevent garbage collection self._held_gather_buffer = packed_tensors - serialized = (pack_tensor_for_ipc, all_handles, tensor_metadata) + serialized = ( + pack_tensor_for_ipc, + all_handles, + tuple(gathered_hf_params.keys()), + ) else: all_handles = [] for key, tensor in gathered_hf_params.items(): - handle = reduce_tensor(tensor.detach()) + handle = get_handle_from_tensor(tensor) all_handles.append((key, handle)) self._held_gather_buffer = gathered_hf_params serialized = (False, all_handles) diff --git a/nemo_rl/models/policy/utils.py b/nemo_rl/models/policy/utils.py index a61e5e20b7..678d5d89c0 100644 --- a/nemo_rl/models/policy/utils.py +++ b/nemo_rl/models/policy/utils.py @@ -220,3 +220,11 @@ def get_megatron_checkpoint_dir() -> str: ) print(f"Using default megatron checkpoint dir: {checkpoint_dir}") return checkpoint_dir + + +def get_handle_from_tensor(tensor: torch.Tensor) -> tuple[Any]: + """Get IPC handle from a tensor.""" + from torch.multiprocessing.reductions import reduce_tensor + + # skip serializing the function for better refit performance + return reduce_tensor(tensor.detach())[1:] diff --git a/pyrefly.toml b/pyrefly.toml index f3bc05a639..4038ce9737 100644 --- a/pyrefly.toml +++ b/pyrefly.toml @@ -84,6 +84,7 @@ project-includes = [ "nemo_rl/models/dtensor/parallelize.py", "nemo_rl/models/generation/__init__.py", "nemo_rl/models/generation/interfaces.py", + "nemo_rl/models/generation/vllm_backend.py", "nemo_rl/models/huggingface/__init__.py", "nemo_rl/models/megatron/__init__.py", "nemo_rl/models/megatron/community_import.py",