diff --git a/docs/src/torch/reference/cxx/index.rst b/docs/src/torch/reference/cxx/index.rst index 080d254b3..60e92365b 100644 --- a/docs/src/torch/reference/cxx/index.rst +++ b/docs/src/torch/reference/cxx/index.rst @@ -12,3 +12,4 @@ to ``libmetatatomic_torch.so`` / ``libmetatatomic_torch.dylib`` / systems models + misc diff --git a/docs/src/torch/reference/cxx/misc.rst b/docs/src/torch/reference/cxx/misc.rst new file mode 100644 index 000000000..52e7d923c --- /dev/null +++ b/docs/src/torch/reference/cxx/misc.rst @@ -0,0 +1,6 @@ +Miscellaneous +============= + +.. doxygenfunction:: metatomic_torch::unit_conversion_factor + +.. doxygenfunction:: metatomic_torch::pick_device diff --git a/docs/src/torch/reference/cxx/models.rst b/docs/src/torch/reference/cxx/models.rst index 198ef938a..e81dc5657 100644 --- a/docs/src/torch/reference/cxx/models.rst +++ b/docs/src/torch/reference/cxx/models.rst @@ -7,8 +7,6 @@ Models .. doxygenfunction:: metatomic_torch::load_model_extensions -.. doxygenfunction:: metatomic_torch::unit_conversion_factor - .. doxygentypedef:: metatomic_torch::ModelOutput .. doxygenclass:: metatomic_torch::ModelOutputHolder diff --git a/docs/src/torch/reference/index.rst b/docs/src/torch/reference/index.rst index 9cbec1aae..a689e439b 100644 --- a/docs/src/torch/reference/index.rst +++ b/docs/src/torch/reference/index.rst @@ -8,6 +8,7 @@ API reference systems models/index + misc ase serialization diff --git a/docs/src/torch/reference/misc.rst b/docs/src/torch/reference/misc.rst new file mode 100644 index 000000000..8338d9413 --- /dev/null +++ b/docs/src/torch/reference/misc.rst @@ -0,0 +1,34 @@ +Miscellaneous +============= + +.. autofunction:: metatomic.torch.pick_device + +.. autofunction:: metatomic.torch.unit_conversion_factor + +.. _known-quantities-units: + +Known quantities and units +-------------------------- + +The following quantities and units can be used with metatomic models. Adding new +units and quantities is very easy, please contact us if you need something else! +In the mean time, you can create :py:class:`metatomic.torch.ModelOutput` with +quantities that are not in this table. A warning will be issued and no unit +conversion will be performed. + +When working with one of the quantities in this table, the unit you use must be +one of the registered unit. + ++----------------+---------------------------------------------------------------------------------------------------+ +| quantity | units | ++================+===================================================================================================+ +| **length** | angstrom (A), Bohr, meter, centimeter (cm), millimeter (mm), micrometer (um, µm), nanometer (nm) | ++----------------+---------------------------------------------------------------------------------------------------+ +| **energy** | eV, meV, Hartree, kcal/mol, kJ/mol, Joule (J), Rydberg (Ry) | ++----------------+---------------------------------------------------------------------------------------------------+ +| **force** | eV/Angstrom (eV/A, eV/Angstrom) | ++----------------+---------------------------------------------------------------------------------------------------+ +| **pressure** | eV/Angstrom^3 (eV/A^3, eV/Angstrom^3) | ++----------------+---------------------------------------------------------------------------------------------------+ +| **momentum** | sqrt(eV*u) | ++----------------+---------------------------------------------------------------------------------------------------+ diff --git a/docs/src/torch/reference/models/index.rst b/docs/src/torch/reference/models/index.rst index d4aa6e551..0bb5d29b7 100644 --- a/docs/src/torch/reference/models/index.rst +++ b/docs/src/torch/reference/models/index.rst @@ -21,32 +21,3 @@ We also provide a couple of functions to work with the models: .. autofunction:: metatomic.torch.load_model_extensions -.. autofunction:: metatomic.torch.unit_conversion_factor - -.. _known-quantities-units: - -Known quantities and units --------------------------- - -The following quantities and units can be used with metatomic models. Adding new -units and quantities is very easy, please contact us if you need something else! -In the mean time, you can create :py:class:`metatomic.torch.ModelOutput` with -quantities that are not in this table. A warning will be issued and no unit -conversion will be performed. - -When working with one of the quantities in this table, the unit you use must be -one of the registered unit. - -+----------------+---------------------------------------------------------------------------------------------------+ -| quantity | units | -+================+===================================================================================================+ -| **length** | angstrom (A), Bohr, meter, centimeter (cm), millimeter (mm), micrometer (um, µm), nanometer (nm) | -+----------------+---------------------------------------------------------------------------------------------------+ -| **energy** | eV, meV, Hartree, kcal/mol, kJ/mol, Joule (J), Rydberg (Ry) | -+----------------+---------------------------------------------------------------------------------------------------+ -| **force** | eV/Angstrom (eV/A, eV/Angstrom) | -+----------------+---------------------------------------------------------------------------------------------------+ -| **pressure** | eV/Angstrom^3 (eV/A^3, eV/Angstrom^3) | -+----------------+---------------------------------------------------------------------------------------------------+ -| **momentum** | sqrt(eV*u) | -+----------------+---------------------------------------------------------------------------------------------------+ diff --git a/metatomic-torch/include/metatomic/torch/misc.hpp b/metatomic-torch/include/metatomic/torch/misc.hpp index 91276f3c9..9264285a9 100644 --- a/metatomic-torch/include/metatomic/torch/misc.hpp +++ b/metatomic-torch/include/metatomic/torch/misc.hpp @@ -3,6 +3,8 @@ #include +#include + #include "metatomic/torch/exports.h" namespace metatomic_torch { @@ -10,6 +12,14 @@ namespace metatomic_torch { /// Get the runtime version of metatensor-torch as a string METATOMIC_TORCH_EXPORT std::string version(); +/// Select the best device according to the list of `model_devices` from a +/// model, the user-provided `desired_device` and what's available on the +/// current machine. +METATOMIC_TORCH_EXPORT std::string pick_device( + std::vector model_devices, + torch::optional desired_device = torch::nullopt +); + } #endif diff --git a/metatomic-torch/src/misc.cpp b/metatomic-torch/src/misc.cpp index 31f8a75b0..b1164567b 100644 --- a/metatomic-torch/src/misc.cpp +++ b/metatomic-torch/src/misc.cpp @@ -1,6 +1,62 @@ +#include + #include "metatomic/torch/version.h" #include "metatomic/torch/misc.hpp" std::string metatomic_torch::version() { return METATOMIC_TORCH_VERSION; } + +std::string metatomic_torch::pick_device( + std::vector model_devices, + torch::optional desired_device +) { + auto available_devices = std::vector(); + std::string selected_device = "cpu"; + + for (const auto& device: model_devices) { + if (device == "cpu") { + available_devices.emplace_back("cpu"); + } else if (device == "cuda") { + if (torch::cuda::is_available()) { + available_devices.emplace_back("cuda"); + } + } else if (device == "mps") { + if (torch::mps::is_available()) { + available_devices.emplace_back("mps"); + } + } else { + TORCH_WARN("'model_devices' contains an entry for unknown device (" + torch::str(device) + + "). It will be ignored."); + } + } + + if (available_devices.empty()) { + C10_THROW_ERROR(ValueError, + "failed to find a valid device. None of the devices supported by the model (" + + torch::str(model_devices) + ") where available (" + torch::str(available_devices) + ")." + ); + } + + if (desired_device == torch::nullopt) { + // no user request, pick the device the model prefers + selected_device = available_devices[0]; + } else { + bool found_desired_device = false; + for (const auto& device: available_devices) { + if (device == desired_device) { + selected_device = device; + found_desired_device = true; + break; + } + } + + if (!found_desired_device) { + C10_THROW_ERROR(ValueError, + "failed to find requested device (" + torch::str(desired_device.value()) + + "): it is either not supported by this model or not available on this machine" + ); + } + } + return selected_device; +} diff --git a/metatomic-torch/src/register.cpp b/metatomic-torch/src/register.cpp index c80d6ac3e..f62413eae 100644 --- a/metatomic-torch/src/register.cpp +++ b/metatomic-torch/src/register.cpp @@ -207,6 +207,7 @@ TORCH_LIBRARY(metatomic, m) { // standalone functions m.def("version() -> str", metatomic_torch::version); + m.def("pick_device(str[] model_devices, str? requested_device = None) -> str", metatomic_torch::pick_device); m.def("read_model_metadata(str path) -> __torch__.torch.classes.metatomic.ModelMetadata", read_model_metadata); m.def("unit_conversion_factor(str quantity, str from_unit, str to_unit) -> float", unit_conversion_factor); diff --git a/metatomic-torch/tests/misc.cpp b/metatomic-torch/tests/misc.cpp index f8f9a4daa..6301db7d7 100644 --- a/metatomic-torch/tests/misc.cpp +++ b/metatomic-torch/tests/misc.cpp @@ -1,7 +1,18 @@ +#include + #include "metatomic/torch.hpp" #include +using namespace Catch::Matchers; + +class TestWarningHandler : public torch::WarningHandler { +public: + std::vector messages; + void process(const torch::Warning& warning) override { + messages.push_back(warning.msg()); + } +}; TEST_CASE("Version macros") { CHECK(std::string(METATOMIC_TORCH_VERSION) == metatomic_torch::version()); @@ -13,3 +24,40 @@ TEST_CASE("Version macros") { // METATOMIC_TORCH_VERSION should start with `x.y.z` CHECK(std::string(METATOMIC_TORCH_VERSION).find(version) == 0); } + + +TEST_CASE("Pick device") { + // check that first entry in vector is picked, if no desired device is given + std::vector supported_devices = {"cpu", "cuda", "mps"}; + CHECK(metatomic_torch::pick_device(supported_devices) == "cpu"); + + // test that desired device is picked if available + std::vector desired_devices = {"cpu"}; + if (torch::cuda::is_available()) { + desired_devices.emplace_back("cuda"); + } + if (torch::mps::is_available()) { + desired_devices.emplace_back("mps"); + } + + for (const auto& desired_device: desired_devices) { + CHECK(metatomic_torch::pick_device(supported_devices, desired_device) == desired_device); + } + + // Check that warning is emitted: + TestWarningHandler handler; + torch::WarningUtils::WarningHandlerGuard guard(&handler); + torch::WarningUtils::set_warnAlways(true); + + std::vector supported_devices_foo = {"cpu", "fooo"}; + CHECK(metatomic_torch::pick_device(supported_devices_foo) == "cpu"); + REQUIRE_FALSE(handler.messages.empty()); + CHECK(handler.messages[0].find("'model_devices' contains an entry for unknown device") != std::string::npos); + + // check exception raised + std::vector supported_devices_cuda = {"cuda"}; + CHECK_THROWS_WITH(metatomic_torch::pick_device(supported_devices_cuda, "cpu"), StartsWith("failed to find a valid device")); + + std::vector supported_devices_cpu = {"cpu"}; + CHECK_THROWS_WITH(metatomic_torch::pick_device(supported_devices_cpu, "cuda"), StartsWith("failed to find requested device")); +} diff --git a/python/metatomic_torch/metatomic/torch/__init__.py b/python/metatomic_torch/metatomic/torch/__init__.py index 61c6bb555..2a91427b0 100644 --- a/python/metatomic_torch/metatomic/torch/__init__.py +++ b/python/metatomic_torch/metatomic/torch/__init__.py @@ -16,6 +16,7 @@ System, check_atomistic_model, load_model_extensions, + pick_device, read_model_metadata, register_autograd_neighbors, unit_conversion_factor, @@ -38,6 +39,7 @@ register_autograd_neighbors = torch.ops.metatomic.register_autograd_neighbors unit_conversion_factor = torch.ops.metatomic.unit_conversion_factor + pick_device = torch.ops.metatomic.pick_device from .io import load_system, save # noqa: F401 from .model import ( # noqa: F401 diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index 654603c9c..31d4d1e54 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -18,6 +18,7 @@ ModelOutput, System, load_atomistic_model, + pick_device, register_autograd_neighbors, ) @@ -143,35 +144,12 @@ def __init__( raise TypeError(f"unknown type for model: {type(model)}") self.parameters["device"] = str(device) if device is not None else None - # check if the model supports the requested device + # get the best device according what the model supports and what's available on + # the current machine capabilities = model.capabilities() - if device is None: - device = _find_best_device(capabilities.supported_devices) - else: - device = torch.device(device) - device_is_supported = False - - for supported in capabilities.supported_devices: - try: - supported = torch.device(supported) - except RuntimeError as e: - warnings.warn( - "the model contains an invalid device in `supported_devices`: " - f"{e}", - stacklevel=2, - ) - continue - - if supported.type == device.type: - device_is_supported = True - break - - if not device_is_supported: - raise ValueError( - f"This model does not support the requested device ({device}), " - "the following devices are supported: " - f"{capabilities.supported_devices}" - ) + self._device = torch.device( + pick_device(capabilities.supported_devices, self.parameters["device"]) + ) if capabilities.dtype in STR_TO_DTYPE: self._dtype = STR_TO_DTYPE[capabilities.dtype] @@ -193,7 +171,6 @@ def __init__( self._additional_output_requests = additional_outputs - self._device = device self._model = model.to(device=self._device) self._calculate_uncertainty = ( @@ -716,49 +693,6 @@ def _ase_properties_to_metatensor_outputs( return metatensor_outputs -def _find_best_device(devices: List[str]) -> torch.device: - """ - Find the best device from the list of ``devices`` that is available to the current - PyTorch installation. - """ - - for device in devices: - if device == "cpu": - return torch.device("cpu") - elif device == "cuda": - if torch.cuda.is_available(): - return torch.device("cuda") - else: - LOGGER.warning( - "the model suggested to use CUDA devices before CPU, " - "but we are unable to find it" - ) - elif device == "mps": - if ( - hasattr(torch.backends, "mps") - and torch.backends.mps.is_built() - and torch.backends.mps.is_available() - ): - return torch.device("mps") - else: - LOGGER.warning( - "the model suggested to use MPS devices before CPU, " - "but we are unable to find it" - ) - else: - warnings.warn( - f"unknown device in the model's `supported_devices`: '{device}'", - stacklevel=2, - ) - - warnings.warn( - "could not find a valid device in the model's `supported_devices`, " - "falling back to CPU", - stacklevel=2, - ) - return torch.device("cpu") - - def _compute_ase_neighbors(atoms, options, dtype, device): # options.strict is ignored by this function, since `ase.neighborlist.neighbor_list` # only computes strict NL, and these are valid even with `strict=False` diff --git a/python/metatomic_torch/metatomic/torch/documentation.py b/python/metatomic_torch/metatomic/torch/documentation.py index f00d9d20b..6c99b750a 100644 --- a/python/metatomic_torch/metatomic/torch/documentation.py +++ b/python/metatomic_torch/metatomic/torch/documentation.py @@ -524,3 +524,15 @@ def unit_conversion_factor(quantity: str, from_unit: str, to_unit: str): :param from_unit: current unit of the data :param to_unit: target unit of the data """ + + +def pick_device(model_devices: List[str], desired_device: Optional[str]) -> str: + """ + Select the best device according to the list of ``model_devices`` from a + model, the user-provided ``desired_device`` and what's available on the + current machine. + + :param model_devices: list of devices supported by a model in order of preference + :param desired_device: user-provided desired device. If ``None`` or not available, + the first available device from ``model_devices`` will be picked. + """