From 36a2d68a6a5889d985e52cb1b323b5846d779f6e Mon Sep 17 00:00:00 2001 From: Guyue Huang Date: Mon, 23 Jun 2025 20:43:30 -0700 Subject: [PATCH 1/6] Pack tensor at megatron worker and unpack at vllm worker Signed-off-by: Guyue Huang --- nemo_rl/models/generation/vllm_backend.py | 17 +++++--- .../models/policy/megatron_policy_worker.py | 42 +++++++++++++++---- 2 files changed, 46 insertions(+), 13 deletions(-) diff --git a/nemo_rl/models/generation/vllm_backend.py b/nemo_rl/models/generation/vllm_backend.py index c40aea4418..25fd941016 100644 --- a/nemo_rl/models/generation/vllm_backend.py +++ b/nemo_rl/models/generation/vllm_backend.py @@ -58,18 +58,23 @@ def update_weights_from_ipc_handles(self, ipc_handles): try: # Get handles for this device device_uuid = self.report_device_id() - handles = ipc_handles[device_uuid] + deserialized = ipc_handles[device_uuid] device_id = self.device.index weights = [] - # Process each handle to get the tensor - for name, handle in handles: - func, args = handle + all_handles, key_to_type_and_offset_and_size_in_big_tensor = deserialized + type_to_packed_big_tensor_size = {} + for k, tensor_handle in all_handles: + func, args = tensor_handle list_args = list(args) - # Update device ID to match the current device list_args[6] = device_id tensor = func(*list_args) - weights.append((name, tensor)) + type_to_packed_big_tensor_size[k] = tensor + + for key, shape, type, offset, size in key_to_type_and_offset_and_size_in_big_tensor: + assert offset+size <= type_to_packed_big_tensor_size[type].numel() + tensor = type_to_packed_big_tensor_size[type][offset:offset+size].clone().reshape(shape) + weights.append((key, tensor)) # Load weights into the model self.model_runner.model.load_weights(weights=weights) diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 40944a0f84..680ceffe78 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -1367,18 +1367,46 @@ def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: from torch.multiprocessing.reductions import reduce_tensor # Create IPC handles for each parameter - all_handles = [] + + # pack tensors in gathered_hf_params to a big tensor + type_to_packed_big_tensor_size = defaultdict(lambda : 0) + key_to_type_and_offset_and_size_in_big_tensor = [] for key, tensor in gathered_hf_params.items(): + key_to_type_and_offset_and_size_in_big_tensor.append( + ( + key, + tensor.shape, + tensor.dtype, + type_to_packed_big_tensor_size[tensor.dtype], + tensor.numel() + ) + ) + type_to_packed_big_tensor_size[tensor.dtype] += tensor.numel() + + type_to_packed_big_tensor_size = { + k: torch.empty(v, device=tensor.device, dtype=k, requires_grad=False) + for k, v in type_to_packed_big_tensor_size.items() + } + for i, (key, tensor) in enumerate(gathered_hf_params.items()): + k, shape, dtype, offset, size = key_to_type_and_offset_and_size_in_big_tensor[i] + assert k == key + type_to_packed_big_tensor_size[dtype][offset:offset+size] = tensor.detach().view(-1) + + all_handles = [] + for dtype, tensor in type_to_packed_big_tensor_size.items(): handle = reduce_tensor(tensor.detach()) - all_handles.append((key, handle)) + all_handles.append((dtype, handle)) # Store references to avoid premature garbage collection - self._held_gather_buffer = gathered_hf_params - shapes = {} - for key, tensor in gathered_hf_params.items(): - shapes[key] = tensor.shape - return {device_uuid: all_handles} + self._held_gather_buffer = type_to_packed_big_tensor_size + # shapes = {} + # for key, tensor in gathered_hf_params.items(): + # shapes[key] = tensor.shape + + serielized = (all_handles, key_to_type_and_offset_and_size_in_big_tensor) + + return {device_uuid: serielized} def prepare_for_lp_inference(self): self.model = self.move_model(self.model, "cuda", move_grads=False) From 4cfcd7b8ad2cf402b22e6d5d23824ed8f6a61063 Mon Sep 17 00:00:00 2001 From: Guyue Huang Date: Fri, 27 Jun 2025 11:05:24 -0700 Subject: [PATCH 2/6] Cleanup code Signed-off-by: Guyue Huang --- nemo_rl/models/generation/vllm_backend.py | 20 ++--- .../models/policy/megatron_policy_worker.py | 76 ++++++++++--------- 2 files changed, 52 insertions(+), 44 deletions(-) diff --git a/nemo_rl/models/generation/vllm_backend.py b/nemo_rl/models/generation/vllm_backend.py index 25fd941016..293cf4e8a3 100644 --- a/nemo_rl/models/generation/vllm_backend.py +++ b/nemo_rl/models/generation/vllm_backend.py @@ -58,22 +58,24 @@ def update_weights_from_ipc_handles(self, ipc_handles): try: # Get handles for this device device_uuid = self.report_device_id() - deserialized = ipc_handles[device_uuid] + serialized = ipc_handles[device_uuid] device_id = self.device.index weights = [] - all_handles, key_to_type_and_offset_and_size_in_big_tensor = deserialized - type_to_packed_big_tensor_size = {} - for k, tensor_handle in all_handles: + all_handles, tensor_metadata = serialized + + # Extract tensors from IPC handles + dtype_to_packed_tensor = {} + for dtype, tensor_handle in all_handles: func, args = tensor_handle list_args = list(args) list_args[6] = device_id tensor = func(*list_args) - type_to_packed_big_tensor_size[k] = tensor - - for key, shape, type, offset, size in key_to_type_and_offset_and_size_in_big_tensor: - assert offset+size <= type_to_packed_big_tensor_size[type].numel() - tensor = type_to_packed_big_tensor_size[type][offset:offset+size].clone().reshape(shape) + dtype_to_packed_tensor[dtype] = tensor + + # Unpack tensor to weights + for key, (shape, dtype, offset, size) in tensor_metadata.items(): + tensor = dtype_to_packed_tensor[dtype][offset:offset+size].clone().reshape(shape) weights.append((key, tensor)) # Load weights into the model diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 680ceffe78..69947dc589 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -1367,46 +1367,52 @@ def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: from torch.multiprocessing.reductions import reduce_tensor # Create IPC handles for each parameter + # Pack tensors in gathered_hf_params into consolidated tensors by dtype + # First calculate total size needed for each dtype + type_to_total_size = defaultdict(lambda: 0) + tensor_metadata = dict() - # pack tensors in gathered_hf_params to a big tensor - type_to_packed_big_tensor_size = defaultdict(lambda : 0) - key_to_type_and_offset_and_size_in_big_tensor = [] for key, tensor in gathered_hf_params.items(): - key_to_type_and_offset_and_size_in_big_tensor.append( - ( - key, - tensor.shape, - tensor.dtype, - type_to_packed_big_tensor_size[tensor.dtype], - tensor.numel() - ) + tensor_metadata[key] = ( + tensor.shape, # shape of the tensor + tensor.dtype, # dtype of the tensor + type_to_total_size[tensor.dtype], # offset of the tensor + # in packed buffer + tensor.numel() # size of the tensor + ) + type_to_total_size[tensor.dtype] += tensor.numel() + + # Allocate consolidated tensors for each dtype + packed_tensors = { + dtype: torch.empty( + total_size, + device=next(iter(gathered_hf_params.values())).device, + dtype=dtype, + requires_grad=False ) - type_to_packed_big_tensor_size[tensor.dtype] += tensor.numel() - - type_to_packed_big_tensor_size = { - k: torch.empty(v, device=tensor.device, dtype=k, requires_grad=False) - for k, v in type_to_packed_big_tensor_size.items() + for dtype, total_size in type_to_total_size.items() } - for i, (key, tensor) in enumerate(gathered_hf_params.items()): - k, shape, dtype, offset, size = key_to_type_and_offset_and_size_in_big_tensor[i] - assert k == key - type_to_packed_big_tensor_size[dtype][offset:offset+size] = tensor.detach().view(-1) - - all_handles = [] - for dtype, tensor in type_to_packed_big_tensor_size.items(): - handle = reduce_tensor(tensor.detach()) - all_handles.append((dtype, handle)) - - # Store references to avoid premature garbage collection - - self._held_gather_buffer = type_to_packed_big_tensor_size - # shapes = {} - # for key, tensor in gathered_hf_params.items(): - # shapes[key] = tensor.shape - - serielized = (all_handles, key_to_type_and_offset_and_size_in_big_tensor) - return {device_uuid: serielized} + # Copy tensors into consolidated buffers + for key, tensor in gathered_hf_params.items(): + metadata = tensor_metadata[key] + _, dtype, offset, size = metadata + packed_tensors[dtype][offset:offset + size].copy_( + tensor.detach().view(-1) + ) + + # Create IPC handles for consolidated tensors + all_handles = [ + (dtype, reduce_tensor(tensor.detach())) + for dtype, tensor in packed_tensors.items() + ] + + # Store reference to prevent garbage collection + self._held_gather_buffer = packed_tensors + + serialized = (all_handles, tensor_metadata) + + return {device_uuid: serialized} def prepare_for_lp_inference(self): self.model = self.move_model(self.model, "cuda", move_grads=False) From 1fff12813c6296b370651cae1456a565f03b74d1 Mon Sep 17 00:00:00 2001 From: Guyue Huang Date: Fri, 27 Jun 2025 11:55:14 -0700 Subject: [PATCH 3/6] use view instead of copy for unpacking Signed-off-by: Guyue Huang --- nemo_rl/models/generation/vllm_backend.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/nemo_rl/models/generation/vllm_backend.py b/nemo_rl/models/generation/vllm_backend.py index 293cf4e8a3..c013a027b3 100644 --- a/nemo_rl/models/generation/vllm_backend.py +++ b/nemo_rl/models/generation/vllm_backend.py @@ -73,9 +73,10 @@ def update_weights_from_ipc_handles(self, ipc_handles): tensor = func(*list_args) dtype_to_packed_tensor[dtype] = tensor - # Unpack tensor to weights + # Unpack tensor to weights. Here we only return a view of the tensor to avoid + # using extra memory. for key, (shape, dtype, offset, size) in tensor_metadata.items(): - tensor = dtype_to_packed_tensor[dtype][offset:offset+size].clone().reshape(shape) + tensor = dtype_to_packed_tensor[dtype][offset:offset+size].view(*shape) weights.append((key, tensor)) # Load weights into the model From 7dafae7ccb50721ac3af8bea06e709270984a92c Mon Sep 17 00:00:00 2001 From: Guyue Huang Date: Fri, 27 Jun 2025 13:08:09 -0700 Subject: [PATCH 4/6] Make it configurable to use packing or not Signed-off-by: Guyue Huang --- nemo_rl/models/generation/vllm_backend.py | 47 +++++---- .../models/policy/megatron_policy_worker.py | 96 +++++++++++-------- 2 files changed, 86 insertions(+), 57 deletions(-) diff --git a/nemo_rl/models/generation/vllm_backend.py b/nemo_rl/models/generation/vllm_backend.py index c013a027b3..c2168857e5 100644 --- a/nemo_rl/models/generation/vllm_backend.py +++ b/nemo_rl/models/generation/vllm_backend.py @@ -58,26 +58,39 @@ def update_weights_from_ipc_handles(self, ipc_handles): try: # Get handles for this device device_uuid = self.report_device_id() - serialized = ipc_handles[device_uuid] + handles = ipc_handles[device_uuid] + is_tensor_packed = handles[0] + if is_tensor_packed: + _, all_handles, tensor_metadata = handles + else: + _, name_and_handle_list = handles + device_id = self.device.index weights = [] - all_handles, tensor_metadata = serialized - - # Extract tensors from IPC handles - dtype_to_packed_tensor = {} - for dtype, tensor_handle in all_handles: - func, args = tensor_handle - list_args = list(args) - list_args[6] = device_id - tensor = func(*list_args) - dtype_to_packed_tensor[dtype] = tensor - - # Unpack tensor to weights. Here we only return a view of the tensor to avoid - # using extra memory. - for key, (shape, dtype, offset, size) in tensor_metadata.items(): - tensor = dtype_to_packed_tensor[dtype][offset:offset+size].view(*shape) - weights.append((key, tensor)) + if is_tensor_packed: + # Extract packed tensor from IPC handle + dtype_to_packed_tensor = {} + for dtype, tensor_handle in all_handles: + func, args = tensor_handle + list_args = list(args) + list_args[6] = device_id + tensor = func(*list_args) + dtype_to_packed_tensor[dtype] = tensor + + # Unpack tensor to weights. Here we only return a view of the tensor to avoid + # using extra memory. + for key, (shape, dtype, offset, size) in tensor_metadata.items(): + tensor = dtype_to_packed_tensor[dtype][offset:offset+size].view(*shape) + weights.append((key, tensor)) + else: + # Process each handle to get the tensor + for name, handle in name_and_handle_list: + func, args = handle + list_args = list(args) + list_args[6] = device_id + tensor = func(*list_args) + weights.append((name, tensor)) # Load weights into the model self.model_runner.model.load_weights(weights=weights) diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 69947dc589..fd07da9f4b 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -1367,50 +1367,66 @@ def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: from torch.multiprocessing.reductions import reduce_tensor # Create IPC handles for each parameter - # Pack tensors in gathered_hf_params into consolidated tensors by dtype - # First calculate total size needed for each dtype - type_to_total_size = defaultdict(lambda: 0) - tensor_metadata = dict() - - for key, tensor in gathered_hf_params.items(): - tensor_metadata[key] = ( - tensor.shape, # shape of the tensor - tensor.dtype, # dtype of the tensor - type_to_total_size[tensor.dtype], # offset of the tensor - # in packed buffer - tensor.numel() # size of the tensor - ) - type_to_total_size[tensor.dtype] += tensor.numel() - - # Allocate consolidated tensors for each dtype - packed_tensors = { - dtype: torch.empty( - total_size, - device=next(iter(gathered_hf_params.values())).device, - dtype=dtype, - requires_grad=False - ) - for dtype, total_size in type_to_total_size.items() - } + tensor_number_threshold = os.getenv( + "NEMO_RL_MEGATRON_IPC_TENSOR_PACKING_THRESHOLD", "32" + ) # an arbitrary threshold + if len(gathered_hf_params) >= int(tensor_number_threshold): + pack_tensor_for_ipc = True + else: + pack_tensor_for_ipc = False + + if pack_tensor_for_ipc: + # Pack tensors in gathered_hf_params into consolidated tensors by dtype + # First calculate total size needed for each dtype + type_to_total_size = defaultdict(lambda: 0) + tensor_metadata = dict() + + for key, tensor in gathered_hf_params.items(): + tensor_metadata[key] = ( + tensor.shape, # shape of the tensor + tensor.dtype, # dtype of the tensor + type_to_total_size[tensor.dtype], # offset of the tensor + # in packed buffer + tensor.numel() # size of the tensor + ) + type_to_total_size[tensor.dtype] += tensor.numel() + + # Allocate consolidated tensors for each dtype + packed_tensors = { + dtype: torch.empty( + total_size, + device=next(iter(gathered_hf_params.values())).device, + dtype=dtype, + requires_grad=False + ) + for dtype, total_size in type_to_total_size.items() + } - # Copy tensors into consolidated buffers - for key, tensor in gathered_hf_params.items(): - metadata = tensor_metadata[key] - _, dtype, offset, size = metadata - packed_tensors[dtype][offset:offset + size].copy_( - tensor.detach().view(-1) - ) + # Copy tensors into consolidated buffers + for key, tensor in gathered_hf_params.items(): + metadata = tensor_metadata[key] + _, dtype, offset, size = metadata + packed_tensors[dtype][offset:offset + size].copy_( + tensor.detach().view(-1) + ) - # Create IPC handles for consolidated tensors - all_handles = [ - (dtype, reduce_tensor(tensor.detach())) - for dtype, tensor in packed_tensors.items() - ] + # Create IPC handles for consolidated tensors + all_handles = [ + (dtype, reduce_tensor(tensor.detach())) + for dtype, tensor in packed_tensors.items() + ] - # Store reference to prevent garbage collection - self._held_gather_buffer = packed_tensors + # Store reference to prevent garbage collection + self._held_gather_buffer = packed_tensors - serialized = (all_handles, tensor_metadata) + serialized = (pack_tensor_for_ipc, all_handles, tensor_metadata) + else: + all_handles = [] + for key, tensor in gathered_hf_params.items(): + handle = reduce_tensor(tensor.detach()) + all_handles.append((key, handle)) + self._held_gather_buffer = gathered_hf_params + serialized = (False, all_handles) return {device_uuid: serialized} From e9953f60bfb8af526e5ea055510b3684b4f77e78 Mon Sep 17 00:00:00 2001 From: Guyue Huang Date: Mon, 7 Jul 2025 09:52:49 -0700 Subject: [PATCH 5/6] Add two flags for profiling RAY_PROFILING and NEMO_RL_TORCH_PROFILE_REFIT Signed-off-by: Guyue Huang --- nemo_rl/algorithms/grpo.py | 13 +++++++++--- nemo_rl/models/generation/vllm.py | 21 +++++++++++++++++++ nemo_rl/models/megatron/refit_utils.py | 1 - .../models/policy/megatron_policy_worker.py | 21 +++++++++++++++++++ 4 files changed, 52 insertions(+), 4 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index ac9335bcd4..fdd60c1d9b 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -20,6 +20,7 @@ import torch from torchdata.stateful_dataloader import StatefulDataLoader from transformers import PreTrainedTokenizerBase +from contextlib import nullcontext from nemo_rl.algorithms.interfaces import LossFunction from nemo_rl.algorithms.loss_functions import ( @@ -362,6 +363,9 @@ def setup( print(" " * 18 + "SETUP COMPLETE") print("=" * 60 + "\n") + if os.getenv("RAY_PROFILING", None) == "1": + ray.timeline(filename=os.getenv("NEMO_RL_RAY_TIMELINE_FILE", "/tmp/ray_timeline.json")) + return ( policy, policy_generation, @@ -386,6 +390,7 @@ def refit_policy_generation( policy_generation: GenerationInterface, colocated_inference: bool, _refit_buffer_size_gb: Optional[int] = None, + timer: Optional[Timer] = None, ) -> None: """Refit the policy generation interface with the latest policy weights. @@ -410,8 +415,10 @@ def refit_policy_generation( print(f"[Refit] Number of splits: {len(grouped_param_keys)}") # do update for keys in grouped_param_keys: - ipc_handles = policy.get_weights_ipc_handles(keys) - update_success = policy_generation.update_weights(ipc_handles) + with timer.time("prepare_for_generation/get_weights_ipc_handles") if timer else nullcontext(): + ipc_handles = policy.get_weights_ipc_handles(keys) + with timer.time("prepare_for_generation/update_weights") if timer else nullcontext(): + update_success = policy_generation.update_weights(ipc_handles) if not update_success: break else: @@ -528,7 +535,7 @@ def grpo_train( with timer.time("prepare_for_generation"): if NEED_REFIT and POLICY_GENERATION_STALE: refit_policy_generation( - policy, policy_generation, colocated_inference + policy, policy_generation, colocated_inference, timer=timer ) POLICY_GENERATION_STALE = False else: diff --git a/nemo_rl/models/generation/vllm.py b/nemo_rl/models/generation/vllm.py index cbb603f74e..cb37e29934 100644 --- a/nemo_rl/models/generation/vllm.py +++ b/nemo_rl/models/generation/vllm.py @@ -348,6 +348,21 @@ def _patch_vllm_init_workers_ray(): else: self.llm = vllm.LLM(**llm_kwargs) + # torch profiler + import socket + if os.getenv("NEMO_RL_TORCH_PROFILE_REFIT", "0") == "1" and self.is_model_owner and (0 in bundle_indices): + self.profiler = torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], + with_stack=True, + on_trace_ready=torch.profiler.tensorboard_trace_handler( + "grpo_refit_trace/update_weights_from_ipc_handles", + worker_name=f"{socket.gethostname()}_vllm_worker_{self.rank}", + use_gzip=True, + ), + ) + else: + self.profiler = None + def init_collective(self, data: int, ip: str, port: int, world_size: int) -> None: self.llm.collective_rpc( "init_collective", @@ -879,11 +894,17 @@ def update_weights_from_ipc_handles(self, data: dict[str, Any]) -> bool: "update_weights_from_ipc_handles cannot be used with async_engine=True. Use update_weights_from_ipc_handles_async instead." ) + if os.getenv("NEMO_RL_TORCH_PROFILE_REFIT", "0") == "1" and self.profiler is not None: + self.profiler.start() + result_or_coro = self.llm.collective_rpc( "update_weights_from_ipc_handles", args=(data,) ) worker_result = result_or_coro[0] + if os.getenv("NEMO_RL_TORCH_PROFILE_REFIT", "0") == "1" and self.profiler is not None: + self.profiler.stop() + if not worker_result: print( f"Error: Worker failed to update weights. Result: {worker_result}" diff --git a/nemo_rl/models/megatron/refit_utils.py b/nemo_rl/models/megatron/refit_utils.py index e6ca825e0a..07081b4532 100644 --- a/nemo_rl/models/megatron/refit_utils.py +++ b/nemo_rl/models/megatron/refit_utils.py @@ -169,5 +169,4 @@ def gather_params( if k is not None: gathered_params[k] = p - print(f"Time taken to gather params: {time.perf_counter() - st}") return gathered_params diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 8543d07aa9..3969014162 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -669,6 +669,21 @@ def __init__( state_dict_info=self.prepare_weights_for_ipc()[0] ) + # torch profiler + import socket + if os.getenv("NEMO_RL_TORCH_PROFILE_REFIT", "0") == "1" and get_rank_safe() == 0: + self.profiler = torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], + with_stack=True, + on_trace_ready=torch.profiler.tensorboard_trace_handler( + "grpo_refit_trace/get_weight_ipc_handles", + worker_name=f"{socket.gethostname()}_megatron_policy_worker_0", + use_gzip=True, + ), + ) + else: + self.profiler = None + def configure_worker(self, num_gpus: int, bundle_indices: Optional[tuple] = None): USE_EXPANDABLE_SEGMENTS = False # Disabling this right now as it seems to cause vLLM refit issues with Ampere if USE_EXPANDABLE_SEGMENTS: @@ -1436,6 +1451,9 @@ def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: Returns: Dict mapping device UUID to list of (mapped_key, handle) tuples """ + if os.getenv("NEMO_RL_TORCH_PROFILE_REFIT", "0") == "1" and self.profiler is not None: + self.profiler.start() + if self._held_gather_buffer is not None: del self._held_gather_buffer self._held_gather_buffer = None @@ -1516,6 +1534,9 @@ def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: self._held_gather_buffer = gathered_hf_params serialized = (False, all_handles) + if os.getenv("NEMO_RL_TORCH_PROFILE_REFIT", "0") == "1" and self.profiler is not None: + self.profiler.stop() + return {device_uuid: serialized} def prepare_for_lp_inference(self): From 114e3b040a22f7141b447e941ab0554679a8af06 Mon Sep 17 00:00:00 2001 From: Guyue Huang Date: Mon, 7 Jul 2025 11:35:37 -0700 Subject: [PATCH 6/6] Profile refit for one rank one time Signed-off-by: Guyue Huang --- nemo_rl/models/generation/vllm.py | 17 +++++++++++++++-- nemo_rl/models/policy/megatron_policy_worker.py | 8 ++++++-- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/nemo_rl/models/generation/vllm.py b/nemo_rl/models/generation/vllm.py index cb37e29934..49249eb685 100644 --- a/nemo_rl/models/generation/vllm.py +++ b/nemo_rl/models/generation/vllm.py @@ -362,6 +362,7 @@ def _patch_vllm_init_workers_ray(): ) else: self.profiler = None + self.maybe_profile_refit_times = 0 def init_collective(self, data: int, ip: str, port: int, world_size: int) -> None: self.llm.collective_rpc( @@ -895,7 +896,9 @@ def update_weights_from_ipc_handles(self, data: dict[str, Any]) -> bool: ) if os.getenv("NEMO_RL_TORCH_PROFILE_REFIT", "0") == "1" and self.profiler is not None: - self.profiler.start() + self.maybe_profile_refit_times += 1 + if self.maybe_profile_refit_times == 3: + self.profiler.start() result_or_coro = self.llm.collective_rpc( "update_weights_from_ipc_handles", args=(data,) @@ -903,7 +906,8 @@ def update_weights_from_ipc_handles(self, data: dict[str, Any]) -> bool: worker_result = result_or_coro[0] if os.getenv("NEMO_RL_TORCH_PROFILE_REFIT", "0") == "1" and self.profiler is not None: - self.profiler.stop() + if self.maybe_profile_refit_times == 3: + self.profiler.stop() if not worker_result: print( @@ -937,6 +941,11 @@ async def update_weights_from_ipc_handles_async(self, data: dict[str, Any]) -> b "update_weights_from_ipc_handles_async can only be used with async_engine=True. Use update_weights_from_ipc_handles instead." ) + if os.getenv("NEMO_RL_TORCH_PROFILE_REFIT", "0") == "1" and self.profiler is not None: + self.maybe_profile_refit_times += 1 + if self.maybe_profile_refit_times == 3: + self.profiler.start() + result_or_coro = await self.llm.collective_rpc( "update_weights_from_ipc_handles", args=(data,) ) @@ -948,6 +957,10 @@ async def update_weights_from_ipc_handles_async(self, data: dict[str, Any]) -> b worker_result = worker_results[0] + if os.getenv("NEMO_RL_TORCH_PROFILE_REFIT", "0") == "1" and self.profiler is not None: + if self.maybe_profile_refit_times == 3: + self.profiler.stop() + if not worker_result: print( f"Error: Worker failed to update weights. Result: {worker_result}" diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 3969014162..6eea16d018 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -683,6 +683,7 @@ def __init__( ) else: self.profiler = None + self.maybe_profile_refit_times = 0 def configure_worker(self, num_gpus: int, bundle_indices: Optional[tuple] = None): USE_EXPANDABLE_SEGMENTS = False # Disabling this right now as it seems to cause vLLM refit issues with Ampere @@ -1452,7 +1453,9 @@ def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: Dict mapping device UUID to list of (mapped_key, handle) tuples """ if os.getenv("NEMO_RL_TORCH_PROFILE_REFIT", "0") == "1" and self.profiler is not None: - self.profiler.start() + self.maybe_profile_refit_times += 1 + if self.maybe_profile_refit_times == 3: + self.profiler.start() if self._held_gather_buffer is not None: del self._held_gather_buffer @@ -1535,7 +1538,8 @@ def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: serialized = (False, all_handles) if os.getenv("NEMO_RL_TORCH_PROFILE_REFIT", "0") == "1" and self.profiler is not None: - self.profiler.stop() + if self.maybe_profile_refit_times == 3: + self.profiler.stop() return {device_uuid: serialized}