diff --git a/nemo_rl/models/generation/vllm/vllm_backend.py b/nemo_rl/models/generation/vllm/vllm_backend.py index 895506e4b4..14db473ae9 100644 --- a/nemo_rl/models/generation/vllm/vllm_backend.py +++ b/nemo_rl/models/generation/vllm/vllm_backend.py @@ -18,6 +18,7 @@ from torch.multiprocessing.reductions import rebuild_cuda_tensor from nemo_rl.utils.nsys import wrap_with_nvtx_name +from nemo_rl.utils.packed_tensor import packed_broadcast_consumer try: import vllm # noqa: F401 @@ -186,18 +187,33 @@ def update_weights_from_collective(self) -> bool: "Please call prepare_refit_info when initializing the worker." ) - try: - for name, (shape, dtype) in self.state_dict_info.items(): - weight = torch.empty(shape, dtype=dtype, device="cuda") - self.model_update_group.broadcast(weight, src=0) + def _load_model_weights(weights, model_runner): + """Load model weights. + + Args: + weights: List[(name, tensor)] + model_runner: vLLM ModelRunner - from nemo_rl.models.generation import fp8 + Returns: + None + """ + from nemo_rl.models.generation import fp8 - if fp8.is_fp8_model(self.model_runner.vllm_config): - # the fp8 load_weights additionally casts bf16 weights into fp8 - fp8.load_weights([(name, weight)], self.model_runner) - else: - self.model_runner.model.load_weights(weights=[(name, weight)]) + if fp8.is_fp8_model(model_runner.vllm_config): + # the fp8 load_weights additionally casts bf16 weights into fp8 + fp8.load_weights(weights, model_runner) + else: + model_runner.model.load_weights(weights=weights) + + load_model_weight_func = lambda x: _load_model_weights(x, self.model_runner) + + try: + packed_broadcast_consumer( + iterator=iter(self.state_dict_info.items()), + group=self.model_update_group, + src=0, + post_unpack_func=load_model_weight_func, + ) except Exception as e: print( f"Error in VllmInternalWorkerExtension.update_weights_from_collective: {e}" diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index cfe524be8d..78d1d37ad7 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -83,6 +83,7 @@ save_checkpoint, ) from nemo_rl.utils.nsys import wrap_with_nvtx_name +from nemo_rl.utils.packed_tensor import packed_broadcast_producer @contextmanager @@ -1805,12 +1806,21 @@ def broadcast_weights_for_collective(self) -> None: ) self.model = self.move_to_cuda(self.model) - # Broadcast the weights for collective communication - for _, tensor in self.model.state_dict().items(): + def _dtensor_post_iter_func(tensor, dtype): if isinstance(tensor, DTensor): tensor = tensor.full_tensor() - tensor = tensor.to(self.dtype, non_blocking=True) - self.model_update_group.broadcast(tensor.data, src=0) + tensor = tensor.to(dtype, non_blocking=True) + return tensor + + # param_iterator will return (name, tensor), we only need tensor + dtensor_post_iter_func = lambda x: _dtensor_post_iter_func(x[1], self.dtype) + + packed_broadcast_producer( + iterator=iter(self.model.state_dict().items()), + group=self.model_update_group, + src=0, + post_iter_func=dtensor_post_iter_func, + ) # Manually move model to cpu for cpu offload case # cpu offload needs model on CPU before model forward diff --git a/nemo_rl/models/policy/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/dtensor_policy_worker_v2.py index ee5cea0b5d..893060a3a2 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/dtensor_policy_worker_v2.py @@ -94,6 +94,7 @@ ) from nemo_rl.utils.checkpoint import CheckpointingConfig from nemo_rl.utils.nsys import wrap_with_nvtx_name +from nemo_rl.utils.packed_tensor import packed_broadcast_producer @ray.remote( @@ -1766,12 +1767,21 @@ def broadcast_weights_for_collective(self) -> None: ) self.model = self.move_to_cuda(self.model) - # Broadcast the weights for collective communication - for _, tensor in self.model.state_dict().items(): + def _dtensor_post_iter_func(tensor, dtype): if isinstance(tensor, DTensor): tensor = tensor.full_tensor() - tensor = tensor.to(self.dtype, non_blocking=True) - self.model_update_group.broadcast(tensor.data, src=0) + tensor = tensor.to(dtype, non_blocking=True) + return tensor + + # param_iterator will return (name, tensor), we only need tensor + dtensor_post_iter_func = lambda x: _dtensor_post_iter_func(x[1], self.dtype) + + packed_broadcast_producer( + iterator=iter(self.model.state_dict().items()), + group=self.model_update_group, + src=0, + post_iter_func=dtensor_post_iter_func, + ) # Manually move model to cpu for cpu offload case # cpu offload needs model on CPU before model forward diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 326ae9fe61..b7931962ea 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -128,6 +128,7 @@ get_runtime_env_for_policy_worker, ) from nemo_rl.utils.nsys import wrap_with_nvtx_name +from nemo_rl.utils.packed_tensor import packed_broadcast_producer TokenizerType = TypeVar("TokenizerType", bound=PreTrainedTokenizerBase) @@ -1738,9 +1739,14 @@ def broadcast_weights_for_collective(self) -> None: [self.model], show_progress=False, ) - # broadcast from train rank 0 to all other ranks (training and inference) - for _, tensor in hf_params_generator: - self.model_update_group.broadcast(tensor, src=0) + + # param_iterator will return (name, tensor), we only need tensor + packed_broadcast_producer( + iterator=hf_params_generator, + group=self.model_update_group, + src=0, + post_iter_func=lambda x: x[1], + ) def prepare_for_lp_inference(self): self.model = self.move_model(self.model, "cuda", move_grads=False) diff --git a/nemo_rl/utils/packed_tensor.py b/nemo_rl/utils/packed_tensor.py new file mode 100644 index 0000000000..53eb915b64 --- /dev/null +++ b/nemo_rl/utils/packed_tensor.py @@ -0,0 +1,150 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import os +from functools import lru_cache +from typing import Any, List, Tuple + +import torch + + +@lru_cache(maxsize=1) +def get_target_packed_tensor_size(): + memory_ratio = os.getenv("NRL_REFIT_BUFFER_MEMORY_RATIO", "0.01") + device = torch.device("cuda") + props = torch.cuda.get_device_properties(device) + total_memory_bytes = props.total_memory + # max size is 5GB + target_size = min(int(total_memory_bytes * float(memory_ratio)), 5 * 1024**3) + return target_size + + +def packed_broadcast_producer(iterator, group, src, post_iter_func): + """Broadcast a list of tensors in a packed manner. + + Args: + iterator: iterator of model parameters. Returns a tuple of (name, tensor) + group: process group (vllm PyNcclCommunicator) + src: source rank (0 in current implementation) + post_iter_func: function to apply to each tensor before packing, should return a tensor + + Returns: + None + + """ + target_packed_tensor_size = get_target_packed_tensor_size() + + while True: + # Form a packed tensor + packing_tensor_list = [] + packing_tensor_sizes = 0 + try: + while True: + # Apply backend specific post processing and then convert to linearized uint8 tensor + tensor = post_iter_func(next(iterator)).view(torch.uint8).view(-1) + packing_tensor_list.append(tensor) + packing_tensor_sizes += tensor.view(torch.uint8).numel() + if packing_tensor_sizes > target_packed_tensor_size: + break + # Pack the tensors and call broadcast collective + packed_tensor = torch.cat(packing_tensor_list, dim=0) + group.broadcast(packed_tensor, src=src) + except StopIteration: + # do the last broadcast if there are remaining tensors + if len(packing_tensor_list) > 0: + packed_tensor = torch.cat(packing_tensor_list, dim=0) + group.broadcast(packed_tensor, src=src) + break + + +def packed_broadcast_consumer(iterator, group, src, post_unpack_func): + """Consume a packed tensor and unpack it into a list of tensors. + + Args: + iterator: iterator of model parameters. Returns a tuple of (name, tensor) + group: process group (vllm PyNcclCommunicator) + src: source rank (0 in current implementation) + post_unpack_func: function to apply to each tensor after unpacking + + Returns: + None + + """ + + def unpack_tensor( + packed_tensor: torch.Tensor, meta_data_list: list[Any] + ) -> List[Tuple[str, torch.Tensor]]: + """Unpack a single tensor into a list of tensors. + + Args: + packed_tensor: the packed torch.uint8 tensor to unpack + meta_data_list: List[(name, shape, dtype, offset, tensor_size)] + + Returns: + unpacked List[(name, tensor)] + """ + unpacked_list = [] + # Perform batched split with torch.split_with_sizes + packed_tensor_sizes = list(map(lambda x: x[4], meta_data_list)) + unpacked_tensor = packed_tensor.split_with_sizes(packed_tensor_sizes) + + # unpacked_list = List[(name, torch.Tensor.view(dtype).view(*shape))] + unpacked_list = [ + ( + meta_data_list[i][0], + tensor.view(meta_data_list[i][2]).view(*meta_data_list[i][1]), + ) + for i, tensor in enumerate(unpacked_tensor) + ] + + return unpacked_list + + target_packed_tensor_size = get_target_packed_tensor_size() + + while True: + # Form a packed tensor + packing_tensor_meta_data = [] + packing_tensor_sizes = 0 + offset = 0 + try: + while True: + # Form a packed tensor + name, (shape, dtype) = next(iterator) + tensor_size = math.prod(shape) * dtype.itemsize + packing_tensor_meta_data.append( + (name, shape, dtype, offset, tensor_size) + ) + packing_tensor_sizes += tensor_size + offset += tensor_size + if packing_tensor_sizes > target_packed_tensor_size: + break + # Create a packed tensor and broadcast it + packed_tensor = torch.empty( + packing_tensor_sizes, dtype=torch.uint8, device="cuda" + ) + group.broadcast(packed_tensor, src=src) + # Load the packed tensor into the model + post_unpack_func(unpack_tensor(packed_tensor, packing_tensor_meta_data)) + except StopIteration: + # do the last broadcast if there are remaining tensors + if len(packing_tensor_meta_data) > 0: + # Create a packed tensor and broadcast it + packed_tensor = torch.empty( + packing_tensor_sizes, dtype=torch.uint8, device="cuda" + ) + group.broadcast(packed_tensor, src=src) + # Load the packed tensor into the model + post_unpack_func(unpack_tensor(packed_tensor, packing_tensor_meta_data)) + break diff --git a/pyrefly.toml b/pyrefly.toml index 03cbdc3f3b..9484fff639 100644 --- a/pyrefly.toml +++ b/pyrefly.toml @@ -109,6 +109,7 @@ project-includes = [ "nemo_rl/utils/automodel_checkpoint.py", "nemo_rl/utils/nsys.py", "nemo_rl/utils/nvml.py", + "nemo_rl/utils/packed_tensor.py", "nemo_rl/utils/prefetch_venvs.py", "nemo_rl/utils/timer.py", "nemo_rl/utils/venvs.py", diff --git a/tests/unit/utils/test_packed_tensor.py b/tests/unit/utils/test_packed_tensor.py new file mode 100644 index 0000000000..6d321bd32a --- /dev/null +++ b/tests/unit/utils/test_packed_tensor.py @@ -0,0 +1,199 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import patch + +import pytest +import torch + +from nemo_rl.utils.packed_tensor import ( + packed_broadcast_consumer, + packed_broadcast_producer, +) + + +class MockCommunicationGroup: + """Mock communication group for testing broadcast operations.""" + + def __init__(self): + self.broadcasted_tensors = [] + self.broadcast_count = 0 + + def broadcast(self, tensor, src): + """Mock broadcast that stores the tensor for later verification.""" + # Store a copy of the tensor + self.broadcasted_tensors.append(tensor.clone()) + self.broadcast_count += 1 + + +class MockConsumerCommunicationGroup: + """Mock communication group for consumer that returns pre-stored tensors.""" + + def __init__(self, tensors_to_return): + self.tensors_to_return = tensors_to_return + self.current_index = 0 + + def broadcast(self, tensor, src): + """Mock broadcast that fills the tensor with pre-stored data.""" + if self.current_index < len(self.tensors_to_return): + tensor.copy_(self.tensors_to_return[self.current_index]) + self.current_index += 1 + + +def create_mock_model_params(): + """Create mock model parameters for testing.""" + params = [ + ("layer1.weight", torch.randn(10, 20, dtype=torch.float32)), + ("layer1.bias", torch.randn(10, dtype=torch.float32)), + ("layer2.weight", torch.randn(20, 30, dtype=torch.float32)), + ("layer2.bias", torch.randn(20, dtype=torch.float32)), + ("layer3.weight", torch.randn(30, 40, dtype=torch.float16)), + ] + return params + + +def create_mock_state_dict_info(params): + """Create state dict info (name -> (shape, dtype)) from params.""" + return {name: (tensor.shape, tensor.dtype) for name, tensor in params} + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_packed_broadcast_producer_consumer_roundtrip(): + """Test that producer and consumer work together correctly.""" + # Create mock parameters + params = create_mock_model_params() + + # Move params to CUDA + params_cuda = [(name, tensor.cuda()) for name, tensor in params] + + # Create mock communication group for producer + producer_group = MockCommunicationGroup() + + # Mock the target size to force packing + target_size = 2000 + with patch( + "nemo_rl.utils.packed_tensor.get_target_packed_tensor_size", + return_value=target_size, + ): + # Post-iter function that just returns the tensor + post_iter_func = lambda x: x[1] + + # Run producer + packed_broadcast_producer( + iterator=iter(params_cuda), + group=producer_group, + src=0, + post_iter_func=post_iter_func, + ) + + # Now test consumer with the broadcasted tensors + consumer_group = MockConsumerCommunicationGroup( + producer_group.broadcasted_tensors + ) + + # Create state dict info for consumer + state_dict_info = create_mock_state_dict_info(params_cuda) + + # Store unpacked tensors + unpacked_tensors = {} + + def post_unpack_func(tensor_list): + """Store unpacked tensors for verification.""" + for name, tensor in tensor_list: + unpacked_tensors[name] = tensor + + # Run consumer + packed_broadcast_consumer( + iterator=iter(state_dict_info.items()), + group=consumer_group, + src=0, + post_unpack_func=post_unpack_func, + ) + + # Verify all parameters were unpacked + assert len(unpacked_tensors) == len(params) + + # Verify each tensor matches the original + for name, original_tensor in params_cuda: + assert name in unpacked_tensors + unpacked = unpacked_tensors[name] + + # Check shape and dtype + assert unpacked.shape == original_tensor.shape + assert unpacked.dtype == original_tensor.dtype + + # Check values are close (accounting for floating point precision) + assert torch.allclose(unpacked, original_tensor, rtol=1e-5, atol=1e-7) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_packed_broadcast_single_large_tensor(): + """Test with a single tensor larger than target size.""" + # Create a large tensor + large_tensor = torch.randn(1000, 1000, dtype=torch.float32).cuda() + params = [("large_weight", large_tensor)] + + # Create mock communication group + mock_group = MockCommunicationGroup() + + # Small target size to force the tensor to exceed it + with patch( + "nemo_rl.utils.packed_tensor.get_target_packed_tensor_size", return_value=100 + ): + packed_broadcast_producer( + iterator=iter(params), + group=mock_group, + src=0, + post_iter_func=lambda x: x[1], + ) + + # Should still broadcast the tensor + assert mock_group.broadcast_count == 1 + assert len(mock_group.broadcasted_tensors) == 1 + + # Verify the size matches the large tensor + expected_size = large_tensor.numel() * large_tensor.element_size() + assert mock_group.broadcasted_tensors[0].numel() == expected_size + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_packed_broadcast_multiple_batches(): + """Test that tensors are properly batched when exceeding target size.""" + # Create many small tensors + params = [ + (f"weight_{i}", torch.randn(10, 10, dtype=torch.float32).cuda()) + for i in range(20) + ] + + # Create mock communication group + mock_group = MockCommunicationGroup() + + # Small target size to force multiple batches + with patch( + "nemo_rl.utils.packed_tensor.get_target_packed_tensor_size", return_value=2000 + ): + packed_broadcast_producer( + iterator=iter(params), + group=mock_group, + src=0, + post_iter_func=lambda x: x[1], + ) + + # Should have multiple broadcasts + assert mock_group.broadcast_count > 1 + + # Total size should match sum of all tensors + total_broadcasted_size = sum(t.numel() for t in mock_group.broadcasted_tensors) + expected_total_size = sum(t.numel() * t.element_size() for _, t in params) + assert total_broadcasted_size == expected_total_size