From f47d3da9a0caf80e443595f7af1e7c4642f64bb7 Mon Sep 17 00:00:00 2001 From: Zhiyu Li Date: Tue, 15 Jul 2025 19:43:09 -0700 Subject: [PATCH 01/14] fix: maintain fp32 mlp.router.expert_bias even with bf16 enabled Signed-off-by: Zhiyu Li Signed-off-by: Zhiyu Li --- 3rdparty/NeMo-workspace/NeMo | 2 +- nemo_rl/models/policy/megatron_policy_worker.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/3rdparty/NeMo-workspace/NeMo b/3rdparty/NeMo-workspace/NeMo index 33259f2540..8ddf438734 160000 --- a/3rdparty/NeMo-workspace/NeMo +++ b/3rdparty/NeMo-workspace/NeMo @@ -1 +1 @@ -Subproject commit 33259f2540af6eef375d43fc48bdcbd7ec490c29 +Subproject commit 8ddf4387344c6423763ec9ee0c9a755cbb5d8d35 diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 9793ea8e9c..5f63c107e9 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -27,6 +27,7 @@ from megatron.core.distributed.custom_fsdp import ( FullyShardedDataParallel as custom_FSDP, ) +from megatron.core.transformer.module import Float16Module from megatron.core.inference.engines import ( StaticInferenceEngine, ) @@ -185,11 +186,25 @@ def setup_megatron_model( if policy_cfg["megatron_cfg"]["freeze_moe_router"]: def freeze_moe_router(model_module): + # Handle both wrapped (Float16Module) and unwrapped models + if isinstance(model_module, Float16Module): + model_module = model_module.module for layer in model_module.decoder.layers: if hasattr(layer.mlp, "router"): layer.mlp.router.weight.requires_grad = False + # Re-enable float32 expert bias for moe router to avoid parameter dtype inconsistency + # see https://github.com/NVIDIA/Megatron-LM/blob/e6c510ff3c1159f8955589b26f7c395bdf0607d9/megatron/core/transformer/moe/router.py#L149 + def re_enable_float32_expert_bias(model_module): + # Handle both wrapped (Float16Module) and unwrapped models + if isinstance(model_module, Float16Module): + model_module = model_module.module + for layer in model_module.decoder.layers: + if hasattr(layer.mlp, "router"): + layer.mlp.router._maintain_float32_expert_bias() + model_post_init_fns.append(freeze_moe_router) + model_post_init_fns.append(re_enable_float32_expert_bias) # Model, optimizer, and learning rate. model = get_model_from_config( From a9ffcee51b25af435891020e5c3e1179a60df856 Mon Sep 17 00:00:00 2001 From: Zhiyu Li Date: Tue, 15 Jul 2025 19:01:31 -0700 Subject: [PATCH 02/14] avoid serializing rebuild_cuda_tensor function Signed-off-by: Zhiyu Li Signed-off-by: Zhiyu Li --- nemo_rl/models/generation/vllm_backend.py | 7 +++++-- nemo_rl/models/policy/megatron_policy_worker.py | 12 +++++++++--- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/nemo_rl/models/generation/vllm_backend.py b/nemo_rl/models/generation/vllm_backend.py index af7e69d046..0e917b13ee 100644 --- a/nemo_rl/models/generation/vllm_backend.py +++ b/nemo_rl/models/generation/vllm_backend.py @@ -15,6 +15,7 @@ from typing import Any, Iterable, Optional import torch +from torch.multiprocessing.reductions import rebuild_cuda_tensor try: import vllm # noqa: F401 @@ -152,7 +153,8 @@ def update_weights_from_local_ipc_handles(self, local_device_ipc_handles): # Extract packed tensor from IPC handle dtype_to_packed_tensor = {} for dtype, tensor_handle in all_handles: - func, args = tensor_handle + func = rebuild_cuda_tensor + args = tensor_handle[0] list_args = list(args) list_args[6] = device_id tensor = func(*list_args) @@ -178,7 +180,8 @@ def update_weights_from_local_ipc_handles(self, local_device_ipc_handles): else: # Process each handle to get the tensor for name, handle in name_and_handle_list: - func, args = handle + func = rebuild_cuda_tensor + args = handle[0] list_args = list(args) list_args[6] = device_id tensor = func(*list_args) diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 5f63c107e9..a4118bed84 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -1430,6 +1430,13 @@ def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: return refit_param_info_mcore, total_available_bytes + def get_handle_from_tensor(self, tensor: torch.Tensor) -> tuple[str, Any]: + """Get IPC handle from a tensor.""" + from torch.multiprocessing.reductions import reduce_tensor + + # skip serializing the function for better refit performance + return reduce_tensor(tensor.detach())[1:] + # Temporary fix, 'keys' is a kwarg due to some sort of ray bug @torch.no_grad() def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: @@ -1456,7 +1463,6 @@ def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: # Get device UUID for IPC handles device_uuid = self.report_device_id() - from torch.multiprocessing.reductions import reduce_tensor # Create IPC handles for each parameter tensor_number_threshold = os.getenv( @@ -1516,7 +1522,7 @@ def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: # Create IPC handles for consolidated tensors all_handles = [ - (dtype, reduce_tensor(tensor.detach())) + (dtype, self.get_handle_from_tensor(tensor)) for dtype, tensor in packed_tensors.items() ] @@ -1527,7 +1533,7 @@ def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: else: all_handles = [] for key, tensor in gathered_hf_params.items(): - handle = reduce_tensor(tensor.detach()) + handle = self.get_handle_from_tensor(tensor) all_handles.append((key, handle)) self._held_gather_buffer = gathered_hf_params serialized = (False, all_handles) From dc19e7f6c302b69de5f016003cb34162670d0d3f Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Wed, 16 Jul 2025 05:09:40 -0700 Subject: [PATCH 03/14] assert dtype Signed-off-by: Yuki Huang Signed-off-by: Zhiyu Li --- nemo_rl/models/policy/megatron_policy_worker.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index a4118bed84..e80b78d7a1 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -27,7 +27,6 @@ from megatron.core.distributed.custom_fsdp import ( FullyShardedDataParallel as custom_FSDP, ) -from megatron.core.transformer.module import Float16Module from megatron.core.inference.engines import ( StaticInferenceEngine, ) @@ -52,6 +51,7 @@ ) from megatron.core.pipeline_parallel import get_forward_backward_func from megatron.core.rerun_state_machine import get_rerun_state_machine +from megatron.core.transformer.module import Float16Module from megatron.inference.text_generation.mcore_engine_server import ( run_mcore_engine, ) @@ -1485,6 +1485,9 @@ def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: if tensor.dtype == self.refit_param_info_hf[key][1]: tensor_metadata[key] = type_to_total_size[tensor.dtype] else: + assert False, ( + f"{key} dtype mismatch: {tensor.dtype} vs {self.refit_param_info_hf[key][1]}" + ) # also send dtype if it changes tensor_metadata[key] = ( type_to_total_size[tensor.dtype], From 36d2fca7d329d0e0ccc803654974ad8f35b4d809 Mon Sep 17 00:00:00 2001 From: Zhiyu Li Date: Thu, 17 Jul 2025 10:18:40 -0700 Subject: [PATCH 04/14] track refitting time inside prepare_for_generation Signed-off-by: Zhiyu Li fix Signed-off-by: Zhiyu Li Signed-off-by: Zhiyu Li --- nemo_rl/algorithms/grpo.py | 79 ++++++++++++++++++++------------------ 1 file changed, 42 insertions(+), 37 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 96faec96cf..82cb681081 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -13,6 +13,7 @@ # limitations under the License. import os import warnings +from contextlib import nullcontext from pathlib import Path from typing import Any, NotRequired, Optional, TypedDict, TypeVar, cast @@ -400,6 +401,7 @@ def refit_policy_generation( policy_generation: GenerationInterface, colocated_inference: bool, _refit_buffer_size_gb: Optional[int] = None, + timer: Optional[Timer] = None, ) -> None: """Refit the policy generation interface with the latest policy weights. @@ -414,43 +416,46 @@ def refit_policy_generation( policy.offload_before_refit() policy_generation.prepare_for_generation(tags=["weights"]) - # update weights - update_success = False - if colocated_inference: - # get model param keys, which is grouped by size - grouped_param_keys = policy.prepare_weights_for_ipc( - _refit_buffer_size_gb=_refit_buffer_size_gb - ) - total_num_keys = sum(len(k) for k in grouped_param_keys) - print( - f"[Refit] Split {total_num_keys} keys into {len(grouped_param_keys)} groups" - ) - # do update - for keys in grouped_param_keys: - ipc_handles = policy.get_weights_ipc_handles(keys) - update_success = policy_generation.update_weights_from_ipc_handles( - ipc_handles + # Create a context manager that does nothing when timer is None + timer_context = timer.time("prepare_for_generation/refit_policy_generation") if timer is not None else nullcontext() + with timer_context: + # update weights + update_success = False + if colocated_inference: + # get model param keys, which is grouped by size + grouped_param_keys = policy.prepare_weights_for_ipc( + _refit_buffer_size_gb=_refit_buffer_size_gb ) - if not update_success: - break - else: - # update weights through nccl - futures_train = policy.broadcast_weights_for_collective() - futures_inference = policy_generation.update_weights_from_collective() - # wait for all futures to complete - ray.get(futures_train) - results = ray.get(futures_inference) - update_success = all(result for result in results if result is not None) - - # check if update is successful - if not update_success: - error_tag = "cuda-ipc" if colocated_inference else "nccl" - error_message = ( - "❌ Error: Updating weights for the generation policy failed during refit.\n" - f"This often indicates an issue with {error_tag} or " - "a problem within the generation backend (e.g., vLLM worker).\n" - ) - raise RuntimeError(error_message) + total_num_keys = sum(len(k) for k in grouped_param_keys) + print( + f"[Refit] Split {total_num_keys} keys into {len(grouped_param_keys)} groups" + ) + # do update + for keys in grouped_param_keys: + ipc_handles = policy.get_weights_ipc_handles(keys) + update_success = policy_generation.update_weights_from_ipc_handles( + ipc_handles + ) + if not update_success: + break + else: + # update weights through nccl + futures_train = policy.broadcast_weights_for_collective() + futures_inference = policy_generation.update_weights_from_collective() + # wait for all futures to complete + ray.get(futures_train) + results = ray.get(futures_inference) + update_success = all(result for result in results if result is not None) + + # check if update is successful + if not update_success: + error_tag = "cuda-ipc" if colocated_inference else "nccl" + error_message = ( + "❌ Error: Updating weights for the generation policy failed during refit.\n" + f"This often indicates an issue with {error_tag} or " + "a problem within the generation backend (e.g., vLLM worker).\n" + ) + raise RuntimeError(error_message) if colocated_inference: policy.offload_after_refit() @@ -544,7 +549,7 @@ def grpo_train( with timer.time("prepare_for_generation"): if NEED_REFIT and POLICY_GENERATION_STALE: refit_policy_generation( - policy, policy_generation, colocated_inference + policy, policy_generation, colocated_inference, timer=timer ) POLICY_GENERATION_STALE = False else: From ddc4cc25bd228a39f86f094510bcba4378882f39 Mon Sep 17 00:00:00 2001 From: Zhiyu Li Date: Thu, 17 Jul 2025 10:48:03 -0700 Subject: [PATCH 05/14] pass list of keys only during refitting Signed-off-by: Zhiyu Li --- nemo_rl/models/generation/vllm_backend.py | 33 ++++++++--------- .../models/policy/megatron_policy_worker.py | 36 ++++++------------- 2 files changed, 25 insertions(+), 44 deletions(-) diff --git a/nemo_rl/models/generation/vllm_backend.py b/nemo_rl/models/generation/vllm_backend.py index 0e917b13ee..d5b6adb9d0 100644 --- a/nemo_rl/models/generation/vllm_backend.py +++ b/nemo_rl/models/generation/vllm_backend.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from collections import defaultdict from typing import Any, Iterable, Optional import torch @@ -137,7 +138,7 @@ def update_weights_from_local_ipc_handles(self, local_device_ipc_handles): try: is_tensor_packed = local_device_ipc_handles[0] if is_tensor_packed: - _, all_handles, tensor_metadata = local_device_ipc_handles + _, all_handles, list_keys = local_device_ipc_handles else: _, name_and_handle_list = local_device_ipc_handles @@ -160,23 +161,19 @@ def update_weights_from_local_ipc_handles(self, local_device_ipc_handles): tensor = func(*list_args) dtype_to_packed_tensor[dtype] = tensor - # Unpack tensor to weights. Here we only return a view of the tensor to avoid - # using extra memory. - for key, metadata in tensor_metadata.items(): - # dtype for the 1st and 2nd steps may be different (e.g. e_score_correction_bias) - if isinstance(metadata, tuple): - # use dtype of current step - offset, dtype = metadata - shape, _, size = self.state_dict_info[key] - # update record - self.state_dict_info[key] = (shape, dtype, size) - else: - offset = metadata - shape, dtype, size = self.state_dict_info[key] - tensor = dtype_to_packed_tensor[dtype][offset : offset + size].view( - *shape - ) - weights.append((key, tensor)) + # Unpack tensor to weights using pre-calculated offsets + weights = [] + dtype_to_offset = defaultdict(lambda: 0) + for key in list_keys: + shape, dtype, size = self.state_dict_info[key] + weights.append((key, dtype_to_packed_tensor[dtype][dtype_to_offset[dtype]:dtype_to_offset[dtype]+size].view(*shape))) + dtype_to_offset[dtype] += size + + expected_sizes = {dtype: tensor.numel() for dtype, tensor in dtype_to_packed_tensor.items()} + assert dtype_to_offset == expected_sizes, ( + f"Packed tensor size mismatch: expected sizes from keys list {expected_sizes} != actual packed tensor sizes {dtype_to_offset}. " + f"This indicates the keys list order doesn't match the order used when packing tensors." + ) else: # Process each handle to get the tensor for name, handle in name_and_handle_list: diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index e80b78d7a1..144b00db81 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -1477,28 +1477,9 @@ def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: # Pack tensors in gathered_hf_params into consolidated tensors by dtype # First calculate total size needed for each dtype type_to_total_size = defaultdict(lambda: 0) - tensor_metadata = dict() # Record offset of the tensor for key, tensor in gathered_hf_params.items(): - # dtype for the 1st and 2nd steps may be different (e.g. e_score_correction_bias) - if tensor.dtype == self.refit_param_info_hf[key][1]: - tensor_metadata[key] = type_to_total_size[tensor.dtype] - else: - assert False, ( - f"{key} dtype mismatch: {tensor.dtype} vs {self.refit_param_info_hf[key][1]}" - ) - # also send dtype if it changes - tensor_metadata[key] = ( - type_to_total_size[tensor.dtype], - tensor.dtype, - ) - # update record - self.refit_param_info_hf[key] = ( - tensor.shape, - tensor.dtype, - tensor.numel(), - ) type_to_total_size[tensor.dtype] += tensor.numel() # Allocate consolidated tensors for each dtype @@ -1512,16 +1493,15 @@ def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: for dtype, total_size in type_to_total_size.items() } + dtype_to_offset = defaultdict(lambda: 0) # Copy tensors into consolidated buffers for key, tensor in gathered_hf_params.items(): - offset = tensor_metadata[key] - if isinstance(offset, tuple): - offset, _ = offset dtype = tensor.dtype size = tensor.numel() - packed_tensors[dtype][offset : offset + size].copy_( - tensor.detach().view(-1) - ) + packed_tensors[dtype][ + dtype_to_offset[dtype] : dtype_to_offset[dtype] + size + ].copy_(tensor.detach().view(-1)) + dtype_to_offset[dtype] += size # Create IPC handles for consolidated tensors all_handles = [ @@ -1532,7 +1512,11 @@ def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: # Store reference to prevent garbage collection self._held_gather_buffer = packed_tensors - serialized = (pack_tensor_for_ipc, all_handles, tensor_metadata) + serialized = ( + pack_tensor_for_ipc, + all_handles, + tuple(gathered_hf_params.keys()), + ) else: all_handles = [] for key, tensor in gathered_hf_params.items(): From 6a02d010bf4ca0319d6d80c477df8523a85b183d Mon Sep 17 00:00:00 2001 From: Zhiyu Li Date: Fri, 18 Jul 2025 23:11:16 +0000 Subject: [PATCH 06/14] remove print and unnecessary comments Signed-off-by: Zhiyu Li --- nemo_rl/models/megatron/refit_utils.py | 1 - nemo_rl/models/policy/megatron_policy_worker.py | 4 ---- 2 files changed, 5 deletions(-) diff --git a/nemo_rl/models/megatron/refit_utils.py b/nemo_rl/models/megatron/refit_utils.py index 2463d6ae4f..af6ffa2634 100644 --- a/nemo_rl/models/megatron/refit_utils.py +++ b/nemo_rl/models/megatron/refit_utils.py @@ -156,7 +156,6 @@ def gather_params(model, keys: list[str], key_to_global_keys: dict[str, list[str if k is not None: gathered_params[k] = p - print(f"Time taken to gather params: {time.perf_counter() - st}") return gathered_params diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 144b00db81..54b450cfa5 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -1371,8 +1371,6 @@ def report_device_id(self) -> str: def prepare_refit_info(self) -> None: # Get parameter info for refit ## param_info: list of ((name, shape, dtype), size_in_bytes) tuples - # Cannot cache refit_param_info_mcore since dtype and size_in_bytes for the 1st and 2nd steps may be different - ## e.g. e_score_correction_bias refit_param_info_mcore = get_param_info(self.model, self.dtype) # Create a map that maps any local parameter name to a list of global parameter names. @@ -1415,8 +1413,6 @@ def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: # Get parameter info for refit ## param_info: list of ((name, shape, dtype), size_in_bytes) tuples - # Cannot cache refit_param_info_mcore since dtype and size_in_bytes for the 1st and 2nd steps may be different - ## e.g. e_score_correction_bias refit_param_info_mcore = get_param_info(self.model, self.dtype) # Collect current available memory for refit From 715f5228b1072d5361de9d697bc23529876501d8 Mon Sep 17 00:00:00 2001 From: yuki <48991475+yuki-666@users.noreply.github.com> Date: Mon, 21 Jul 2025 12:43:42 +0800 Subject: [PATCH 07/14] feat: cache refit_param_info_mcore (#698) Signed-off-by: Yuki Huang Signed-off-by: Zhiyu Li --- .../models/policy/megatron_policy_worker.py | 23 ++++++++----------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 54b450cfa5..8171e371a9 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -707,7 +707,7 @@ def __init__( # vars used for refit ## will be initialized in prepare_refit_info - self.refit_param_info_hf = None + self.refit_param_info_mcore = None self.local_key_to_global_keys = None ## used for streaming update inference engine weights self._held_gather_buffer = None @@ -1370,18 +1370,18 @@ def report_device_id(self) -> str: @torch.no_grad() def prepare_refit_info(self) -> None: # Get parameter info for refit - ## param_info: list of ((name, shape, dtype), size_in_bytes) tuples - refit_param_info_mcore = get_param_info(self.model, self.dtype) + # param_info: list of ((name, shape, dtype), size_in_bytes) tuples + self.refit_param_info_mcore = get_param_info(self.model, self.dtype) # Create a map that maps any local parameter name to a list of global parameter names. # This map is repeatedly used by parameter gatherring phase during refit of every step. self.local_key_to_global_keys = get_local_key_to_global_keys( - self.model, state_dict_info=refit_param_info_mcore + self.model, state_dict_info=self.refit_param_info_mcore ) # Collect tensor metadata for refit - self.refit_param_info_hf = {} - for key, _ in refit_param_info_mcore: + refit_param_info_hf = {} + for key, _ in self.refit_param_info_mcore: # gather megatron params gathered_megatron_params = gather_params( self.model, @@ -1394,15 +1394,14 @@ def prepare_refit_info(self) -> None: ) # collect tensor metadata for name, tensor in gathered_hf_params.items(): - self.refit_param_info_hf[name] = ( + refit_param_info_hf[name] = ( tensor.shape, tensor.dtype, tensor.numel(), ) - return self.refit_param_info_hf + return refit_param_info_hf - @torch.no_grad() def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: """Prepare Megatron model weights for IPC transfer to vLLM. @@ -1411,10 +1410,6 @@ def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: """ from nemo_rl.utils.nvml import get_free_memory_bytes - # Get parameter info for refit - ## param_info: list of ((name, shape, dtype), size_in_bytes) tuples - refit_param_info_mcore = get_param_info(self.model, self.dtype) - # Collect current available memory for refit ## Get current device index from torch device_idx = torch.cuda.current_device() @@ -1424,7 +1419,7 @@ def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: memory_ratio = os.getenv("NRL_REFIT_BUFFER_MEMORY_RATIO", "0.2") total_available_bytes *= float(memory_ratio) - return refit_param_info_mcore, total_available_bytes + return self.refit_param_info_mcore, total_available_bytes def get_handle_from_tensor(self, tensor: torch.Tensor) -> tuple[str, Any]: """Get IPC handle from a tensor.""" From 35cffa728cfb4dc58b4a35a32b4e0e5753f0983b Mon Sep 17 00:00:00 2001 From: Zhiyu Li Date: Mon, 21 Jul 2025 21:06:08 +0000 Subject: [PATCH 08/14] better context manager name Signed-off-by: Zhiyu Li --- nemo_rl/algorithms/grpo.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 82cb681081..9f9273ac8c 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -417,7 +417,11 @@ def refit_policy_generation( policy_generation.prepare_for_generation(tags=["weights"]) # Create a context manager that does nothing when timer is None - timer_context = timer.time("prepare_for_generation/refit_policy_generation") if timer is not None else nullcontext() + timer_context = ( + timer.time("prepare_for_generation/transfer_and_update_weights") + if timer is not None + else nullcontext() + ) with timer_context: # update weights update_success = False From c3f015ad57616fbc0ed0096c9e6310cd17aec95c Mon Sep 17 00:00:00 2001 From: Zhiyu Li Date: Mon, 21 Jul 2025 21:18:30 +0000 Subject: [PATCH 09/14] update .gitmodules Signed-off-by: Zhiyu Li --- .gitmodules | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitmodules b/.gitmodules index 2a588f3a89..09342d3495 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,7 +1,7 @@ [submodule "3rdparty/NeMo"] path = 3rdparty/NeMo-workspace/NeMo url = https://github.com/NVIDIA/NeMo.git - branch = ashors/nemorl-qwen3 + branch = zhiyul/yukih/prepare-refit-info shallow = true [submodule "3rdparty/Megatron-LM"] path = 3rdparty/Megatron-LM-workspace/Megatron-LM From 3a21a4e934447e6e723fe97649078968fdfa5fe4 Mon Sep 17 00:00:00 2001 From: Zhiyu Li Date: Mon, 21 Jul 2025 21:24:50 +0000 Subject: [PATCH 10/14] lint Signed-off-by: Zhiyu Li --- nemo_rl/models/generation/vllm_backend.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/nemo_rl/models/generation/vllm_backend.py b/nemo_rl/models/generation/vllm_backend.py index d5b6adb9d0..16c6526a0c 100644 --- a/nemo_rl/models/generation/vllm_backend.py +++ b/nemo_rl/models/generation/vllm_backend.py @@ -166,10 +166,20 @@ def update_weights_from_local_ipc_handles(self, local_device_ipc_handles): dtype_to_offset = defaultdict(lambda: 0) for key in list_keys: shape, dtype, size = self.state_dict_info[key] - weights.append((key, dtype_to_packed_tensor[dtype][dtype_to_offset[dtype]:dtype_to_offset[dtype]+size].view(*shape))) + weights.append( + ( + key, + dtype_to_packed_tensor[dtype][ + dtype_to_offset[dtype] : dtype_to_offset[dtype] + size + ].view(*shape), + ) + ) dtype_to_offset[dtype] += size - expected_sizes = {dtype: tensor.numel() for dtype, tensor in dtype_to_packed_tensor.items()} + expected_sizes = { + dtype: tensor.numel() + for dtype, tensor in dtype_to_packed_tensor.items() + } assert dtype_to_offset == expected_sizes, ( f"Packed tensor size mismatch: expected sizes from keys list {expected_sizes} != actual packed tensor sizes {dtype_to_offset}. " f"This indicates the keys list order doesn't match the order used when packing tensors." From 4eb076fad6cf8ca415c0290d5757e9f9526e4a9b Mon Sep 17 00:00:00 2001 From: Zhiyu Li Date: Mon, 21 Jul 2025 21:38:06 +0000 Subject: [PATCH 11/14] remove unused comment Signed-off-by: Zhiyu Li --- nemo_rl/models/generation/vllm_backend.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nemo_rl/models/generation/vllm_backend.py b/nemo_rl/models/generation/vllm_backend.py index 16c6526a0c..090393c0a3 100644 --- a/nemo_rl/models/generation/vllm_backend.py +++ b/nemo_rl/models/generation/vllm_backend.py @@ -161,7 +161,6 @@ def update_weights_from_local_ipc_handles(self, local_device_ipc_handles): tensor = func(*list_args) dtype_to_packed_tensor[dtype] = tensor - # Unpack tensor to weights using pre-calculated offsets weights = [] dtype_to_offset = defaultdict(lambda: 0) for key in list_keys: From 898903a4c68eddbf3a5fbbc6cdac7063ca4b289c Mon Sep 17 00:00:00 2001 From: Zhiyu Li Date: Tue, 29 Jul 2025 15:02:49 -0700 Subject: [PATCH 12/14] fix tests failure in dtensor Signed-off-by: Zhiyu Li --- nemo_rl/models/policy/dtensor_policy_worker.py | 5 ++--- nemo_rl/models/policy/megatron_policy_worker.py | 12 +++--------- nemo_rl/models/policy/utils.py | 7 +++++++ 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index b590032408..58199dc8a9 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -67,6 +67,7 @@ import_class_from_path, is_vllm_v1_engine_enabled, sliding_window_overwrite, + get_handle_from_tensor, ) from nemo_rl.utils.native_checkpoint import ( load_checkpoint, @@ -1186,8 +1187,6 @@ def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: @torch.no_grad() def get_weights_ipc_handles(self, keys: Iterable[str]) -> dict[str, Any]: - from torch.multiprocessing.reductions import reduce_tensor - assert self._held_sharded_state_dict_reference is not None, ( "prepare_weights_for_ipc must be called before get_weights_ipc_handles" ) @@ -1217,7 +1216,7 @@ def get_weights_ipc_handles(self, keys: Iterable[str]) -> dict[str, Any]: # Create handles for the tensors all_handles = [] for key, p in converted_params.items(): - handle = reduce_tensor(p.detach()) + handle = get_handle_from_tensor(p) all_handles.append((key, handle)) # (pack_tensor_for_ipc: bool, handles: list) diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 8171e371a9..d2bfc48ab6 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -121,6 +121,7 @@ get_gpu_info, get_megatron_checkpoint_dir, get_runtime_env_for_policy_worker, + get_handle_from_tensor, ) TokenizerType = TypeVar("TokenizerType", bound=PreTrainedTokenizerBase) @@ -1421,13 +1422,6 @@ def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: return self.refit_param_info_mcore, total_available_bytes - def get_handle_from_tensor(self, tensor: torch.Tensor) -> tuple[str, Any]: - """Get IPC handle from a tensor.""" - from torch.multiprocessing.reductions import reduce_tensor - - # skip serializing the function for better refit performance - return reduce_tensor(tensor.detach())[1:] - # Temporary fix, 'keys' is a kwarg due to some sort of ray bug @torch.no_grad() def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: @@ -1496,7 +1490,7 @@ def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: # Create IPC handles for consolidated tensors all_handles = [ - (dtype, self.get_handle_from_tensor(tensor)) + (dtype, get_handle_from_tensor(tensor)) for dtype, tensor in packed_tensors.items() ] @@ -1511,7 +1505,7 @@ def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: else: all_handles = [] for key, tensor in gathered_hf_params.items(): - handle = self.get_handle_from_tensor(tensor) + handle = get_handle_from_tensor(tensor) all_handles.append((key, handle)) self._held_gather_buffer = gathered_hf_params serialized = (False, all_handles) diff --git a/nemo_rl/models/policy/utils.py b/nemo_rl/models/policy/utils.py index a61e5e20b7..9a84eeb4c0 100644 --- a/nemo_rl/models/policy/utils.py +++ b/nemo_rl/models/policy/utils.py @@ -220,3 +220,10 @@ def get_megatron_checkpoint_dir() -> str: ) print(f"Using default megatron checkpoint dir: {checkpoint_dir}") return checkpoint_dir + +def get_handle_from_tensor(tensor: torch.Tensor) -> tuple[str, Any]: + """Get IPC handle from a tensor.""" + from torch.multiprocessing.reductions import reduce_tensor + + # skip serializing the function for better refit performance + return reduce_tensor(tensor.detach())[1:] From 0d267655f6d59cfa3ca03fa84defed819ae0785b Mon Sep 17 00:00:00 2001 From: Zhiyu Li Date: Tue, 29 Jul 2025 16:23:15 -0700 Subject: [PATCH 13/14] lint Signed-off-by: Zhiyu Li --- nemo_rl/models/policy/dtensor_policy_worker.py | 2 +- nemo_rl/models/policy/megatron_policy_worker.py | 2 +- nemo_rl/models/policy/utils.py | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 58199dc8a9..50ea698d18 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -63,11 +63,11 @@ from nemo_rl.models.policy.utils import ( configure_expandable_segments, get_gpu_info, + get_handle_from_tensor, get_runtime_env_for_policy_worker, import_class_from_path, is_vllm_v1_engine_enabled, sliding_window_overwrite, - get_handle_from_tensor, ) from nemo_rl.utils.native_checkpoint import ( load_checkpoint, diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index d2bfc48ab6..d50ffa0aa0 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -119,9 +119,9 @@ from nemo_rl.models.policy.utils import ( configure_expandable_segments, get_gpu_info, + get_handle_from_tensor, get_megatron_checkpoint_dir, get_runtime_env_for_policy_worker, - get_handle_from_tensor, ) TokenizerType = TypeVar("TokenizerType", bound=PreTrainedTokenizerBase) diff --git a/nemo_rl/models/policy/utils.py b/nemo_rl/models/policy/utils.py index 9a84eeb4c0..678d5d89c0 100644 --- a/nemo_rl/models/policy/utils.py +++ b/nemo_rl/models/policy/utils.py @@ -221,7 +221,8 @@ def get_megatron_checkpoint_dir() -> str: print(f"Using default megatron checkpoint dir: {checkpoint_dir}") return checkpoint_dir -def get_handle_from_tensor(tensor: torch.Tensor) -> tuple[str, Any]: + +def get_handle_from_tensor(tensor: torch.Tensor) -> tuple[Any]: """Get IPC handle from a tensor.""" from torch.multiprocessing.reductions import reduce_tensor From dd7485001a99cbc5862e4e1416830009dc38d20f Mon Sep 17 00:00:00 2001 From: Zhiyu Li Date: Tue, 29 Jul 2025 16:39:04 -0700 Subject: [PATCH 14/14] add nemo_rl/models/generation/vllm_backend.py to pyrefly whitelist Signed-off-by: Zhiyu Li --- pyrefly.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyrefly.toml b/pyrefly.toml index f3bc05a639..4038ce9737 100644 --- a/pyrefly.toml +++ b/pyrefly.toml @@ -84,6 +84,7 @@ project-includes = [ "nemo_rl/models/dtensor/parallelize.py", "nemo_rl/models/generation/__init__.py", "nemo_rl/models/generation/interfaces.py", + "nemo_rl/models/generation/vllm_backend.py", "nemo_rl/models/huggingface/__init__.py", "nemo_rl/models/megatron/__init__.py", "nemo_rl/models/megatron/community_import.py",