Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/torch/reference/cxx/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ to ``libmetatatomic_torch.so`` / ``libmetatatomic_torch.dylib`` /

systems
models
misc
6 changes: 6 additions & 0 deletions docs/src/torch/reference/cxx/misc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Miscellaneous
=============

.. doxygenfunction:: metatomic_torch::unit_conversion_factor

.. doxygenfunction:: metatomic_torch::pick_device
2 changes: 0 additions & 2 deletions docs/src/torch/reference/cxx/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ Models

.. doxygenfunction:: metatomic_torch::load_model_extensions

.. doxygenfunction:: metatomic_torch::unit_conversion_factor

.. doxygentypedef:: metatomic_torch::ModelOutput

.. doxygenclass:: metatomic_torch::ModelOutputHolder
Expand Down
1 change: 1 addition & 0 deletions docs/src/torch/reference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ API reference

systems
models/index
misc
ase
serialization

Expand Down
34 changes: 34 additions & 0 deletions docs/src/torch/reference/misc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
Miscellaneous
=============

.. autofunction:: metatomic.torch.pick_device

.. autofunction:: metatomic.torch.unit_conversion_factor

.. _known-quantities-units:

Known quantities and units
--------------------------

The following quantities and units can be used with metatomic models. Adding new
units and quantities is very easy, please contact us if you need something else!
In the mean time, you can create :py:class:`metatomic.torch.ModelOutput` with
quantities that are not in this table. A warning will be issued and no unit
conversion will be performed.

When working with one of the quantities in this table, the unit you use must be
one of the registered unit.

+----------------+---------------------------------------------------------------------------------------------------+
| quantity | units |
+================+===================================================================================================+
| **length** | angstrom (A), Bohr, meter, centimeter (cm), millimeter (mm), micrometer (um, µm), nanometer (nm) |
+----------------+---------------------------------------------------------------------------------------------------+
| **energy** | eV, meV, Hartree, kcal/mol, kJ/mol, Joule (J), Rydberg (Ry) |
+----------------+---------------------------------------------------------------------------------------------------+
| **force** | eV/Angstrom (eV/A, eV/Angstrom) |
+----------------+---------------------------------------------------------------------------------------------------+
| **pressure** | eV/Angstrom^3 (eV/A^3, eV/Angstrom^3) |
+----------------+---------------------------------------------------------------------------------------------------+
| **momentum** | sqrt(eV*u) |
+----------------+---------------------------------------------------------------------------------------------------+
29 changes: 0 additions & 29 deletions docs/src/torch/reference/models/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,32 +21,3 @@ We also provide a couple of functions to work with the models:

.. autofunction:: metatomic.torch.load_model_extensions

.. autofunction:: metatomic.torch.unit_conversion_factor

.. _known-quantities-units:

Known quantities and units
--------------------------

The following quantities and units can be used with metatomic models. Adding new
units and quantities is very easy, please contact us if you need something else!
In the mean time, you can create :py:class:`metatomic.torch.ModelOutput` with
quantities that are not in this table. A warning will be issued and no unit
conversion will be performed.

When working with one of the quantities in this table, the unit you use must be
one of the registered unit.

+----------------+---------------------------------------------------------------------------------------------------+
| quantity | units |
+================+===================================================================================================+
| **length** | angstrom (A), Bohr, meter, centimeter (cm), millimeter (mm), micrometer (um, µm), nanometer (nm) |
+----------------+---------------------------------------------------------------------------------------------------+
| **energy** | eV, meV, Hartree, kcal/mol, kJ/mol, Joule (J), Rydberg (Ry) |
+----------------+---------------------------------------------------------------------------------------------------+
| **force** | eV/Angstrom (eV/A, eV/Angstrom) |
+----------------+---------------------------------------------------------------------------------------------------+
| **pressure** | eV/Angstrom^3 (eV/A^3, eV/Angstrom^3) |
+----------------+---------------------------------------------------------------------------------------------------+
| **momentum** | sqrt(eV*u) |
+----------------+---------------------------------------------------------------------------------------------------+
10 changes: 10 additions & 0 deletions metatomic-torch/include/metatomic/torch/misc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,23 @@

#include <string>

#include <torch/types.h>

#include "metatomic/torch/exports.h"

namespace metatomic_torch {

/// Get the runtime version of metatensor-torch as a string
METATOMIC_TORCH_EXPORT std::string version();

/// Select the best device according to the list of `model_devices` from a
/// model, the user-provided `desired_device` and what's available on the
/// current machine.
METATOMIC_TORCH_EXPORT std::string pick_device(
std::vector<std::string> model_devices,
torch::optional<std::string> desired_device = torch::nullopt
);

}

#endif
56 changes: 56 additions & 0 deletions metatomic-torch/src/misc.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,62 @@
#include <torch/torch.h>

#include "metatomic/torch/version.h"
#include "metatomic/torch/misc.hpp"

std::string metatomic_torch::version() {
return METATOMIC_TORCH_VERSION;
}

std::string metatomic_torch::pick_device(
std::vector<std::string> model_devices,
torch::optional<std::string> desired_device
) {
auto available_devices = std::vector<std::string>();
std::string selected_device = "cpu";

for (const auto& device: model_devices) {
if (device == "cpu") {
available_devices.emplace_back("cpu");
} else if (device == "cuda") {
if (torch::cuda::is_available()) {
available_devices.emplace_back("cuda");
}
} else if (device == "mps") {
if (torch::mps::is_available()) {
available_devices.emplace_back("mps");
}
} else {
TORCH_WARN("'model_devices' contains an entry for unknown device (" + torch::str(device)
+ "). It will be ignored.");
}
}

if (available_devices.empty()) {
C10_THROW_ERROR(ValueError,
"failed to find a valid device. None of the devices supported by the model ("
+ torch::str(model_devices) + ") where available (" + torch::str(available_devices) + ")."
);
}

if (desired_device == torch::nullopt) {
// no user request, pick the device the model prefers
selected_device = available_devices[0];
} else {
bool found_desired_device = false;
for (const auto& device: available_devices) {
if (device == desired_device) {
selected_device = device;
found_desired_device = true;
break;
}
}

if (!found_desired_device) {
C10_THROW_ERROR(ValueError,
"failed to find requested device (" + torch::str(desired_device.value()) +
"): it is either not supported by this model or not available on this machine"
);
}
}
return selected_device;
}
1 change: 1 addition & 0 deletions metatomic-torch/src/register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ TORCH_LIBRARY(metatomic, m) {

// standalone functions
m.def("version() -> str", metatomic_torch::version);
m.def("pick_device(str[] model_devices, str? requested_device = None) -> str", metatomic_torch::pick_device);

m.def("read_model_metadata(str path) -> __torch__.torch.classes.metatomic.ModelMetadata", read_model_metadata);
m.def("unit_conversion_factor(str quantity, str from_unit, str to_unit) -> float", unit_conversion_factor);
Expand Down
48 changes: 48 additions & 0 deletions metatomic-torch/tests/misc.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
#include <torch/torch.h>

#include "metatomic/torch.hpp"

#include <catch.hpp>
using namespace Catch::Matchers;

class TestWarningHandler : public torch::WarningHandler {
public:
std::vector<std::string> messages;

void process(const torch::Warning& warning) override {
messages.push_back(warning.msg());
}
};

TEST_CASE("Version macros") {
CHECK(std::string(METATOMIC_TORCH_VERSION) == metatomic_torch::version());
Expand All @@ -13,3 +24,40 @@ TEST_CASE("Version macros") {
// METATOMIC_TORCH_VERSION should start with `x.y.z`
CHECK(std::string(METATOMIC_TORCH_VERSION).find(version) == 0);
}


TEST_CASE("Pick device") {
// check that first entry in vector is picked, if no desired device is given
std::vector<std::string> supported_devices = {"cpu", "cuda", "mps"};
CHECK(metatomic_torch::pick_device(supported_devices) == "cpu");

// test that desired device is picked if available
std::vector<std::string> desired_devices = {"cpu"};
if (torch::cuda::is_available()) {
desired_devices.emplace_back("cuda");
}
if (torch::mps::is_available()) {
desired_devices.emplace_back("mps");
}

for (const auto& desired_device: desired_devices) {
CHECK(metatomic_torch::pick_device(supported_devices, desired_device) == desired_device);
}

// Check that warning is emitted:
TestWarningHandler handler;
torch::WarningUtils::WarningHandlerGuard guard(&handler);
torch::WarningUtils::set_warnAlways(true);

std::vector<std::string> supported_devices_foo = {"cpu", "fooo"};
CHECK(metatomic_torch::pick_device(supported_devices_foo) == "cpu");
REQUIRE_FALSE(handler.messages.empty());
CHECK(handler.messages[0].find("'model_devices' contains an entry for unknown device") != std::string::npos);

// check exception raised
std::vector<std::string> supported_devices_cuda = {"cuda"};
CHECK_THROWS_WITH(metatomic_torch::pick_device(supported_devices_cuda, "cpu"), StartsWith("failed to find a valid device"));

std::vector<std::string> supported_devices_cpu = {"cpu"};
CHECK_THROWS_WITH(metatomic_torch::pick_device(supported_devices_cpu, "cuda"), StartsWith("failed to find requested device"));
}
2 changes: 2 additions & 0 deletions python/metatomic_torch/metatomic/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
System,
check_atomistic_model,
load_model_extensions,
pick_device,
read_model_metadata,
register_autograd_neighbors,
unit_conversion_factor,
Expand All @@ -38,6 +39,7 @@

register_autograd_neighbors = torch.ops.metatomic.register_autograd_neighbors
unit_conversion_factor = torch.ops.metatomic.unit_conversion_factor
pick_device = torch.ops.metatomic.pick_device

from .io import load_system, save # noqa: F401
from .model import ( # noqa: F401
Expand Down
78 changes: 6 additions & 72 deletions python/metatomic_torch/metatomic/torch/ase_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
ModelOutput,
System,
load_atomistic_model,
pick_device,
register_autograd_neighbors,
)

Expand Down Expand Up @@ -143,35 +144,12 @@ def __init__(
raise TypeError(f"unknown type for model: {type(model)}")

self.parameters["device"] = str(device) if device is not None else None
# check if the model supports the requested device
# get the best device according what the model supports and what's available on
# the current machine
capabilities = model.capabilities()
if device is None:
device = _find_best_device(capabilities.supported_devices)
else:
device = torch.device(device)
device_is_supported = False

for supported in capabilities.supported_devices:
try:
supported = torch.device(supported)
except RuntimeError as e:
warnings.warn(
"the model contains an invalid device in `supported_devices`: "
f"{e}",
stacklevel=2,
)
continue

if supported.type == device.type:
device_is_supported = True
break

if not device_is_supported:
raise ValueError(
f"This model does not support the requested device ({device}), "
"the following devices are supported: "
f"{capabilities.supported_devices}"
)
self._device = torch.device(
pick_device(capabilities.supported_devices, self.parameters["device"])
)

if capabilities.dtype in STR_TO_DTYPE:
self._dtype = STR_TO_DTYPE[capabilities.dtype]
Expand All @@ -193,7 +171,6 @@ def __init__(

self._additional_output_requests = additional_outputs

self._device = device
self._model = model.to(device=self._device)

self._calculate_uncertainty = (
Expand Down Expand Up @@ -716,49 +693,6 @@ def _ase_properties_to_metatensor_outputs(
return metatensor_outputs


def _find_best_device(devices: List[str]) -> torch.device:
"""
Find the best device from the list of ``devices`` that is available to the current
PyTorch installation.
"""

for device in devices:
if device == "cpu":
return torch.device("cpu")
elif device == "cuda":
if torch.cuda.is_available():
return torch.device("cuda")
else:
LOGGER.warning(
"the model suggested to use CUDA devices before CPU, "
"but we are unable to find it"
)
elif device == "mps":
if (
hasattr(torch.backends, "mps")
and torch.backends.mps.is_built()
and torch.backends.mps.is_available()
):
return torch.device("mps")
else:
LOGGER.warning(
"the model suggested to use MPS devices before CPU, "
"but we are unable to find it"
)
else:
warnings.warn(
f"unknown device in the model's `supported_devices`: '{device}'",
stacklevel=2,
)

warnings.warn(
"could not find a valid device in the model's `supported_devices`, "
"falling back to CPU",
stacklevel=2,
)
return torch.device("cpu")


def _compute_ase_neighbors(atoms, options, dtype, device):
# options.strict is ignored by this function, since `ase.neighborlist.neighbor_list`
# only computes strict NL, and these are valid even with `strict=False`
Expand Down
12 changes: 12 additions & 0 deletions python/metatomic_torch/metatomic/torch/documentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,3 +524,15 @@ def unit_conversion_factor(quantity: str, from_unit: str, to_unit: str):
:param from_unit: current unit of the data
:param to_unit: target unit of the data
"""


def pick_device(model_devices: List[str], desired_device: Optional[str]) -> str:
"""
Select the best device according to the list of ``model_devices`` from a
model, the user-provided ``desired_device`` and what's available on the
current machine.

:param model_devices: list of devices supported by a model in order of preference
:param desired_device: user-provided desired device. If ``None`` or not available,
the first available device from ``model_devices`` will be picked.
"""
Loading