From e1d7a52054b0ab953d859f917b3c4ceecc24056b Mon Sep 17 00:00:00 2001 From: Youngeun Kwon Date: Mon, 6 Oct 2025 18:03:49 -0700 Subject: [PATCH 1/5] draft Signed-off-by: Youngeun Kwon --- .../models/generation/vllm/vllm_backend.py | 76 +++++++++++++++---- .../models/policy/dtensor_policy_worker_v2.py | 33 ++++++-- .../models/policy/megatron_policy_worker.py | 29 ++++++- nemo_rl/utils/packed_tensor.py | 60 +++++++++++++++ pyrefly.toml | 1 + 5 files changed, 176 insertions(+), 23 deletions(-) create mode 100644 nemo_rl/utils/packed_tensor.py diff --git a/nemo_rl/models/generation/vllm/vllm_backend.py b/nemo_rl/models/generation/vllm/vllm_backend.py index 895506e4b4..937cc3ba86 100644 --- a/nemo_rl/models/generation/vllm/vllm_backend.py +++ b/nemo_rl/models/generation/vllm/vllm_backend.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import math from collections import defaultdict from typing import Any, Optional @@ -18,6 +19,7 @@ from torch.multiprocessing.reductions import rebuild_cuda_tensor from nemo_rl.utils.nsys import wrap_with_nvtx_name +from nemo_rl.utils.packed_tensor import get_target_packed_tensor_size, unpack_tensor try: import vllm # noqa: F401 @@ -186,23 +188,67 @@ def update_weights_from_collective(self) -> bool: "Please call prepare_refit_info when initializing the worker." ) - try: - for name, (shape, dtype) in self.state_dict_info.items(): - weight = torch.empty(shape, dtype=dtype, device="cuda") - self.model_update_group.broadcast(weight, src=0) + def load_model_weights(weights): + """Load model weights. - from nemo_rl.models.generation import fp8 + Args: + weights: List[(name, tensor)] - if fp8.is_fp8_model(self.model_runner.vllm_config): - # the fp8 load_weights additionally casts bf16 weights into fp8 - fp8.load_weights([(name, weight)], self.model_runner) - else: - self.model_runner.model.load_weights(weights=[(name, weight)]) - except Exception as e: - print( - f"Error in VllmInternalWorkerExtension.update_weights_from_collective: {e}" - ) - return False + Returns: + None + """ + from nemo_rl.models.generation import fp8 + + if fp8.is_fp8_model(self.model_runner.vllm_config): + # the fp8 load_weights additionally casts bf16 weights into fp8 + fp8.load_weights(weights, self.model_runner) + else: + self.model_runner.model.load_weights(weights=weights) + + target_packed_tensor_size = get_target_packed_tensor_size() + + hf_params_iterator = iter(self.state_dict_info.items()) + + while True: + # Form a packed tensor + packed_tensor_meta_data = [] + packed_tensor_sizes = [] + offset = 0 + try: + while True: + # Form a packed tensor + name, (shape, dtype) = next(hf_params_iterator) + tensor_size = math.prod(shape) * dtype.itemsize + packed_tensor_meta_data.append( + (name, shape, dtype, offset, tensor_size) + ) + packed_tensor_sizes.append(tensor_size) + offset += tensor_size + if sum(packed_tensor_sizes) > target_packed_tensor_size: + break + # Create a packed tensor and broadcast it + packed_tensor = torch.empty( + sum(packed_tensor_sizes), dtype=torch.uint8, device="cuda" + ) + self.model_update_group.broadcast(packed_tensor, src=0) + # Load the packed tensor into the model + load_model_weights( + unpack_tensor(packed_tensor, packed_tensor_meta_data) + ) + except StopIteration: + break + finally: + if len(packed_tensor_meta_data) > 0: + # do the last broadcast + # Create a packed tensor and broadcast it + packed_tensor = torch.empty( + sum(packed_tensor_sizes), dtype=torch.uint8, device="cuda" + ) + self.model_update_group.broadcast(packed_tensor, src=0) + # Load the packed tensor into the model + load_model_weights( + unpack_tensor(packed_tensor, packed_tensor_meta_data) + ) return True diff --git a/nemo_rl/models/policy/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/dtensor_policy_worker_v2.py index ee5cea0b5d..c826344ba8 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/dtensor_policy_worker_v2.py @@ -94,6 +94,7 @@ ) from nemo_rl.utils.checkpoint import CheckpointingConfig from nemo_rl.utils.nsys import wrap_with_nvtx_name +from nemo_rl.utils.packed_tensor import get_target_packed_tensor_size, pack_tensor @ray.remote( @@ -1767,11 +1768,33 @@ def broadcast_weights_for_collective(self) -> None: self.model = self.move_to_cuda(self.model) # Broadcast the weights for collective communication - for _, tensor in self.model.state_dict().items(): - if isinstance(tensor, DTensor): - tensor = tensor.full_tensor() - tensor = tensor.to(self.dtype, non_blocking=True) - self.model_update_group.broadcast(tensor.data, src=0) + target_packed_tensor_size = get_target_packed_tensor_size() + weight_iterator = iter(self.model.state_dict().items()) + + while True: + # Form a packed tensor + packed_tensor_list = [] + packed_tensor_sizes = [] + try: + while True: + name, tensor = next(weight_iterator) + if isinstance(tensor, DTensor): + tensor = tensor.full_tensor() + tensor = tensor.to(self.dtype, non_blocking=True) + packed_tensor_list.append((name, tensor)) + packed_tensor_sizes.append( + tensor.view(torch.uint8).view(-1).numel() + ) + if sum(packed_tensor_sizes) > target_packed_tensor_size: + break + packed_tensor = pack_tensor(packed_tensor_list) + self.model_update_group.broadcast(packed_tensor, src=0) + except StopIteration: + break + finally: + if len(packed_tensor_list) > 0: + packed_tensor = pack_tensor(packed_tensor_list) + self.model_update_group.broadcast(packed_tensor, src=0) # Manually move model to cpu for cpu offload case # cpu offload needs model on CPU before model forward diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 326ae9fe61..7657fa2fec 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -128,6 +128,7 @@ get_runtime_env_for_policy_worker, ) from nemo_rl.utils.nsys import wrap_with_nvtx_name +from nemo_rl.utils.packed_tensor import get_target_packed_tensor_size, pack_tensor TokenizerType = TypeVar("TokenizerType", bound=PreTrainedTokenizerBase) @@ -1738,9 +1739,31 @@ def broadcast_weights_for_collective(self) -> None: [self.model], show_progress=False, ) - # broadcast from train rank 0 to all other ranks (training and inference) - for _, tensor in hf_params_generator: - self.model_update_group.broadcast(tensor, src=0) + target_packed_tensor_size = get_target_packed_tensor_size() + + while True: + # Form a packed tensor + packed_tensor_list = [] + packed_tensor_sizes = [] + try: + while True: + name, tensor = next(hf_params_generator) + packed_tensor_list.append((name, tensor)) + packed_tensor_sizes.append( + tensor.view(torch.uint8).view(-1).numel() + ) + if sum(packed_tensor_sizes) > target_packed_tensor_size: + break + # Concatenate the tensors into a single tensor and broadcast it + packed_tensor = pack_tensor(packed_tensor_list) + self.model_update_group.broadcast(packed_tensor, src=0) + except StopIteration: + break + finally: + if len(packed_tensor_list) > 0: + # do the last broadcast + packed_tensor = pack_tensor(packed_tensor_list) + self.model_update_group.broadcast(packed_tensor, src=0) def prepare_for_lp_inference(self): self.model = self.move_model(self.model, "cuda", move_grads=False) diff --git a/nemo_rl/utils/packed_tensor.py b/nemo_rl/utils/packed_tensor.py new file mode 100644 index 0000000000..3f1d4bd3ac --- /dev/null +++ b/nemo_rl/utils/packed_tensor.py @@ -0,0 +1,60 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Any, List, Tuple + +import torch + + +def get_target_packed_tensor_size(): + packed_tensor_bucket_size = os.getenv("NRL_PACKED_TENSOR_SIZE_TARGET_IN_MB", "500") + return int(packed_tensor_bucket_size) * 1024 * 1024 + + +def pack_tensor(packed_tensor_list: list[torch.Tensor]) -> torch.Tensor: + """Pack a list of tensors into a single tensor.""" + # Perform batched concatenation with torch.cat + return torch.cat( + [tensor.view(torch.uint8).view(-1) for _, tensor in packed_tensor_list], dim=0 + ) + + +def unpack_tensor( + packed_tensor: torch.Tensor, meta_data_list: list[Any] +) -> List[Tuple[str, torch.Tensor]]: + """Unpack a single tensor into a list of tensors. + + Args: + packed_tensor: the packed torch.uint8 tensor to unpack + meta_data_list: List[(name, shape, dtype, offset, tensor_size)] + + Returns: + unpacked List[(name, tensor)] + """ + unpacked_list = [] + # Perform batched split with torch.split_with_sizes + packed_tensor_sizes = [tensor_size for _, _, _, _, tensor_size in meta_data_list] + unpacked_tensor = packed_tensor.split_with_sizes(packed_tensor_sizes) + + for i, tensor in enumerate(unpacked_tensor): + # unpacked_list = List[(name, torch.Tensor.view(dtype).view(*shape))] + unpacked_list.append( + ( + meta_data_list[i][0], + tensor.view(meta_data_list[i][2]).view(*meta_data_list[i][1]), + ) + ) + + return unpacked_list diff --git a/pyrefly.toml b/pyrefly.toml index 03cbdc3f3b..9484fff639 100644 --- a/pyrefly.toml +++ b/pyrefly.toml @@ -109,6 +109,7 @@ project-includes = [ "nemo_rl/utils/automodel_checkpoint.py", "nemo_rl/utils/nsys.py", "nemo_rl/utils/nvml.py", + "nemo_rl/utils/packed_tensor.py", "nemo_rl/utils/prefetch_venvs.py", "nemo_rl/utils/timer.py", "nemo_rl/utils/venvs.py", From 085282d075f93d36a19119332cd7ab43dc5f4ba5 Mon Sep 17 00:00:00 2001 From: Youngeun Kwon Date: Wed, 8 Oct 2025 10:55:11 -0700 Subject: [PATCH 2/5] update dtensor v1 path Signed-off-by: Youngeun Kwon --- .../models/policy/dtensor_policy_worker.py | 33 ++++++++++++++++--- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index cfe524be8d..5337477a0c 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -83,6 +83,7 @@ save_checkpoint, ) from nemo_rl.utils.nsys import wrap_with_nvtx_name +from nemo_rl.utils.packed_tensor import get_target_packed_tensor_size, pack_tensor @contextmanager @@ -1806,11 +1807,33 @@ def broadcast_weights_for_collective(self) -> None: self.model = self.move_to_cuda(self.model) # Broadcast the weights for collective communication - for _, tensor in self.model.state_dict().items(): - if isinstance(tensor, DTensor): - tensor = tensor.full_tensor() - tensor = tensor.to(self.dtype, non_blocking=True) - self.model_update_group.broadcast(tensor.data, src=0) + target_packed_tensor_size = get_target_packed_tensor_size() + weight_iterator = iter(self.model.state_dict().items()) + + while True: + # Form a packed tensor + packed_tensor_list = [] + packed_tensor_sizes = [] + try: + while True: + name, tensor = next(weight_iterator) + if isinstance(tensor, DTensor): + tensor = tensor.full_tensor() + tensor = tensor.to(self.dtype, non_blocking=True) + packed_tensor_list.append((name, tensor)) + packed_tensor_sizes.append( + tensor.view(torch.uint8).view(-1).numel() + ) + if sum(packed_tensor_sizes) > target_packed_tensor_size: + break + packed_tensor = pack_tensor(packed_tensor_list) + self.model_update_group.broadcast(packed_tensor, src=0) + except StopIteration: + break + finally: + if len(packed_tensor_list) > 0: + packed_tensor = pack_tensor(packed_tensor_list) + self.model_update_group.broadcast(packed_tensor, src=0) # Manually move model to cpu for cpu offload case # cpu offload needs model on CPU before model forward From e12c0fd57059a8d8e052ef2313d81cabd96e8f99 Mon Sep 17 00:00:00 2001 From: Youngeun Kwon Date: Wed, 8 Oct 2025 15:44:51 -0700 Subject: [PATCH 3/5] code refactoring Signed-off-by: Youngeun Kwon --- .../models/generation/vllm/vllm_backend.py | 70 +++------ .../models/policy/dtensor_policy_worker.py | 45 ++---- .../models/policy/dtensor_policy_worker_v2.py | 45 ++---- .../models/policy/megatron_policy_worker.py | 33 +---- nemo_rl/utils/packed_tensor.py | 137 ++++++++++++++---- 5 files changed, 171 insertions(+), 159 deletions(-) diff --git a/nemo_rl/models/generation/vllm/vllm_backend.py b/nemo_rl/models/generation/vllm/vllm_backend.py index 937cc3ba86..14db473ae9 100644 --- a/nemo_rl/models/generation/vllm/vllm_backend.py +++ b/nemo_rl/models/generation/vllm/vllm_backend.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math from collections import defaultdict from typing import Any, Optional @@ -19,7 +18,7 @@ from torch.multiprocessing.reductions import rebuild_cuda_tensor from nemo_rl.utils.nsys import wrap_with_nvtx_name -from nemo_rl.utils.packed_tensor import get_target_packed_tensor_size, unpack_tensor +from nemo_rl.utils.packed_tensor import packed_broadcast_consumer try: import vllm # noqa: F401 @@ -188,67 +187,38 @@ def update_weights_from_collective(self) -> bool: "Please call prepare_refit_info when initializing the worker." ) - def load_model_weights(weights): + def _load_model_weights(weights, model_runner): """Load model weights. Args: weights: List[(name, tensor)] + model_runner: vLLM ModelRunner Returns: None """ from nemo_rl.models.generation import fp8 - if fp8.is_fp8_model(self.model_runner.vllm_config): + if fp8.is_fp8_model(model_runner.vllm_config): # the fp8 load_weights additionally casts bf16 weights into fp8 - fp8.load_weights(weights, self.model_runner) + fp8.load_weights(weights, model_runner) else: - self.model_runner.model.load_weights(weights=weights) + model_runner.model.load_weights(weights=weights) - target_packed_tensor_size = get_target_packed_tensor_size() - - hf_params_iterator = iter(self.state_dict_info.items()) - - while True: - # Form a packed tensor - packed_tensor_meta_data = [] - packed_tensor_sizes = [] - offset = 0 - try: - while True: - # Form a packed tensor - name, (shape, dtype) = next(hf_params_iterator) - tensor_size = math.prod(shape) * dtype.itemsize - packed_tensor_meta_data.append( - (name, shape, dtype, offset, tensor_size) - ) - packed_tensor_sizes.append(tensor_size) - offset += tensor_size - if sum(packed_tensor_sizes) > target_packed_tensor_size: - break - # Create a packed tensor and broadcast it - packed_tensor = torch.empty( - sum(packed_tensor_sizes), dtype=torch.uint8, device="cuda" - ) - self.model_update_group.broadcast(packed_tensor, src=0) - # Load the packed tensor into the model - load_model_weights( - unpack_tensor(packed_tensor, packed_tensor_meta_data) - ) - except StopIteration: - break - finally: - if len(packed_tensor_meta_data) > 0: - # do the last broadcast - # Create a packed tensor and broadcast it - packed_tensor = torch.empty( - sum(packed_tensor_sizes), dtype=torch.uint8, device="cuda" - ) - self.model_update_group.broadcast(packed_tensor, src=0) - # Load the packed tensor into the model - load_model_weights( - unpack_tensor(packed_tensor, packed_tensor_meta_data) - ) + load_model_weight_func = lambda x: _load_model_weights(x, self.model_runner) + + try: + packed_broadcast_consumer( + iterator=iter(self.state_dict_info.items()), + group=self.model_update_group, + src=0, + post_unpack_func=load_model_weight_func, + ) + except Exception as e: + print( + f"Error in VllmInternalWorkerExtension.update_weights_from_collective: {e}" + ) + return False return True diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 5337477a0c..78d1d37ad7 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -83,7 +83,7 @@ save_checkpoint, ) from nemo_rl.utils.nsys import wrap_with_nvtx_name -from nemo_rl.utils.packed_tensor import get_target_packed_tensor_size, pack_tensor +from nemo_rl.utils.packed_tensor import packed_broadcast_producer @contextmanager @@ -1806,34 +1806,21 @@ def broadcast_weights_for_collective(self) -> None: ) self.model = self.move_to_cuda(self.model) - # Broadcast the weights for collective communication - target_packed_tensor_size = get_target_packed_tensor_size() - weight_iterator = iter(self.model.state_dict().items()) - - while True: - # Form a packed tensor - packed_tensor_list = [] - packed_tensor_sizes = [] - try: - while True: - name, tensor = next(weight_iterator) - if isinstance(tensor, DTensor): - tensor = tensor.full_tensor() - tensor = tensor.to(self.dtype, non_blocking=True) - packed_tensor_list.append((name, tensor)) - packed_tensor_sizes.append( - tensor.view(torch.uint8).view(-1).numel() - ) - if sum(packed_tensor_sizes) > target_packed_tensor_size: - break - packed_tensor = pack_tensor(packed_tensor_list) - self.model_update_group.broadcast(packed_tensor, src=0) - except StopIteration: - break - finally: - if len(packed_tensor_list) > 0: - packed_tensor = pack_tensor(packed_tensor_list) - self.model_update_group.broadcast(packed_tensor, src=0) + def _dtensor_post_iter_func(tensor, dtype): + if isinstance(tensor, DTensor): + tensor = tensor.full_tensor() + tensor = tensor.to(dtype, non_blocking=True) + return tensor + + # param_iterator will return (name, tensor), we only need tensor + dtensor_post_iter_func = lambda x: _dtensor_post_iter_func(x[1], self.dtype) + + packed_broadcast_producer( + iterator=iter(self.model.state_dict().items()), + group=self.model_update_group, + src=0, + post_iter_func=dtensor_post_iter_func, + ) # Manually move model to cpu for cpu offload case # cpu offload needs model on CPU before model forward diff --git a/nemo_rl/models/policy/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/dtensor_policy_worker_v2.py index c826344ba8..893060a3a2 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/dtensor_policy_worker_v2.py @@ -94,7 +94,7 @@ ) from nemo_rl.utils.checkpoint import CheckpointingConfig from nemo_rl.utils.nsys import wrap_with_nvtx_name -from nemo_rl.utils.packed_tensor import get_target_packed_tensor_size, pack_tensor +from nemo_rl.utils.packed_tensor import packed_broadcast_producer @ray.remote( @@ -1767,34 +1767,21 @@ def broadcast_weights_for_collective(self) -> None: ) self.model = self.move_to_cuda(self.model) - # Broadcast the weights for collective communication - target_packed_tensor_size = get_target_packed_tensor_size() - weight_iterator = iter(self.model.state_dict().items()) - - while True: - # Form a packed tensor - packed_tensor_list = [] - packed_tensor_sizes = [] - try: - while True: - name, tensor = next(weight_iterator) - if isinstance(tensor, DTensor): - tensor = tensor.full_tensor() - tensor = tensor.to(self.dtype, non_blocking=True) - packed_tensor_list.append((name, tensor)) - packed_tensor_sizes.append( - tensor.view(torch.uint8).view(-1).numel() - ) - if sum(packed_tensor_sizes) > target_packed_tensor_size: - break - packed_tensor = pack_tensor(packed_tensor_list) - self.model_update_group.broadcast(packed_tensor, src=0) - except StopIteration: - break - finally: - if len(packed_tensor_list) > 0: - packed_tensor = pack_tensor(packed_tensor_list) - self.model_update_group.broadcast(packed_tensor, src=0) + def _dtensor_post_iter_func(tensor, dtype): + if isinstance(tensor, DTensor): + tensor = tensor.full_tensor() + tensor = tensor.to(dtype, non_blocking=True) + return tensor + + # param_iterator will return (name, tensor), we only need tensor + dtensor_post_iter_func = lambda x: _dtensor_post_iter_func(x[1], self.dtype) + + packed_broadcast_producer( + iterator=iter(self.model.state_dict().items()), + group=self.model_update_group, + src=0, + post_iter_func=dtensor_post_iter_func, + ) # Manually move model to cpu for cpu offload case # cpu offload needs model on CPU before model forward diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 7657fa2fec..b7931962ea 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -128,7 +128,7 @@ get_runtime_env_for_policy_worker, ) from nemo_rl.utils.nsys import wrap_with_nvtx_name -from nemo_rl.utils.packed_tensor import get_target_packed_tensor_size, pack_tensor +from nemo_rl.utils.packed_tensor import packed_broadcast_producer TokenizerType = TypeVar("TokenizerType", bound=PreTrainedTokenizerBase) @@ -1739,31 +1739,14 @@ def broadcast_weights_for_collective(self) -> None: [self.model], show_progress=False, ) - target_packed_tensor_size = get_target_packed_tensor_size() - while True: - # Form a packed tensor - packed_tensor_list = [] - packed_tensor_sizes = [] - try: - while True: - name, tensor = next(hf_params_generator) - packed_tensor_list.append((name, tensor)) - packed_tensor_sizes.append( - tensor.view(torch.uint8).view(-1).numel() - ) - if sum(packed_tensor_sizes) > target_packed_tensor_size: - break - # Concatenate the tensors into a single tensor and broadcast it - packed_tensor = pack_tensor(packed_tensor_list) - self.model_update_group.broadcast(packed_tensor, src=0) - except StopIteration: - break - finally: - if len(packed_tensor_list) > 0: - # do the last broadcast - packed_tensor = pack_tensor(packed_tensor_list) - self.model_update_group.broadcast(packed_tensor, src=0) + # param_iterator will return (name, tensor), we only need tensor + packed_broadcast_producer( + iterator=hf_params_generator, + group=self.model_update_group, + src=0, + post_iter_func=lambda x: x[1], + ) def prepare_for_lp_inference(self): self.model = self.move_model(self.model, "cuda", move_grads=False) diff --git a/nemo_rl/utils/packed_tensor.py b/nemo_rl/utils/packed_tensor.py index 3f1d4bd3ac..7d79c788ef 100644 --- a/nemo_rl/utils/packed_tensor.py +++ b/nemo_rl/utils/packed_tensor.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math import os from typing import Any, List, Tuple @@ -23,38 +24,122 @@ def get_target_packed_tensor_size(): return int(packed_tensor_bucket_size) * 1024 * 1024 -def pack_tensor(packed_tensor_list: list[torch.Tensor]) -> torch.Tensor: - """Pack a list of tensors into a single tensor.""" - # Perform batched concatenation with torch.cat - return torch.cat( - [tensor.view(torch.uint8).view(-1) for _, tensor in packed_tensor_list], dim=0 - ) +def packed_broadcast_producer(iterator, group, src, post_iter_func): + """Broadcast a list of tensors in a packed manner. + Args: + iterator: iterator of model parameters. Returns a tuple of (name, tensor) + group: process group (vllm PyNcclCommunicator) + src: source rank (0 in current implementation) + post_iter_func: function to apply to each tensor before packing, should return a tensor + + Returns: + None + + """ + target_packed_tensor_size = get_target_packed_tensor_size() + + while True: + # Form a packed tensor + packing_tensor_list = [] + packing_tensor_sizes = 0 + try: + while True: + # Apply backend specific post processing and then convert to linearized uint8 tensor + tensor = post_iter_func(next(iterator)).view(torch.uint8).view(-1) + packing_tensor_list.append(tensor) + packing_tensor_sizes += tensor.view(torch.uint8).numel() + if packing_tensor_sizes > target_packed_tensor_size: + break + # Pack the tensors and call broadcast collective + packed_tensor = torch.cat(packing_tensor_list, dim=0) + group.broadcast(packed_tensor, src=src) + except StopIteration: + # do the last broadcast if there are remaining tensors + if len(packing_tensor_list) > 0: + packed_tensor = torch.cat(packing_tensor_list, dim=0) + group.broadcast(packed_tensor, src=src) + break -def unpack_tensor( - packed_tensor: torch.Tensor, meta_data_list: list[Any] -) -> List[Tuple[str, torch.Tensor]]: - """Unpack a single tensor into a list of tensors. + +def packed_broadcast_consumer(iterator, group, src, post_unpack_func): + """Consume a packed tensor and unpack it into a list of tensors. Args: - packed_tensor: the packed torch.uint8 tensor to unpack - meta_data_list: List[(name, shape, dtype, offset, tensor_size)] + iterator: iterator of model parameters. Returns a tuple of (name, tensor) + group: process group (vllm PyNcclCommunicator) + src: source rank (0 in current implementation) + post_unpack_func: function to apply to each tensor after unpacking Returns: - unpacked List[(name, tensor)] + None + """ - unpacked_list = [] - # Perform batched split with torch.split_with_sizes - packed_tensor_sizes = [tensor_size for _, _, _, _, tensor_size in meta_data_list] - unpacked_tensor = packed_tensor.split_with_sizes(packed_tensor_sizes) - - for i, tensor in enumerate(unpacked_tensor): - # unpacked_list = List[(name, torch.Tensor.view(dtype).view(*shape))] - unpacked_list.append( - ( - meta_data_list[i][0], - tensor.view(meta_data_list[i][2]).view(*meta_data_list[i][1]), + + def unpack_tensor( + packed_tensor: torch.Tensor, meta_data_list: list[Any] + ) -> List[Tuple[str, torch.Tensor]]: + """Unpack a single tensor into a list of tensors. + + Args: + packed_tensor: the packed torch.uint8 tensor to unpack + meta_data_list: List[(name, shape, dtype, offset, tensor_size)] + + Returns: + unpacked List[(name, tensor)] + """ + unpacked_list = [] + # Perform batched split with torch.split_with_sizes + packed_tensor_sizes = [ + tensor_size for _, _, _, _, tensor_size in meta_data_list + ] + unpacked_tensor = packed_tensor.split_with_sizes(packed_tensor_sizes) + + for i, tensor in enumerate(unpacked_tensor): + # unpacked_list = List[(name, torch.Tensor.view(dtype).view(*shape))] + unpacked_list.append( + ( + meta_data_list[i][0], + tensor.view(meta_data_list[i][2]).view(*meta_data_list[i][1]), + ) ) - ) - return unpacked_list + return unpacked_list + + target_packed_tensor_size = get_target_packed_tensor_size() + + while True: + # Form a packed tensor + packing_tensor_meta_data = [] + packing_tensor_sizes = 0 + offset = 0 + try: + while True: + # Form a packed tensor + name, (shape, dtype) = next(iterator) + tensor_size = math.prod(shape) * dtype.itemsize + packing_tensor_meta_data.append( + (name, shape, dtype, offset, tensor_size) + ) + packing_tensor_sizes += tensor_size + offset += tensor_size + if packing_tensor_sizes > target_packed_tensor_size: + break + # Create a packed tensor and broadcast it + packed_tensor = torch.empty( + packing_tensor_sizes, dtype=torch.uint8, device="cuda" + ) + group.broadcast(packed_tensor, src=src) + # Load the packed tensor into the model + post_unpack_func(unpack_tensor(packed_tensor, packing_tensor_meta_data)) + except StopIteration: + # do the last broadcast if there are remaining tensors + if len(packing_tensor_meta_data) > 0: + # Create a packed tensor and broadcast it + packed_tensor = torch.empty( + packing_tensor_sizes, dtype=torch.uint8, device="cuda" + ) + group.broadcast(packed_tensor, src=src) + # Load the packed tensor into the model + post_unpack_func(unpack_tensor(packed_tensor, packing_tensor_meta_data)) + break From 1b915218c1f625fff80f51d04809598bf89dee5e Mon Sep 17 00:00:00 2001 From: Youngeun Kwon Date: Thu, 9 Oct 2025 13:31:31 -0700 Subject: [PATCH 4/5] use the same env var Signed-off-by: Youngeun Kwon --- nemo_rl/utils/packed_tensor.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/nemo_rl/utils/packed_tensor.py b/nemo_rl/utils/packed_tensor.py index 7d79c788ef..09ea9d1e50 100644 --- a/nemo_rl/utils/packed_tensor.py +++ b/nemo_rl/utils/packed_tensor.py @@ -14,14 +14,20 @@ import math import os +from functools import lru_cache from typing import Any, List, Tuple import torch +@lru_cache(maxsize=1) def get_target_packed_tensor_size(): - packed_tensor_bucket_size = os.getenv("NRL_PACKED_TENSOR_SIZE_TARGET_IN_MB", "500") - return int(packed_tensor_bucket_size) * 1024 * 1024 + memory_ratio = os.getenv("NRL_REFIT_BUFFER_MEMORY_RATIO", "0.01") + device = torch.device("cuda") + props = torch.cuda.get_device_properties(device) + total_memory_bytes = props.total_memory + target_size = int(total_memory_bytes * float(memory_ratio)) + return target_size def packed_broadcast_producer(iterator, group, src, post_iter_func): @@ -90,19 +96,17 @@ def unpack_tensor( """ unpacked_list = [] # Perform batched split with torch.split_with_sizes - packed_tensor_sizes = [ - tensor_size for _, _, _, _, tensor_size in meta_data_list - ] + packed_tensor_sizes = list(map(lambda x: x[4], meta_data_list)) unpacked_tensor = packed_tensor.split_with_sizes(packed_tensor_sizes) - for i, tensor in enumerate(unpacked_tensor): - # unpacked_list = List[(name, torch.Tensor.view(dtype).view(*shape))] - unpacked_list.append( - ( - meta_data_list[i][0], - tensor.view(meta_data_list[i][2]).view(*meta_data_list[i][1]), - ) + # unpacked_list = List[(name, torch.Tensor.view(dtype).view(*shape))] + unpacked_list = [ + ( + meta_data_list[i][0], + tensor.view(meta_data_list[i][2]).view(*meta_data_list[i][1]), ) + for i, tensor in enumerate(unpacked_tensor) + ] return unpacked_list From f6e334643f8c27e4392a4201a40ff49db54cec98 Mon Sep 17 00:00:00 2001 From: Youngeun Kwon Date: Fri, 10 Oct 2025 17:53:33 -0700 Subject: [PATCH 5/5] add tests Signed-off-by: Youngeun Kwon --- nemo_rl/utils/packed_tensor.py | 3 +- tests/unit/utils/test_packed_tensor.py | 199 +++++++++++++++++++++++++ 2 files changed, 201 insertions(+), 1 deletion(-) create mode 100644 tests/unit/utils/test_packed_tensor.py diff --git a/nemo_rl/utils/packed_tensor.py b/nemo_rl/utils/packed_tensor.py index 09ea9d1e50..53eb915b64 100644 --- a/nemo_rl/utils/packed_tensor.py +++ b/nemo_rl/utils/packed_tensor.py @@ -26,7 +26,8 @@ def get_target_packed_tensor_size(): device = torch.device("cuda") props = torch.cuda.get_device_properties(device) total_memory_bytes = props.total_memory - target_size = int(total_memory_bytes * float(memory_ratio)) + # max size is 5GB + target_size = min(int(total_memory_bytes * float(memory_ratio)), 5 * 1024**3) return target_size diff --git a/tests/unit/utils/test_packed_tensor.py b/tests/unit/utils/test_packed_tensor.py new file mode 100644 index 0000000000..6d321bd32a --- /dev/null +++ b/tests/unit/utils/test_packed_tensor.py @@ -0,0 +1,199 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import patch + +import pytest +import torch + +from nemo_rl.utils.packed_tensor import ( + packed_broadcast_consumer, + packed_broadcast_producer, +) + + +class MockCommunicationGroup: + """Mock communication group for testing broadcast operations.""" + + def __init__(self): + self.broadcasted_tensors = [] + self.broadcast_count = 0 + + def broadcast(self, tensor, src): + """Mock broadcast that stores the tensor for later verification.""" + # Store a copy of the tensor + self.broadcasted_tensors.append(tensor.clone()) + self.broadcast_count += 1 + + +class MockConsumerCommunicationGroup: + """Mock communication group for consumer that returns pre-stored tensors.""" + + def __init__(self, tensors_to_return): + self.tensors_to_return = tensors_to_return + self.current_index = 0 + + def broadcast(self, tensor, src): + """Mock broadcast that fills the tensor with pre-stored data.""" + if self.current_index < len(self.tensors_to_return): + tensor.copy_(self.tensors_to_return[self.current_index]) + self.current_index += 1 + + +def create_mock_model_params(): + """Create mock model parameters for testing.""" + params = [ + ("layer1.weight", torch.randn(10, 20, dtype=torch.float32)), + ("layer1.bias", torch.randn(10, dtype=torch.float32)), + ("layer2.weight", torch.randn(20, 30, dtype=torch.float32)), + ("layer2.bias", torch.randn(20, dtype=torch.float32)), + ("layer3.weight", torch.randn(30, 40, dtype=torch.float16)), + ] + return params + + +def create_mock_state_dict_info(params): + """Create state dict info (name -> (shape, dtype)) from params.""" + return {name: (tensor.shape, tensor.dtype) for name, tensor in params} + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_packed_broadcast_producer_consumer_roundtrip(): + """Test that producer and consumer work together correctly.""" + # Create mock parameters + params = create_mock_model_params() + + # Move params to CUDA + params_cuda = [(name, tensor.cuda()) for name, tensor in params] + + # Create mock communication group for producer + producer_group = MockCommunicationGroup() + + # Mock the target size to force packing + target_size = 2000 + with patch( + "nemo_rl.utils.packed_tensor.get_target_packed_tensor_size", + return_value=target_size, + ): + # Post-iter function that just returns the tensor + post_iter_func = lambda x: x[1] + + # Run producer + packed_broadcast_producer( + iterator=iter(params_cuda), + group=producer_group, + src=0, + post_iter_func=post_iter_func, + ) + + # Now test consumer with the broadcasted tensors + consumer_group = MockConsumerCommunicationGroup( + producer_group.broadcasted_tensors + ) + + # Create state dict info for consumer + state_dict_info = create_mock_state_dict_info(params_cuda) + + # Store unpacked tensors + unpacked_tensors = {} + + def post_unpack_func(tensor_list): + """Store unpacked tensors for verification.""" + for name, tensor in tensor_list: + unpacked_tensors[name] = tensor + + # Run consumer + packed_broadcast_consumer( + iterator=iter(state_dict_info.items()), + group=consumer_group, + src=0, + post_unpack_func=post_unpack_func, + ) + + # Verify all parameters were unpacked + assert len(unpacked_tensors) == len(params) + + # Verify each tensor matches the original + for name, original_tensor in params_cuda: + assert name in unpacked_tensors + unpacked = unpacked_tensors[name] + + # Check shape and dtype + assert unpacked.shape == original_tensor.shape + assert unpacked.dtype == original_tensor.dtype + + # Check values are close (accounting for floating point precision) + assert torch.allclose(unpacked, original_tensor, rtol=1e-5, atol=1e-7) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_packed_broadcast_single_large_tensor(): + """Test with a single tensor larger than target size.""" + # Create a large tensor + large_tensor = torch.randn(1000, 1000, dtype=torch.float32).cuda() + params = [("large_weight", large_tensor)] + + # Create mock communication group + mock_group = MockCommunicationGroup() + + # Small target size to force the tensor to exceed it + with patch( + "nemo_rl.utils.packed_tensor.get_target_packed_tensor_size", return_value=100 + ): + packed_broadcast_producer( + iterator=iter(params), + group=mock_group, + src=0, + post_iter_func=lambda x: x[1], + ) + + # Should still broadcast the tensor + assert mock_group.broadcast_count == 1 + assert len(mock_group.broadcasted_tensors) == 1 + + # Verify the size matches the large tensor + expected_size = large_tensor.numel() * large_tensor.element_size() + assert mock_group.broadcasted_tensors[0].numel() == expected_size + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_packed_broadcast_multiple_batches(): + """Test that tensors are properly batched when exceeding target size.""" + # Create many small tensors + params = [ + (f"weight_{i}", torch.randn(10, 10, dtype=torch.float32).cuda()) + for i in range(20) + ] + + # Create mock communication group + mock_group = MockCommunicationGroup() + + # Small target size to force multiple batches + with patch( + "nemo_rl.utils.packed_tensor.get_target_packed_tensor_size", return_value=2000 + ): + packed_broadcast_producer( + iterator=iter(params), + group=mock_group, + src=0, + post_iter_func=lambda x: x[1], + ) + + # Should have multiple broadcasts + assert mock_group.broadcast_count > 1 + + # Total size should match sum of all tensors + total_broadcasted_size = sum(t.numel() for t in mock_group.broadcasted_tensors) + expected_total_size = sum(t.numel() * t.element_size() for _, t in params) + assert total_broadcasted_size == expected_total_size