Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion metatomic-torch/include/metatomic/torch/misc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@ METATOMIC_TORCH_EXPORT std::string version();
/// Select the best device according to the list of `model_devices` from a
/// model, the user-provided `desired_device` and what's available on the
/// current machine.
METATOMIC_TORCH_EXPORT std::string pick_device(
///
/// This function returns a c10::DeviceType (torch::DeviceType). It does NOT
/// decide a device index — callers that need a full torch::Device should
/// construct one from the returned DeviceType (and choose an index explicitly).
/// Or let it default away to zero via Device(DeviceType)
METATOMIC_TORCH_EXPORT torch::DeviceType pick_device(
std::vector<std::string> model_devices,
torch::optional<std::string> desired_device = torch::nullopt
);
Expand Down
139 changes: 95 additions & 44 deletions metatomic-torch/src/misc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,58 +23,109 @@ std::string version() {
return METATOMIC_TORCH_VERSION;
}

std::string pick_device(
std::vector<std::string> model_devices,
torch::optional<std::string> desired_device
) {
auto available_devices = std::vector<std::string>();
std::string selected_device = "cpu";

for (const auto& device: model_devices) {
if (device == "cpu") {
available_devices.emplace_back("cpu");
} else if (device == "cuda") {
if (torch::cuda::is_available()) {
available_devices.emplace_back("cuda");
}
} else if (device == "mps") {
if (torch::mps::is_available()) {
available_devices.emplace_back("mps");
}
} else {
TORCH_WARN("'model_devices' contains an entry for unknown device (" + torch::str(device)
+ "). It will be ignored.");
namespace {
using namespace std::string_literals;

static inline std::string lower(std::string s) {
std::transform(s.begin(), s.end(), s.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
return s;
}

static inline bool available_device(const std::string &n) {
if (n == "cpu") return true;
if (n == "cuda") return torch::cuda::is_available();
if (n == "mps") return torch::mps::is_available();
#ifdef TORCH_ENABLE_HIP
if (n == "hip") return torch::hip::is_available();
#else
if (n == "hip") return false;
#endif
// For many backends we can't reliably test availability at runtime;
// assume existence (caller may still fail when using the device).
if (n == "xla" || n == "ipu" || n == "xpu" || n == "ve" || n == "opencl" ||
n == "opengl" || n == "vulkan" || n == "mkldnn" || n == "ideep" ||
n == "mtia" || n == "meta" || n == "hpu" || n == "mtia") { return true; }
return false;
}

static inline torch::DeviceType map_to_devicetype(const std::string &n) {
if (n == "cpu") return torch::DeviceType::CPU;
if (n == "cuda") return torch::DeviceType::CUDA;
if (n == "mps") return torch::DeviceType::MPS;
if (n == "hip") return torch::DeviceType::HIP;
if (n == "xla") return torch::DeviceType::XLA;
if (n == "ipu") return torch::DeviceType::IPU;
if (n == "xpu") return torch::DeviceType::XPU;
if (n == "ve") return torch::DeviceType::VE;
if (n == "opencl") return torch::DeviceType::OPENCL;
if (n == "opengl") return torch::DeviceType::OPENGL;
if (n == "vulkan") return torch::DeviceType::Vulkan;
if (n == "mkldnn") return torch::DeviceType::MKLDNN;
if (n == "ideep") return torch::DeviceType::IDEEP;
if (n == "mtia") return torch::DeviceType::MTIA;
if (n == "meta") return torch::DeviceType::Meta;
if (n == "hpu") return torch::DeviceType::HPU;
if (n == "mtia") return torch::DeviceType::MTIA;
// Fallback to CPU
C10_THROW_ERROR(ValueError, "failed to find a valid device.");
}

static inline bool is_known_device(const std::string &n) {
if (n == "cpu") return true;
if (n == "cuda") return true;
if (n == "mps") return true;
if (n == "hip") return true;
if (n == "xla" || n == "ipu" || n == "xpu" || n == "ve" || n == "opencl" ||
n == "opengl" || n == "vulkan" || n == "mkldnn" || n == "ideep" ||
n == "mtia" || n == "meta" || n == "hpu") { return true; }
return false;
}

} // namespace

c10::DeviceType pick_device(std::vector<std::string> model_devices,
torch::optional<std::string> desired_device) {
// build list of available (normalized) device names in order
std::vector<std::string> available;
available.reserve(model_devices.size());
for (auto &d : model_devices) {
std::string n = lower(d);
if (!is_known_device(n)) {
TORCH_WARN("'model_devices' contains an entry for unknown device '" + d + "' (" + n + "); ignoring.");
Comment thread
HaoZeke marked this conversation as resolved.
continue;
}
if (!available_device(n)) {
TORCH_WARN("Model requested device '" + d + "' (" + n +
") but it's not available; ignoring.");
continue;
}
available.emplace_back(std::move(n));
}

if (available_devices.empty()) {
if (available.empty()) {
C10_THROW_ERROR(ValueError,
"failed to find a valid device. None of the devices supported by the model ("
+ torch::str(model_devices) + ") where available (" + torch::str(available_devices) + ")."
);
"failed to find a valid device. None of the "
"model-supported devices are available.");
}

if (desired_device == torch::nullopt) {
// no user request, pick the device the model prefers
selected_device = available_devices[0];
} else {
bool found_desired_device = false;
for (const auto& device: available_devices) {
if (device == desired_device) {
selected_device = device;
found_desired_device = true;
break;
}
}
// if no desired device requested, pick first available
if (!desired_device.has_value() || desired_device->empty()) {
return map_to_devicetype(available.front());
}

if (!found_desired_device) {
C10_THROW_ERROR(ValueError,
"failed to find requested device (" + torch::str(desired_device.value()) +
"): it is either not supported by this model or not available on this machine"
);
}
// normalize desired and check
std::string wanted = lower(desired_device.value());
for (auto &a : available) {
if (a == wanted)
return map_to_devicetype(a);
}
return selected_device;

C10_THROW_ERROR(ValueError, "failed to find requested device (" +
desired_device.value() +
"): it is either not supported by this "
"model or not available on this machine");
}

std::string pick_output(
Expand Down
34 changes: 33 additions & 1 deletion metatomic-torch/src/register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,32 @@

using namespace metatomic_torch;

std::string pick_device_pywrapper(
const std::vector<std::string> &model_devices,
const c10::optional<std::string> &requested_device
) {
try {
torch::optional<std::string> desired = torch::nullopt;
if (requested_device.has_value()) {
desired = requested_device.value();
}

c10::DeviceType devtype = metatomic_torch::pick_device(model_devices, desired);

// Convert device type to string, stripping device index
torch::Device dev(devtype);
std::string s = dev.str();
auto pos = s.find(':');
if (pos != std::string::npos) {
return s.substr(0, pos);
}
return s;

} catch (const std::exception &e) {
throw std::runtime_error(std::string("pick_device failed: ") + e.what());
}
}

TORCH_LIBRARY(metatomic, m) {
// There is no way to access the docstrings from Python, so we don't bother
// setting them to something useful here.
Expand Down Expand Up @@ -219,7 +245,13 @@ TORCH_LIBRARY(metatomic, m) {

// standalone functions
m.def("version() -> str", version);
m.def("pick_device(str[] model_devices, str? requested_device = None) -> str", pick_device);
// Expose pick_device to Python. The C++ helper returns a c10::DeviceType;
// build a torch::Device from it (with default index policy) and return its
// string representation to Python (backwards-compatible).
m.def(
"pick_device(str[] model_devices, str? requested_device = None) -> str",
&pick_device_pywrapper
);
m.def("pick_output(str requested_output, Dict(str, __torch__.torch.classes.metatomic.ModelOutput) outputs, str? desired_variant = None) -> str", pick_output);

m.def("read_model_metadata(str path) -> __torch__.torch.classes.metatomic.ModelMetadata", read_model_metadata);
Expand Down
47 changes: 29 additions & 18 deletions metatomic-torch/tests/misc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,25 @@ TEST_CASE("Version macros") {

TEST_CASE("Pick device") {
// check that first entry in vector is picked, if no desired device is given
std::vector<std::string> supported_devices = {"cpu", "cuda", "mps"};
CHECK(metatomic_torch::pick_device(supported_devices) == "cpu");

// test that desired device is picked if available
std::vector<std::string> desired_devices = {"cpu"};
std::vector<std::string> supported = {"cpu", "cuda", "mps"};

auto devtype = metatomic_torch::pick_device(supported);
// devtype must be one of the DeviceType enum values; at minimum CPU should be selectable
// Catch doesn't allow operator|| inside assertions in a portable way, so evaluate
// the boolean expression first and assert that.
bool dev_ok = (devtype == c10::DeviceType::CPU) ||
(devtype == c10::DeviceType::CUDA) ||
(devtype == c10::DeviceType::MPS);
REQUIRE(dev_ok);

// Test requested device selection when available
if (torch::cuda::is_available()) {
desired_devices.emplace_back("cuda");
auto dt_cuda = metatomic_torch::pick_device(supported, std::string("cuda"));
CHECK(dt_cuda == c10::DeviceType::CUDA);
}
if (torch::mps::is_available()) {
desired_devices.emplace_back("mps");
}

for (const auto& desired_device: desired_devices) {
CHECK(metatomic_torch::pick_device(supported_devices, desired_device) == desired_device);
auto dt_mps = metatomic_torch::pick_device(supported, std::string("mps"));
CHECK(dt_mps == c10::DeviceType::MPS);
}

// Check that warning is emitted:
Expand All @@ -51,16 +56,22 @@ TEST_CASE("Pick device") {
torch::WarningUtils::set_warnAlways(true);

std::vector<std::string> supported_devices_foo = {"cpu", "fooo"};
CHECK(metatomic_torch::pick_device(supported_devices_foo) == "cpu");
auto tdevtype = torch::Device(metatomic_torch::pick_device(supported_devices_foo));
CHECK(tdevtype.str() == "cpu");
REQUIRE_FALSE(handler.messages.empty());
CHECK(handler.messages[0].find("'model_devices' contains an entry for unknown device") != std::string::npos);
}

// check exception raised
std::vector<std::string> supported_devices_cuda = {"cuda"};
CHECK_THROWS_WITH(metatomic_torch::pick_device(supported_devices_cuda, "cpu"), StartsWith("failed to find a valid device"));

std::vector<std::string> supported_devices_cpu = {"cpu"};
CHECK_THROWS_WITH(metatomic_torch::pick_device(supported_devices_cpu, "cuda"), StartsWith("failed to find requested device"));
TEST_CASE("Pick device errors") {
// If the model only supports CUDA and CUDA is not available, throw
std::vector<std::string> only_cuda = {"cuda"};
if (!torch::cuda::is_available()) {
CHECK_THROWS_WITH(metatomic_torch::pick_device(only_cuda), StartsWith("failed to find a valid device"));
} else {
// If CUDA is available, requesting CPU when not declared should raise
std::vector<std::string> only_cpu = {"cpu"};
CHECK_THROWS_WITH(metatomic_torch::pick_device(only_cpu, std::string("cuda")), StartsWith("failed to find requested device"));
}
}


Expand Down
53 changes: 53 additions & 0 deletions python/metatomic_torch/tests/test_pick_device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pytest
import torch

import metatomic.torch as mta


def test_pick_device_basic():
# basic call should return a non-empty string describing a device
res = mta.pick_device(["cpu", "cuda", "mps"], None)
assert isinstance(res, str)
assert len(res) > 0
# sanity: in typical environments we'll at least see "cpu" somewhere
assert "cpu" == res or "cuda" == res or "mps" == res


def test_pick_device_requested_if_available():
# if CUDA is available, requesting it should yield a cuda device string
if torch.cuda.is_available():
res = mta.pick_device(["cpu", "cuda"], "cuda")
assert isinstance(res, str)
assert "cuda" == res
# if MPS is available, requesting it should yield an mps device string
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
res = mta.pick_device(["cpu", "mps"], "mps")
assert isinstance(res, str)
assert "mps" == res


def test_pick_device_ignores_unrecognized_and_warns(capfd):
# Ensure unrecognized device names are ignored and a warning is emitted
res = mta.pick_device(["cpu", "fooo"], None)
assert isinstance(res, str)
# should pick cpu (ignore "fooo")
assert "cpu" == res
# at least one warning should have been produced about the unrecognized/
# ignored entry
captured = capfd.readouterr()
err_output = captured.err.lower()
assert "fooo" in err_output
assert "ignoring" in err_output or "unknown" in err_output


def test_pick_device_error_on_unavailable_requested():
# If a device is explicitly requested but isn't available/declared,
# the python wrapper is expected to raise a RuntimeError (propagated from C++).
only_cuda = ["cuda"]
if not torch.cuda.is_available():
with pytest.raises(RuntimeError):
mta.pick_device(only_cuda, "cuda")
else:
# If CUDA is available, requesting a non-present device should raise
with pytest.raises(RuntimeError):
mta.pick_device(["cpu"], "cuda")
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ test_options =
--import-mode=append

packaging_deps =
setuptools >= 75
setuptools >= 77
packaging >= 23
cmake

Expand Down
Loading