Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions nemo_reinforcer/models/generation/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def configure_worker(

# Force vllm to use v0 runtime (will be enabled by default in #51)
env_vars["VLLM_USE_V1"] = "0"

return resources, env_vars, init_kwargs

def __init__(
Expand Down Expand Up @@ -379,9 +380,6 @@ def shutdown(self):
return False

def report_device_id(self) -> str:
# from vllm.platforms import current_platform
# self.device_uuid = current_platform.get_device_uuid(self.rank)
# return self.device_uuid
return self.llm.collective_rpc("report_device_id", args=tuple())[0]

def update_weights_from_ipc_handles(self, ipc_handles):
Expand Down
5 changes: 2 additions & 3 deletions nemo_reinforcer/models/generation/vllm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,9 @@

class UpdatableVllmInternalWorker(Worker):
def report_device_id(self) -> str:
from vllm.platforms import current_platform
from nemo_reinforcer.utils.nvml import get_device_uuid

self.device_uuid = current_platform.get_device_uuid(self.device.index)
return self.device_uuid
return get_device_uuid(self.device.index)

def update_weights_from_ipc_handles(self, ipc_handles):
"""Update weights from IPC handles.
Expand Down
17 changes: 12 additions & 5 deletions nemo_reinforcer/models/policy/hf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,10 +686,17 @@ def zero_out_weights(self):
torch.cuda.synchronize()

def report_device_id(self) -> str:
from vllm.platforms import current_platform
"""Report the UUID of the current CUDA device using NVML.

self.device_uuid = current_platform.get_device_uuid(torch.cuda.current_device())
return self.device_uuid
Returns:
str: UUID of the device in the format "GPU-xxxxx"
"""
from nemo_reinforcer.utils.nvml import get_device_uuid

# Get current device index from torch
device_idx = torch.cuda.current_device()
# Get device UUID using NVML
return get_device_uuid(device_idx)

@torch.no_grad()
def get_weight_ipc_handles(self, offload_model=True):
Expand All @@ -708,15 +715,15 @@ def get_weight_ipc_handles(self, offload_model=True):
params = dtype_params
self._held_reference_model_params = params
data = {}
self.device_uuid = self.report_device_id()
device_uuid = self.report_device_id()
for name, p in params.items():
data[name] = reduce_tensor(p.detach())

if offload_model:
self.model = self.move_to_cpu(self.model)
gc.collect()
torch.cuda.empty_cache()
return {self.device_uuid: data}
return {device_uuid: data}

def prepare_for_lp_inference(self):
self.model.to("cuda")
Expand Down
66 changes: 66 additions & 0 deletions nemo_reinforcer/utils/nvml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import os
import pynvml


@contextlib.contextmanager
def nvml_context():
"""Context manager for NVML initialization and shutdown.

Raises:
RuntimeError: If NVML initialization fails
"""
try:
pynvml.nvmlInit()
yield
except pynvml.NVMLError as e:
raise RuntimeError(f"Failed to initialize NVML: {e}")
finally:
try:
pynvml.nvmlShutdown()
except:
pass


def device_id_to_physical_device_id(device_id: int) -> int:
"""Convert a logical device ID to a physical device ID considering CUDA_VISIBLE_DEVICES."""
if "CUDA_VISIBLE_DEVICES" in os.environ:
device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
try:
physical_device_id = int(device_ids[device_id])
return physical_device_id
except ValueError:
raise RuntimeError(
f"Failed to convert logical device ID {device_id} to physical device ID. Available devices are: {device_ids}."
)
else:
return device_id


def get_device_uuid(device_idx: int) -> str:
"""Get the UUID of a CUDA device using NVML."""
# Convert logical device index to physical device index
global_device_idx = device_id_to_physical_device_id(device_idx)

# Get the device handle and UUID
with nvml_context():
try:
handle = pynvml.nvmlDeviceGetHandleByIndex(global_device_idx)
return pynvml.nvmlDeviceGetUUID(handle)
except pynvml.NVMLError as e:
raise RuntimeError(
f"Failed to get device UUID for device {device_idx} (global index: {global_device_idx}): {e}"
)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dependencies = [
"omegaconf",
"torchdata",
"vllm==0.8.0",
"nvidia-ml-py",
]

[tool.setuptools]
Expand Down
61 changes: 61 additions & 0 deletions tests/unit/utils/test_pynvml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest.mock import patch

from nemo_reinforcer.utils.nvml import (
nvml_context,
device_id_to_physical_device_id,
get_device_uuid,
)


@patch("nemo_reinforcer.utils.nvml.pynvml")
def test_nvml_context(mock_pynvml):
"""Test that nvml_context initializes and shuts down NVML."""
with nvml_context():
pass

# Verify init and shutdown were called
mock_pynvml.nvmlInit.assert_called_once()
mock_pynvml.nvmlShutdown.assert_called_once()


def test_device_id_conversion():
"""Test device ID conversion with and without CUDA_VISIBLE_DEVICES."""
with patch.dict(os.environ, {}, clear=True):
assert device_id_to_physical_device_id(0) == 0

with patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "2,3"}):
assert device_id_to_physical_device_id(0) == 2
assert device_id_to_physical_device_id(1) == 3


@patch("nemo_reinforcer.utils.nvml.device_id_to_physical_device_id")
@patch("nemo_reinforcer.utils.nvml.pynvml")
def test_get_device_uuid(mock_pynvml, mock_convert_id):
"""Test that get_device_uuid correctly retrieves a UUID."""

# Setup
mock_convert_id.return_value = 1
mock_handle = mock_pynvml.nvmlDeviceGetHandleByIndex.return_value
mock_pynvml.nvmlDeviceGetUUID.return_value = b"GPU-12345"

# Call function
uuid = get_device_uuid(0)

# Verify
assert uuid == b"GPU-12345"
mock_convert_id.assert_called_once_with(0)
mock_pynvml.nvmlDeviceGetHandleByIndex.assert_called_once_with(1)