Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions metatomic-torch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ a changelog](https://keepachangelog.com/en/1.1.0/) format. This project follows

### Changed

- `pick_device` now returns a `torch::Device` (instead of `torch::DeviceType`)
in C++, and a `torch.device` (instead of a `str`) in Python. Any device index
supplied by the caller (e.g. `"cuda:1"`) is preserved in the returned device.

- 3-argument `unit_conversion_factor(quantity, from_unit, to_unit)` is
deprecated; the `quantity` parameter is ignored
- `ModelOutput.quantity` field is deprecated, since it is no longer required for
Expand Down
11 changes: 5 additions & 6 deletions metatomic-torch/include/metatomic/torch/misc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,12 @@ METATOMIC_TORCH_EXPORT std::string version();
///
/// If `desired_device` is provided, it is checked against the `model_devices`
/// and the machine availability. If it contains a device index (e.g. "cuda:1"),
/// the base device type ("cuda") is used for these checks.
/// the base device type ("cuda") is used for these checks, and the full device
/// (including the index) is returned if successful.
///
/// This function returns a c10::DeviceType (torch::DeviceType). It does NOT
/// decide a device index — callers that need a full torch::Device should
/// construct one from the returned DeviceType (and choose an index explicitly).
/// Or let it default away to zero via Device(DeviceType)
METATOMIC_TORCH_EXPORT torch::DeviceType pick_device(
/// If `desired_device` is not provided or empty, the first available device
/// from `model_devices` is selected and returned (without a specific index).
METATOMIC_TORCH_EXPORT torch::Device pick_device(
std::vector<std::string> model_devices,
torch::optional<std::string> desired_device = torch::nullopt
);
Expand Down
16 changes: 8 additions & 8 deletions metatomic-torch/src/misc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ static inline bool is_known_device(const std::string &n) {
);
}

c10::DeviceType pick_device(
torch::Device pick_device(
std::vector<std::string> model_devices,
torch::optional<std::string> desired_device
) {
Expand Down Expand Up @@ -133,21 +133,21 @@ c10::DeviceType pick_device(

// if no desired device requested, pick first available
if (!desired_device.has_value() || desired_device->empty()) {
return map_to_devicetype(available.front());
return torch::Device(map_to_devicetype(available.front()));
}

// normalize desired and check
std::string wanted_str = lower(desired_device.value());
torch::DeviceType wanted_type;
// convert desired and check
torch::Device wanted_device(torch::kCPU);
try {
wanted_type = torch::Device(wanted_str).type();
wanted_device = torch::Device(lower(desired_device.value()));
} catch (const std::exception &) {
C10_THROW_ERROR(ValueError, "invalid device string: " + desired_device.value());
}

for (auto &a : available) {
torch::DeviceType wanted_type = wanted_device.type();
for (const auto& a : available) {
if (map_to_devicetype(a) == wanted_type) {
return wanted_type;
return wanted_device;
}
}

Expand Down
39 changes: 9 additions & 30 deletions metatomic-torch/src/register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,6 @@

using namespace metatomic_torch;

std::string pick_device_pywrapper(
const std::vector<std::string> &model_devices,
const c10::optional<std::string> &requested_device
) {
try {
torch::optional<std::string> desired = torch::nullopt;
if (requested_device.has_value() && !requested_device->empty()) {
desired = requested_device.value();
}

c10::DeviceType devtype = metatomic_torch::pick_device(model_devices, desired);

if (desired.has_value()) {
// User requested a specific device (possibly with an index like "cuda:1").
// We return it normalized (e.g. "CUDA:1" -> "cuda:1").
return torch::Device(desired.value()).str();
} else {
// Automatic selection: return the device type name (e.g. "cuda").
return torch::Device(devtype).str();
}

} catch (const std::exception &e) {
throw std::runtime_error(std::string("pick_device failed: ") + e.what());
}
}

/// Small wrapper around unit_conversion_factor to support both the new 2-arg
/// form and the deprecated 3-arg form, with flexible positional/keyword
Expand Down Expand Up @@ -334,12 +309,16 @@ TORCH_LIBRARY(metatomic, m) {

// standalone functions
m.def("version() -> str", version);
// Expose pick_device to Python. The C++ helper returns a c10::DeviceType;
// build a torch::Device from it (with default index policy) and return its
// string representation to Python (backwards-compatible).
m.def(
"pick_device(str[] model_devices, str? requested_device = None) -> str",
&pick_device_pywrapper
"pick_device(str[] model_devices, str? requested_device = None) -> Device",
[](const std::vector<std::string>& model_devices,
const c10::optional<std::string>& requested_device) -> torch::Device {
torch::optional<std::string> desired = torch::nullopt;
if (requested_device.has_value() && !requested_device->empty()) {
desired = requested_device.value();
}
return metatomic_torch::pick_device(model_devices, desired);
}
);
m.def("pick_output(str requested_output, Dict(str, __torch__.torch.classes.metatomic.ModelOutput) outputs, str? desired_variant = None) -> str", pick_output);

Expand Down
40 changes: 27 additions & 13 deletions metatomic-torch/tests/misc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,30 @@ TEST_CASE("Pick device") {
// check that first entry in vector is picked, if no desired device is given
std::vector<std::string> supported = {"cpu", "cuda", "mps"};

auto devtype = metatomic_torch::pick_device(supported);
// devtype must be one of the DeviceType enum values; at minimum CPU should be selectable
// Catch doesn't allow operator|| inside assertions in a portable way, so evaluate
// the boolean expression first and assert that.
bool dev_ok = (devtype == c10::DeviceType::CPU) ||
(devtype == c10::DeviceType::CUDA) ||
(devtype == c10::DeviceType::MPS);
auto device = metatomic_torch::pick_device(supported);
// device must be one of the supported types; at minimum CPU should be selectable
bool dev_ok = (device.type() == c10::DeviceType::CPU) ||
(device.type() == c10::DeviceType::CUDA) ||
(device.type() == c10::DeviceType::MPS);
REQUIRE(dev_ok);
// auto-selection should not set a device index
CHECK_FALSE(device.has_index());

// Test requested device selection when available
if (torch::cuda::is_available()) {
auto dt_cuda = metatomic_torch::pick_device(supported, std::string("cuda"));
CHECK(dt_cuda == c10::DeviceType::CUDA);
auto dev_cuda = metatomic_torch::pick_device(supported, std::string("cuda"));
CHECK(dev_cuda.type() == c10::DeviceType::CUDA);
CHECK_FALSE(dev_cuda.has_index());

// Device index should be preserved
auto dev_cuda0 = metatomic_torch::pick_device(supported, std::string("cuda:0"));
CHECK(dev_cuda0.type() == c10::DeviceType::CUDA);
CHECK(dev_cuda0.has_index());
CHECK(dev_cuda0.index() == 0);
}
if (torch::mps::is_available()) {
auto dt_mps = metatomic_torch::pick_device(supported, std::string("mps"));
CHECK(dt_mps == c10::DeviceType::MPS);
auto dev_mps = metatomic_torch::pick_device(supported, std::string("mps"));
CHECK(dev_mps.type() == c10::DeviceType::MPS);
}

// Check that warning is emitted:
Expand All @@ -56,8 +63,8 @@ TEST_CASE("Pick device") {
torch::WarningUtils::set_warnAlways(true);

std::vector<std::string> supported_devices_foo = {"cpu", "fooo"};
auto tdevtype = torch::Device(metatomic_torch::pick_device(supported_devices_foo));
CHECK(tdevtype.str() == "cpu");
auto dev_foo = metatomic_torch::pick_device(supported_devices_foo);
CHECK(dev_foo.str() == "cpu");
REQUIRE_FALSE(handler.messages.empty());

const auto* expected = "ignoring unknown device 'fooo' from `model_devices`";
Expand All @@ -74,6 +81,13 @@ TEST_CASE("Pick device errors") {
std::vector<std::string> only_cpu = {"cpu"};
CHECK_THROWS_WITH(metatomic_torch::pick_device(only_cpu, std::string("cuda")), StartsWith("failed to find requested device"));
}

// invalid device string should raise
std::vector<std::string> supported = {"cpu"};
CHECK_THROWS_WITH(
metatomic_torch::pick_device(supported, std::string("cpu:invalid")),
StartsWith("invalid device string")
);
}


Expand Down
4 changes: 2 additions & 2 deletions python/metatomic_ase/src/metatomic_ase/_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ def __init__(
# get the best device according what the model supports and what's available on
# the current machine
capabilities = model.capabilities()
self._device = torch.device(
pick_device(capabilities.supported_devices, self.parameters["device"])
self._device = pick_device(
capabilities.supported_devices, self.parameters["device"]
)

if capabilities.dtype in STR_TO_DTYPE:
Expand Down
13 changes: 8 additions & 5 deletions python/metatomic_torch/metatomic/torch/documentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,21 +663,24 @@ def unit_dimension_for_quantity(name: str) -> str:
raise THIS_CODE_SHOULD_NOT_RUN


def pick_device(model_devices: List[str], desired_device: Optional[str]) -> str:
def pick_device(
model_devices: List[str], desired_device: Optional[str]
) -> torch.device:
"""
Select the best device according to the list of ``model_devices`` from a model, the
user-provided ``desired_device`` and what's available on the current machine.

If ``desired_device`` is provided, it is checked against the ``model_devices`` and
the machine availability. If it contains a device index (e.g. ``"cuda:1"``), the
base device type (``"cuda"``) is used for these checks, and the full string is
returned if successful.
base device type (``"cuda"``) is used for these checks, and the full
:py:class:`torch.device` (including the provided index) is returned if successful.

If ``desired_device`` is ``None`` or an empty string, the first available device
from ``model_devices`` will be picked and returned.
from ``model_devices`` will be picked and returned without a specific index.

:param model_devices: list of devices supported by a model in order of preference
:param desired_device: user-provided desired device.
:param desired_device: user-provided desired device string (e.g. ``"cuda"``,
``"cuda:1"``, ``"cpu"``), or ``None`` to auto-select.
"""
raise THIS_CODE_SHOULD_NOT_RUN

Expand Down
66 changes: 40 additions & 26 deletions python/metatomic_torch/tests/pick_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,41 @@


def test_pick_device_basic():
# basic call should return a non-empty string describing a device
# basic call should return a torch.device
res = mta.pick_device(["cpu", "cuda", "mps"], None)
assert isinstance(res, str)
assert len(res) > 0
# sanity: in typical environments we'll at least see "cpu" somewhere
assert "cpu" == res or "cuda" == res or "mps" == res
assert isinstance(res, torch.device)
# sanity: in typical environments we'll at least see "cpu", "cuda" or "mps"
assert res.type in ("cpu", "cuda", "mps")


def test_pick_device_requested_if_available():
# if CUDA is available, requesting it should yield a cuda device string
# if CUDA is available, requesting it should yield a cuda device
if torch.cuda.is_available():
res = mta.pick_device(["cpu", "cuda"], "cuda")
assert isinstance(res, str)
assert "cuda" == res
# if MPS is available, requesting it should yield an mps device string
assert isinstance(res, torch.device)
assert res.type == "cuda"
# if MPS is available, requesting it should yield an mps device
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
res = mta.pick_device(["cpu", "mps"], "mps")
assert isinstance(res, str)
assert "mps" == res
assert isinstance(res, torch.device)
assert res.type == "mps"


def test_pick_device_no_index_when_auto():
# When no desired_device is given, the result should have no specific index
res = mta.pick_device(["cpu"], None)
assert isinstance(res, torch.device)
assert res.type == "cpu"
# torch.device("cpu") has index -1 (unset); torch.device("cpu:0") has index 0
assert res.index is None


def test_pick_device_ignores_unrecognized_and_warns(capfd):
# Ensure unrecognized device names are ignored and a warning is emitted
res = mta.pick_device(["cpu", "fooo"], None)
assert isinstance(res, str)
assert isinstance(res, torch.device)
# should pick cpu (ignore "fooo")
assert "cpu" == res
assert res.type == "cpu"
# at least one warning should have been produced about the unrecognized/
# ignored entry
captured = capfd.readouterr()
Expand All @@ -42,27 +50,33 @@ def test_pick_device_ignores_unrecognized_and_warns(capfd):


def test_pick_device_error_on_unavailable_requested():
# If a device is explicitly requested but isn't available/declared,
# the python wrapper is expected to raise a RuntimeError (propagated from C++).
only_cuda = ["cuda"]
if not torch.cuda.is_available():
with pytest.raises(RuntimeError):
mta.pick_device(only_cuda, "cuda")
# Test if a device is explicitly requested but isn't available/declared"
if torch.cuda.is_available():
model_devices = ["cpu"]
else:
# If CUDA is available, requesting a non-present device should raise
with pytest.raises(RuntimeError):
mta.pick_device(["cpu"], "cuda")
model_devices = ["cuda"]

match = (
"failed to find a valid device. "
"None of the model-supported devices are available."
)
with pytest.raises(ValueError, match=match):
mta.pick_device(model_devices, "cuda")


def test_pick_device_indexed():
# Test that indexed device strings like "cpu:0" or "cuda:1" are accepted
# and preserved.
# and that the index is preserved in the returned torch.device.
res = mta.pick_device(["cpu", "cuda"], "cpu:0")
assert res == "cpu:0"
assert isinstance(res, torch.device)
assert res.type == "cpu"
assert res.index == 0

if torch.cuda.is_available():
res = mta.pick_device(["cpu", "cuda"], "cuda:0")
assert res == "cuda:0"
assert isinstance(res, torch.device)
assert res.type == "cuda"
assert res.index == 0

with pytest.raises(RuntimeError, match="invalid device string"):
with pytest.raises(ValueError, match="invalid device string"):
mta.pick_device(["cpu"], "cpu:invalid")
4 changes: 1 addition & 3 deletions python/metatomic_torchsim/metatomic_torchsim/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,7 @@ def __init__(
device = torch.device(device)
self._device = device
else:
self._device = torch.device(
pick_device(capabilities.supported_devices, None)
)
self._device = pick_device(capabilities.supported_devices, None)

# Resolve dtype from model capabilities
if capabilities.dtype in STR_TO_DTYPE:
Expand Down
Loading