Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 26 additions & 10 deletions nemo_rl/models/generation/vllm/vllm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torch.multiprocessing.reductions import rebuild_cuda_tensor

from nemo_rl.utils.nsys import wrap_with_nvtx_name
from nemo_rl.utils.packed_tensor import packed_broadcast_consumer

try:
import vllm # noqa: F401
Expand Down Expand Up @@ -186,18 +187,33 @@ def update_weights_from_collective(self) -> bool:
"Please call prepare_refit_info when initializing the worker."
)

try:
for name, (shape, dtype) in self.state_dict_info.items():
weight = torch.empty(shape, dtype=dtype, device="cuda")
self.model_update_group.broadcast(weight, src=0)
def _load_model_weights(weights, model_runner):
"""Load model weights.

Args:
weights: List[(name, tensor)]
model_runner: vLLM ModelRunner

from nemo_rl.models.generation import fp8
Returns:
None
"""
from nemo_rl.models.generation import fp8

if fp8.is_fp8_model(self.model_runner.vllm_config):
# the fp8 load_weights additionally casts bf16 weights into fp8
fp8.load_weights([(name, weight)], self.model_runner)
else:
self.model_runner.model.load_weights(weights=[(name, weight)])
if fp8.is_fp8_model(model_runner.vllm_config):
# the fp8 load_weights additionally casts bf16 weights into fp8
fp8.load_weights(weights, model_runner)
else:
model_runner.model.load_weights(weights=weights)

load_model_weight_func = lambda x: _load_model_weights(x, self.model_runner)

try:
packed_broadcast_consumer(
iterator=iter(self.state_dict_info.items()),
group=self.model_update_group,
src=0,
post_unpack_func=load_model_weight_func,
)
except Exception as e:
print(
f"Error in VllmInternalWorkerExtension.update_weights_from_collective: {e}"
Expand Down
18 changes: 14 additions & 4 deletions nemo_rl/models/policy/dtensor_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
save_checkpoint,
)
from nemo_rl.utils.nsys import wrap_with_nvtx_name
from nemo_rl.utils.packed_tensor import packed_broadcast_producer


@contextmanager
Expand Down Expand Up @@ -1805,12 +1806,21 @@ def broadcast_weights_for_collective(self) -> None:
)
self.model = self.move_to_cuda(self.model)

# Broadcast the weights for collective communication
for _, tensor in self.model.state_dict().items():
def _dtensor_post_iter_func(tensor, dtype):
if isinstance(tensor, DTensor):
tensor = tensor.full_tensor()
tensor = tensor.to(self.dtype, non_blocking=True)
self.model_update_group.broadcast(tensor.data, src=0)
tensor = tensor.to(dtype, non_blocking=True)
return tensor

# param_iterator will return (name, tensor), we only need tensor
dtensor_post_iter_func = lambda x: _dtensor_post_iter_func(x[1], self.dtype)

packed_broadcast_producer(
iterator=iter(self.model.state_dict().items()),
group=self.model_update_group,
src=0,
post_iter_func=dtensor_post_iter_func,
)

# Manually move model to cpu for cpu offload case
# cpu offload needs model on CPU before model forward
Expand Down
18 changes: 14 additions & 4 deletions nemo_rl/models/policy/dtensor_policy_worker_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
)
from nemo_rl.utils.checkpoint import CheckpointingConfig
from nemo_rl.utils.nsys import wrap_with_nvtx_name
from nemo_rl.utils.packed_tensor import packed_broadcast_producer


@ray.remote(
Expand Down Expand Up @@ -1766,12 +1767,21 @@ def broadcast_weights_for_collective(self) -> None:
)
self.model = self.move_to_cuda(self.model)

# Broadcast the weights for collective communication
for _, tensor in self.model.state_dict().items():
def _dtensor_post_iter_func(tensor, dtype):
if isinstance(tensor, DTensor):
tensor = tensor.full_tensor()
tensor = tensor.to(self.dtype, non_blocking=True)
self.model_update_group.broadcast(tensor.data, src=0)
tensor = tensor.to(dtype, non_blocking=True)
return tensor

# param_iterator will return (name, tensor), we only need tensor
dtensor_post_iter_func = lambda x: _dtensor_post_iter_func(x[1], self.dtype)

packed_broadcast_producer(
iterator=iter(self.model.state_dict().items()),
group=self.model_update_group,
src=0,
post_iter_func=dtensor_post_iter_func,
)

# Manually move model to cpu for cpu offload case
# cpu offload needs model on CPU before model forward
Expand Down
12 changes: 9 additions & 3 deletions nemo_rl/models/policy/megatron_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@
get_runtime_env_for_policy_worker,
)
from nemo_rl.utils.nsys import wrap_with_nvtx_name
from nemo_rl.utils.packed_tensor import packed_broadcast_producer

TokenizerType = TypeVar("TokenizerType", bound=PreTrainedTokenizerBase)

Expand Down Expand Up @@ -1738,9 +1739,14 @@ def broadcast_weights_for_collective(self) -> None:
[self.model],
show_progress=False,
)
# broadcast from train rank 0 to all other ranks (training and inference)
for _, tensor in hf_params_generator:
self.model_update_group.broadcast(tensor, src=0)

# param_iterator will return (name, tensor), we only need tensor
packed_broadcast_producer(
iterator=hf_params_generator,
group=self.model_update_group,
src=0,
post_iter_func=lambda x: x[1],
)

def prepare_for_lp_inference(self):
self.model = self.move_model(self.model, "cuda", move_grads=False)
Expand Down
150 changes: 150 additions & 0 deletions nemo_rl/utils/packed_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import os
from functools import lru_cache
from typing import Any, List, Tuple

import torch


@lru_cache(maxsize=1)
def get_target_packed_tensor_size():
memory_ratio = os.getenv("NRL_REFIT_BUFFER_MEMORY_RATIO", "0.01")
device = torch.device("cuda")
props = torch.cuda.get_device_properties(device)
total_memory_bytes = props.total_memory
# max size is 5GB
target_size = min(int(total_memory_bytes * float(memory_ratio)), 5 * 1024**3)
return target_size


def packed_broadcast_producer(iterator, group, src, post_iter_func):
"""Broadcast a list of tensors in a packed manner.

Args:
iterator: iterator of model parameters. Returns a tuple of (name, tensor)
group: process group (vllm PyNcclCommunicator)
src: source rank (0 in current implementation)
post_iter_func: function to apply to each tensor before packing, should return a tensor

Returns:
None

"""
target_packed_tensor_size = get_target_packed_tensor_size()

while True:
# Form a packed tensor
packing_tensor_list = []
packing_tensor_sizes = 0
try:
while True:
# Apply backend specific post processing and then convert to linearized uint8 tensor
tensor = post_iter_func(next(iterator)).view(torch.uint8).view(-1)
packing_tensor_list.append(tensor)
packing_tensor_sizes += tensor.view(torch.uint8).numel()
if packing_tensor_sizes > target_packed_tensor_size:
break
# Pack the tensors and call broadcast collective
packed_tensor = torch.cat(packing_tensor_list, dim=0)
group.broadcast(packed_tensor, src=src)
except StopIteration:
# do the last broadcast if there are remaining tensors
if len(packing_tensor_list) > 0:
packed_tensor = torch.cat(packing_tensor_list, dim=0)
group.broadcast(packed_tensor, src=src)
break


def packed_broadcast_consumer(iterator, group, src, post_unpack_func):
"""Consume a packed tensor and unpack it into a list of tensors.

Args:
iterator: iterator of model parameters. Returns a tuple of (name, tensor)
group: process group (vllm PyNcclCommunicator)
src: source rank (0 in current implementation)
post_unpack_func: function to apply to each tensor after unpacking

Returns:
None

"""

def unpack_tensor(
packed_tensor: torch.Tensor, meta_data_list: list[Any]
) -> List[Tuple[str, torch.Tensor]]:
"""Unpack a single tensor into a list of tensors.

Args:
packed_tensor: the packed torch.uint8 tensor to unpack
meta_data_list: List[(name, shape, dtype, offset, tensor_size)]

Returns:
unpacked List[(name, tensor)]
"""
unpacked_list = []
# Perform batched split with torch.split_with_sizes
packed_tensor_sizes = list(map(lambda x: x[4], meta_data_list))
unpacked_tensor = packed_tensor.split_with_sizes(packed_tensor_sizes)

# unpacked_list = List[(name, torch.Tensor.view(dtype).view(*shape))]
unpacked_list = [
(
meta_data_list[i][0],
tensor.view(meta_data_list[i][2]).view(*meta_data_list[i][1]),
)
for i, tensor in enumerate(unpacked_tensor)
]

return unpacked_list

target_packed_tensor_size = get_target_packed_tensor_size()

while True:
# Form a packed tensor
packing_tensor_meta_data = []
packing_tensor_sizes = 0
offset = 0
try:
while True:
# Form a packed tensor
name, (shape, dtype) = next(iterator)
tensor_size = math.prod(shape) * dtype.itemsize
packing_tensor_meta_data.append(
(name, shape, dtype, offset, tensor_size)
)
packing_tensor_sizes += tensor_size
offset += tensor_size
if packing_tensor_sizes > target_packed_tensor_size:
break
# Create a packed tensor and broadcast it
packed_tensor = torch.empty(
packing_tensor_sizes, dtype=torch.uint8, device="cuda"
)
group.broadcast(packed_tensor, src=src)
# Load the packed tensor into the model
post_unpack_func(unpack_tensor(packed_tensor, packing_tensor_meta_data))
except StopIteration:
# do the last broadcast if there are remaining tensors
if len(packing_tensor_meta_data) > 0:
# Create a packed tensor and broadcast it
packed_tensor = torch.empty(
packing_tensor_sizes, dtype=torch.uint8, device="cuda"
)
group.broadcast(packed_tensor, src=src)
# Load the packed tensor into the model
post_unpack_func(unpack_tensor(packed_tensor, packing_tensor_meta_data))
break
1 change: 1 addition & 0 deletions pyrefly.toml
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ project-includes = [
"nemo_rl/utils/automodel_checkpoint.py",
"nemo_rl/utils/nsys.py",
"nemo_rl/utils/nvml.py",
"nemo_rl/utils/packed_tensor.py",
"nemo_rl/utils/prefetch_venvs.py",
"nemo_rl/utils/timer.py",
"nemo_rl/utils/venvs.py",
Expand Down
Loading
Loading