diff --git a/metatomic-torch/CHANGELOG.md b/metatomic-torch/CHANGELOG.md index d4655811..0640b65e 100644 --- a/metatomic-torch/CHANGELOG.md +++ b/metatomic-torch/CHANGELOG.md @@ -29,6 +29,10 @@ a changelog](https://keepachangelog.com/en/1.1.0/) format. This project follows ### Changed +- `pick_device` now returns a `torch::Device` (instead of `torch::DeviceType`) + in C++, and a `torch.device` (instead of a `str`) in Python. Any device index + supplied by the caller (e.g. `"cuda:1"`) is preserved in the returned device. + - 3-argument `unit_conversion_factor(quantity, from_unit, to_unit)` is deprecated; the `quantity` parameter is ignored - `ModelOutput.quantity` field is deprecated, since it is no longer required for diff --git a/metatomic-torch/include/metatomic/torch/misc.hpp b/metatomic-torch/include/metatomic/torch/misc.hpp index ea3621b4..fc18dc4e 100644 --- a/metatomic-torch/include/metatomic/torch/misc.hpp +++ b/metatomic-torch/include/metatomic/torch/misc.hpp @@ -25,13 +25,12 @@ METATOMIC_TORCH_EXPORT std::string version(); /// /// If `desired_device` is provided, it is checked against the `model_devices` /// and the machine availability. If it contains a device index (e.g. "cuda:1"), -/// the base device type ("cuda") is used for these checks. +/// the base device type ("cuda") is used for these checks, and the full device +/// (including the index) is returned if successful. /// -/// This function returns a c10::DeviceType (torch::DeviceType). It does NOT -/// decide a device index — callers that need a full torch::Device should -/// construct one from the returned DeviceType (and choose an index explicitly). -/// Or let it default away to zero via Device(DeviceType) -METATOMIC_TORCH_EXPORT torch::DeviceType pick_device( +/// If `desired_device` is not provided or empty, the first available device +/// from `model_devices` is selected and returned (without a specific index). +METATOMIC_TORCH_EXPORT torch::Device pick_device( std::vector model_devices, torch::optional desired_device = torch::nullopt ); diff --git a/metatomic-torch/src/misc.cpp b/metatomic-torch/src/misc.cpp index 8b2eb177..2f019c1a 100644 --- a/metatomic-torch/src/misc.cpp +++ b/metatomic-torch/src/misc.cpp @@ -105,7 +105,7 @@ static inline bool is_known_device(const std::string &n) { ); } -c10::DeviceType pick_device( +torch::Device pick_device( std::vector model_devices, torch::optional desired_device ) { @@ -133,21 +133,21 @@ c10::DeviceType pick_device( // if no desired device requested, pick first available if (!desired_device.has_value() || desired_device->empty()) { - return map_to_devicetype(available.front()); + return torch::Device(map_to_devicetype(available.front())); } - // normalize desired and check - std::string wanted_str = lower(desired_device.value()); - torch::DeviceType wanted_type; + // convert desired and check + torch::Device wanted_device(torch::kCPU); try { - wanted_type = torch::Device(wanted_str).type(); + wanted_device = torch::Device(lower(desired_device.value())); } catch (const std::exception &) { C10_THROW_ERROR(ValueError, "invalid device string: " + desired_device.value()); } - for (auto &a : available) { + torch::DeviceType wanted_type = wanted_device.type(); + for (const auto& a : available) { if (map_to_devicetype(a) == wanted_type) { - return wanted_type; + return wanted_device; } } diff --git a/metatomic-torch/src/register.cpp b/metatomic-torch/src/register.cpp index 1de13cba..1dde2930 100644 --- a/metatomic-torch/src/register.cpp +++ b/metatomic-torch/src/register.cpp @@ -8,31 +8,6 @@ using namespace metatomic_torch; -std::string pick_device_pywrapper( - const std::vector &model_devices, - const c10::optional &requested_device -) { - try { - torch::optional desired = torch::nullopt; - if (requested_device.has_value() && !requested_device->empty()) { - desired = requested_device.value(); - } - - c10::DeviceType devtype = metatomic_torch::pick_device(model_devices, desired); - - if (desired.has_value()) { - // User requested a specific device (possibly with an index like "cuda:1"). - // We return it normalized (e.g. "CUDA:1" -> "cuda:1"). - return torch::Device(desired.value()).str(); - } else { - // Automatic selection: return the device type name (e.g. "cuda"). - return torch::Device(devtype).str(); - } - - } catch (const std::exception &e) { - throw std::runtime_error(std::string("pick_device failed: ") + e.what()); - } -} /// Small wrapper around unit_conversion_factor to support both the new 2-arg /// form and the deprecated 3-arg form, with flexible positional/keyword @@ -334,12 +309,16 @@ TORCH_LIBRARY(metatomic, m) { // standalone functions m.def("version() -> str", version); - // Expose pick_device to Python. The C++ helper returns a c10::DeviceType; - // build a torch::Device from it (with default index policy) and return its - // string representation to Python (backwards-compatible). m.def( - "pick_device(str[] model_devices, str? requested_device = None) -> str", - &pick_device_pywrapper + "pick_device(str[] model_devices, str? requested_device = None) -> Device", + [](const std::vector& model_devices, + const c10::optional& requested_device) -> torch::Device { + torch::optional desired = torch::nullopt; + if (requested_device.has_value() && !requested_device->empty()) { + desired = requested_device.value(); + } + return metatomic_torch::pick_device(model_devices, desired); + } ); m.def("pick_output(str requested_output, Dict(str, __torch__.torch.classes.metatomic.ModelOutput) outputs, str? desired_variant = None) -> str", pick_output); diff --git a/metatomic-torch/tests/misc.cpp b/metatomic-torch/tests/misc.cpp index 3cd8ba7d..2ca8cb86 100644 --- a/metatomic-torch/tests/misc.cpp +++ b/metatomic-torch/tests/misc.cpp @@ -31,23 +31,30 @@ TEST_CASE("Pick device") { // check that first entry in vector is picked, if no desired device is given std::vector supported = {"cpu", "cuda", "mps"}; - auto devtype = metatomic_torch::pick_device(supported); - // devtype must be one of the DeviceType enum values; at minimum CPU should be selectable - // Catch doesn't allow operator|| inside assertions in a portable way, so evaluate - // the boolean expression first and assert that. - bool dev_ok = (devtype == c10::DeviceType::CPU) || - (devtype == c10::DeviceType::CUDA) || - (devtype == c10::DeviceType::MPS); + auto device = metatomic_torch::pick_device(supported); + // device must be one of the supported types; at minimum CPU should be selectable + bool dev_ok = (device.type() == c10::DeviceType::CPU) || + (device.type() == c10::DeviceType::CUDA) || + (device.type() == c10::DeviceType::MPS); REQUIRE(dev_ok); + // auto-selection should not set a device index + CHECK_FALSE(device.has_index()); // Test requested device selection when available if (torch::cuda::is_available()) { - auto dt_cuda = metatomic_torch::pick_device(supported, std::string("cuda")); - CHECK(dt_cuda == c10::DeviceType::CUDA); + auto dev_cuda = metatomic_torch::pick_device(supported, std::string("cuda")); + CHECK(dev_cuda.type() == c10::DeviceType::CUDA); + CHECK_FALSE(dev_cuda.has_index()); + + // Device index should be preserved + auto dev_cuda0 = metatomic_torch::pick_device(supported, std::string("cuda:0")); + CHECK(dev_cuda0.type() == c10::DeviceType::CUDA); + CHECK(dev_cuda0.has_index()); + CHECK(dev_cuda0.index() == 0); } if (torch::mps::is_available()) { - auto dt_mps = metatomic_torch::pick_device(supported, std::string("mps")); - CHECK(dt_mps == c10::DeviceType::MPS); + auto dev_mps = metatomic_torch::pick_device(supported, std::string("mps")); + CHECK(dev_mps.type() == c10::DeviceType::MPS); } // Check that warning is emitted: @@ -56,8 +63,8 @@ TEST_CASE("Pick device") { torch::WarningUtils::set_warnAlways(true); std::vector supported_devices_foo = {"cpu", "fooo"}; - auto tdevtype = torch::Device(metatomic_torch::pick_device(supported_devices_foo)); - CHECK(tdevtype.str() == "cpu"); + auto dev_foo = metatomic_torch::pick_device(supported_devices_foo); + CHECK(dev_foo.str() == "cpu"); REQUIRE_FALSE(handler.messages.empty()); const auto* expected = "ignoring unknown device 'fooo' from `model_devices`"; @@ -74,6 +81,13 @@ TEST_CASE("Pick device errors") { std::vector only_cpu = {"cpu"}; CHECK_THROWS_WITH(metatomic_torch::pick_device(only_cpu, std::string("cuda")), StartsWith("failed to find requested device")); } + + // invalid device string should raise + std::vector supported = {"cpu"}; + CHECK_THROWS_WITH( + metatomic_torch::pick_device(supported, std::string("cpu:invalid")), + StartsWith("invalid device string") + ); } diff --git a/python/metatomic_ase/src/metatomic_ase/_calculator.py b/python/metatomic_ase/src/metatomic_ase/_calculator.py index 0973c712..28b7883f 100644 --- a/python/metatomic_ase/src/metatomic_ase/_calculator.py +++ b/python/metatomic_ase/src/metatomic_ase/_calculator.py @@ -212,8 +212,8 @@ def __init__( # get the best device according what the model supports and what's available on # the current machine capabilities = model.capabilities() - self._device = torch.device( - pick_device(capabilities.supported_devices, self.parameters["device"]) + self._device = pick_device( + capabilities.supported_devices, self.parameters["device"] ) if capabilities.dtype in STR_TO_DTYPE: diff --git a/python/metatomic_torch/metatomic/torch/documentation.py b/python/metatomic_torch/metatomic/torch/documentation.py index 40013197..1b456745 100644 --- a/python/metatomic_torch/metatomic/torch/documentation.py +++ b/python/metatomic_torch/metatomic/torch/documentation.py @@ -663,21 +663,24 @@ def unit_dimension_for_quantity(name: str) -> str: raise THIS_CODE_SHOULD_NOT_RUN -def pick_device(model_devices: List[str], desired_device: Optional[str]) -> str: +def pick_device( + model_devices: List[str], desired_device: Optional[str] +) -> torch.device: """ Select the best device according to the list of ``model_devices`` from a model, the user-provided ``desired_device`` and what's available on the current machine. If ``desired_device`` is provided, it is checked against the ``model_devices`` and the machine availability. If it contains a device index (e.g. ``"cuda:1"``), the - base device type (``"cuda"``) is used for these checks, and the full string is - returned if successful. + base device type (``"cuda"``) is used for these checks, and the full + :py:class:`torch.device` (including the provided index) is returned if successful. If ``desired_device`` is ``None`` or an empty string, the first available device - from ``model_devices`` will be picked and returned. + from ``model_devices`` will be picked and returned without a specific index. :param model_devices: list of devices supported by a model in order of preference - :param desired_device: user-provided desired device. + :param desired_device: user-provided desired device string (e.g. ``"cuda"``, + ``"cuda:1"``, ``"cpu"``), or ``None`` to auto-select. """ raise THIS_CODE_SHOULD_NOT_RUN diff --git a/python/metatomic_torch/tests/pick_device.py b/python/metatomic_torch/tests/pick_device.py index bcc11d4d..0782a2ff 100644 --- a/python/metatomic_torch/tests/pick_device.py +++ b/python/metatomic_torch/tests/pick_device.py @@ -5,33 +5,41 @@ def test_pick_device_basic(): - # basic call should return a non-empty string describing a device + # basic call should return a torch.device res = mta.pick_device(["cpu", "cuda", "mps"], None) - assert isinstance(res, str) - assert len(res) > 0 - # sanity: in typical environments we'll at least see "cpu" somewhere - assert "cpu" == res or "cuda" == res or "mps" == res + assert isinstance(res, torch.device) + # sanity: in typical environments we'll at least see "cpu", "cuda" or "mps" + assert res.type in ("cpu", "cuda", "mps") def test_pick_device_requested_if_available(): - # if CUDA is available, requesting it should yield a cuda device string + # if CUDA is available, requesting it should yield a cuda device if torch.cuda.is_available(): res = mta.pick_device(["cpu", "cuda"], "cuda") - assert isinstance(res, str) - assert "cuda" == res - # if MPS is available, requesting it should yield an mps device string + assert isinstance(res, torch.device) + assert res.type == "cuda" + # if MPS is available, requesting it should yield an mps device if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): res = mta.pick_device(["cpu", "mps"], "mps") - assert isinstance(res, str) - assert "mps" == res + assert isinstance(res, torch.device) + assert res.type == "mps" + + +def test_pick_device_no_index_when_auto(): + # When no desired_device is given, the result should have no specific index + res = mta.pick_device(["cpu"], None) + assert isinstance(res, torch.device) + assert res.type == "cpu" + # torch.device("cpu") has index -1 (unset); torch.device("cpu:0") has index 0 + assert res.index is None def test_pick_device_ignores_unrecognized_and_warns(capfd): # Ensure unrecognized device names are ignored and a warning is emitted res = mta.pick_device(["cpu", "fooo"], None) - assert isinstance(res, str) + assert isinstance(res, torch.device) # should pick cpu (ignore "fooo") - assert "cpu" == res + assert res.type == "cpu" # at least one warning should have been produced about the unrecognized/ # ignored entry captured = capfd.readouterr() @@ -42,27 +50,33 @@ def test_pick_device_ignores_unrecognized_and_warns(capfd): def test_pick_device_error_on_unavailable_requested(): - # If a device is explicitly requested but isn't available/declared, - # the python wrapper is expected to raise a RuntimeError (propagated from C++). - only_cuda = ["cuda"] - if not torch.cuda.is_available(): - with pytest.raises(RuntimeError): - mta.pick_device(only_cuda, "cuda") + # Test if a device is explicitly requested but isn't available/declared" + if torch.cuda.is_available(): + model_devices = ["cpu"] else: - # If CUDA is available, requesting a non-present device should raise - with pytest.raises(RuntimeError): - mta.pick_device(["cpu"], "cuda") + model_devices = ["cuda"] + + match = ( + "failed to find a valid device. " + "None of the model-supported devices are available." + ) + with pytest.raises(ValueError, match=match): + mta.pick_device(model_devices, "cuda") def test_pick_device_indexed(): # Test that indexed device strings like "cpu:0" or "cuda:1" are accepted - # and preserved. + # and that the index is preserved in the returned torch.device. res = mta.pick_device(["cpu", "cuda"], "cpu:0") - assert res == "cpu:0" + assert isinstance(res, torch.device) + assert res.type == "cpu" + assert res.index == 0 if torch.cuda.is_available(): res = mta.pick_device(["cpu", "cuda"], "cuda:0") - assert res == "cuda:0" + assert isinstance(res, torch.device) + assert res.type == "cuda" + assert res.index == 0 - with pytest.raises(RuntimeError, match="invalid device string"): + with pytest.raises(ValueError, match="invalid device string"): mta.pick_device(["cpu"], "cpu:invalid") diff --git a/python/metatomic_torchsim/metatomic_torchsim/_model.py b/python/metatomic_torchsim/metatomic_torchsim/_model.py index 500aa37e..72abc399 100644 --- a/python/metatomic_torchsim/metatomic_torchsim/_model.py +++ b/python/metatomic_torchsim/metatomic_torchsim/_model.py @@ -138,9 +138,7 @@ def __init__( device = torch.device(device) self._device = device else: - self._device = torch.device( - pick_device(capabilities.supported_devices, None) - ) + self._device = pick_device(capabilities.supported_devices, None) # Resolve dtype from model capabilities if capabilities.dtype in STR_TO_DTYPE: