From a7df29c90557fcd66bdde87c80cc0d8635e7c4ce Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Wed, 12 Nov 2025 17:40:25 +0100 Subject: [PATCH 1/5] feat(torch): Refactor pick_device to return DeviceType Refactors the C++ function to perform more robust device detection and normalization. - Changes the return type from std::string to c10::DeviceType. - Adds helper functions for string normalization, device availability checks (is_known_device, available_device), and mapping to the DeviceType enum. - Adds warnings for unknown or unavailable devices. - Throws an error on fallback instead of defaulting to CPU. --- .../include/metatomic/torch/misc.hpp | 7 +- metatomic-torch/src/misc.cpp | 139 ++++++++++++------ 2 files changed, 101 insertions(+), 45 deletions(-) diff --git a/metatomic-torch/include/metatomic/torch/misc.hpp b/metatomic-torch/include/metatomic/torch/misc.hpp index ebeff887f..4326fd7bb 100644 --- a/metatomic-torch/include/metatomic/torch/misc.hpp +++ b/metatomic-torch/include/metatomic/torch/misc.hpp @@ -22,7 +22,12 @@ METATOMIC_TORCH_EXPORT std::string version(); /// Select the best device according to the list of `model_devices` from a /// model, the user-provided `desired_device` and what's available on the /// current machine. -METATOMIC_TORCH_EXPORT std::string pick_device( +/// +/// This function returns a c10::DeviceType (torch::DeviceType). It does NOT +/// decide a device index — callers that need a full torch::Device should +/// construct one from the returned DeviceType (and choose an index explicitly). +/// Or let it default away to zero via Device(DeviceType) +METATOMIC_TORCH_EXPORT torch::DeviceType pick_device( std::vector model_devices, torch::optional desired_device = torch::nullopt ); diff --git a/metatomic-torch/src/misc.cpp b/metatomic-torch/src/misc.cpp index 8c3b3e799..d1be00a12 100644 --- a/metatomic-torch/src/misc.cpp +++ b/metatomic-torch/src/misc.cpp @@ -23,58 +23,109 @@ std::string version() { return METATOMIC_TORCH_VERSION; } -std::string pick_device( - std::vector model_devices, - torch::optional desired_device -) { - auto available_devices = std::vector(); - std::string selected_device = "cpu"; - - for (const auto& device: model_devices) { - if (device == "cpu") { - available_devices.emplace_back("cpu"); - } else if (device == "cuda") { - if (torch::cuda::is_available()) { - available_devices.emplace_back("cuda"); - } - } else if (device == "mps") { - if (torch::mps::is_available()) { - available_devices.emplace_back("mps"); - } - } else { - TORCH_WARN("'model_devices' contains an entry for unknown device (" + torch::str(device) - + "). It will be ignored."); +namespace { +using namespace std::string_literals; + +static inline std::string lower(std::string s) { + std::transform(s.begin(), s.end(), s.begin(), [](unsigned char c) { + return static_cast(std::tolower(c)); + }); + return s; +} + +static inline bool available_device(const std::string &n) { + if (n == "cpu") return true; + if (n == "cuda") return torch::cuda::is_available(); + if (n == "mps") return torch::mps::is_available(); +#ifdef TORCH_ENABLE_HIP + if (n == "hip") return torch::hip::is_available(); +#else + if (n == "hip") return false; +#endif + // For many backends we can't reliably test availability at runtime; + // assume existence (caller may still fail when using the device). + if (n == "xla" || n == "ipu" || n == "xpu" || n == "ve" || n == "opencl" || + n == "opengl" || n == "vulkan" || n == "mkldnn" || n == "ideep" || + n == "mtia" || n == "meta" || n == "hpu" || n == "mtia") { return true; } + return false; +} + +static inline torch::DeviceType map_to_devicetype(const std::string &n) { + if (n == "cpu") return torch::DeviceType::CPU; + if (n == "cuda") return torch::DeviceType::CUDA; + if (n == "mps") return torch::DeviceType::MPS; + if (n == "hip") return torch::DeviceType::HIP; + if (n == "xla") return torch::DeviceType::XLA; + if (n == "ipu") return torch::DeviceType::IPU; + if (n == "xpu") return torch::DeviceType::XPU; + if (n == "ve") return torch::DeviceType::VE; + if (n == "opencl") return torch::DeviceType::OPENCL; + if (n == "opengl") return torch::DeviceType::OPENGL; + if (n == "vulkan") return torch::DeviceType::Vulkan; + if (n == "mkldnn") return torch::DeviceType::MKLDNN; + if (n == "ideep") return torch::DeviceType::IDEEP; + if (n == "mtia") return torch::DeviceType::MTIA; + if (n == "meta") return torch::DeviceType::Meta; + if (n == "hpu") return torch::DeviceType::HPU; + if (n == "mtia") return torch::DeviceType::MTIA; + // Fallback to CPU + C10_THROW_ERROR(ValueError, "failed to find a valid device."); +} + +static inline bool is_known_device(const std::string &n) { + if (n == "cpu") return true; + if (n == "cuda") return true; + if (n == "mps") return true; + if (n == "hip") return true; + if (n == "xla" || n == "ipu" || n == "xpu" || n == "ve" || n == "opencl" || + n == "opengl" || n == "vulkan" || n == "mkldnn" || n == "ideep" || + n == "mtia" || n == "meta" || n == "hpu") { return true; } + return false; +} + +} // namespace + +c10::DeviceType pick_device(std::vector model_devices, + torch::optional desired_device) { + // build list of available (normalized) device names in order + std::vector available; + available.reserve(model_devices.size()); + for (auto &d : model_devices) { + std::string n = lower(d); + if (!is_known_device(n)) { + TORCH_WARN("'model_devices' contains an entry for unknown device '" + d + "' (" + n + "); ignoring."); + continue; + } + if (!available_device(n)) { + TORCH_WARN("Model requested device '" + d + "' (" + n + + ") but it's not available; ignoring."); + continue; } + available.emplace_back(std::move(n)); } - if (available_devices.empty()) { + if (available.empty()) { C10_THROW_ERROR(ValueError, - "failed to find a valid device. None of the devices supported by the model (" - + torch::str(model_devices) + ") where available (" + torch::str(available_devices) + ")." - ); + "failed to find a valid device. None of the " + "model-supported devices are available."); } - if (desired_device == torch::nullopt) { - // no user request, pick the device the model prefers - selected_device = available_devices[0]; - } else { - bool found_desired_device = false; - for (const auto& device: available_devices) { - if (device == desired_device) { - selected_device = device; - found_desired_device = true; - break; - } - } + // if no desired device requested, pick first available + if (!desired_device.has_value() || desired_device->empty()) { + return map_to_devicetype(available.front()); + } - if (!found_desired_device) { - C10_THROW_ERROR(ValueError, - "failed to find requested device (" + torch::str(desired_device.value()) + - "): it is either not supported by this model or not available on this machine" - ); - } + // normalize desired and check + std::string wanted = lower(desired_device.value()); + for (auto &a : available) { + if (a == wanted) + return map_to_devicetype(a); } - return selected_device; + + C10_THROW_ERROR(ValueError, "failed to find requested device (" + + desired_device.value() + + "): it is either not supported by this " + "model or not available on this machine"); } std::string pick_output( From 1d5d58b3a996a0381489a7e5173228bb667bbbec Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Wed, 12 Nov 2025 17:40:26 +0100 Subject: [PATCH 2/5] refactor(torch): Clean up pick_device Python bindings Moves the Python binding logic for out of a complex lambda and into a standalone function. - This wrapper handles the conversion from the C++ c10::DeviceType to the Python-facing string. - It strips the device index (e.g., 'cpu:0' -> 'cpu') for backwards compatibility. - Includes robust try/catch blocks to convert c10::Error into std::runtime_error for Python. --- metatomic-torch/src/register.cpp | 36 +++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/metatomic-torch/src/register.cpp b/metatomic-torch/src/register.cpp index 909162502..04894e011 100644 --- a/metatomic-torch/src/register.cpp +++ b/metatomic-torch/src/register.cpp @@ -6,6 +6,34 @@ using namespace metatomic_torch; +std::string pick_device_pywrapper( + const std::vector &model_devices, + const c10::optional &requested_device +) { + try { + torch::optional desired = torch::nullopt; + if (requested_device.has_value()) { + desired = requested_device.value(); + } + + c10::DeviceType devtype = metatomic_torch::pick_device(model_devices, desired); + + // Convert device type to string, stripping device index + torch::Device dev(devtype); + std::string s = dev.str(); + auto pos = s.find(':'); + if (pos != std::string::npos) { + return s.substr(0, pos); + } + return s; + + } catch (const c10::Error &e) { + throw std::runtime_error(std::string("pick_device failed: ") + e.what()); + } catch (const std::exception &e) { + throw std::runtime_error(std::string("pick_device failed: ") + e.what()); + } +} + TORCH_LIBRARY(metatomic, m) { // There is no way to access the docstrings from Python, so we don't bother // setting them to something useful here. @@ -219,7 +247,13 @@ TORCH_LIBRARY(metatomic, m) { // standalone functions m.def("version() -> str", version); - m.def("pick_device(str[] model_devices, str? requested_device = None) -> str", pick_device); + // Expose pick_device to Python. The C++ helper returns a c10::DeviceType; + // build a torch::Device from it (with default index policy) and return its + // string representation to Python (backwards-compatible). + m.def( + "pick_device(str[] model_devices, str? requested_device = None) -> str", + &pick_device_pywrapper + ); m.def("pick_output(str requested_output, Dict(str, __torch__.torch.classes.metatomic.ModelOutput) outputs, str? desired_variant = None) -> str", pick_output); m.def("read_model_metadata(str path) -> __torch__.torch.classes.metatomic.ModelMetadata", read_model_metadata); From 8b4f6a6ed612c7b6c7d92f38f877562da88f34fb Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Fri, 14 Nov 2025 15:22:34 +0100 Subject: [PATCH 3/5] Update metatomic-torch/src/register.cpp Co-authored-by: Guillaume Fraux --- metatomic-torch/src/register.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/metatomic-torch/src/register.cpp b/metatomic-torch/src/register.cpp index 04894e011..cd947b3bc 100644 --- a/metatomic-torch/src/register.cpp +++ b/metatomic-torch/src/register.cpp @@ -27,8 +27,6 @@ std::string pick_device_pywrapper( } return s; - } catch (const c10::Error &e) { - throw std::runtime_error(std::string("pick_device failed: ") + e.what()); } catch (const std::exception &e) { throw std::runtime_error(std::string("pick_device failed: ") + e.what()); } From e38c3e011349f4e5aa63bd45d0dd5296e67f5f14 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Wed, 12 Nov 2025 17:40:26 +0100 Subject: [PATCH 4/5] tst(torch): Update and add tests for pick_device - Updates C++ tests in 'misc.cpp' to align with the new c10::DeviceType return value. - Adds a new Python test suite 'test_pick_device.py'. - Tests basic selection, requested device selection, error handling for unavailable devices, and warnings for unrecognized devices. --- metatomic-torch/tests/misc.cpp | 47 +++++++++------- .../metatomic_torch/tests/test_pick_device.py | 53 +++++++++++++++++++ 2 files changed, 82 insertions(+), 18 deletions(-) create mode 100644 python/metatomic_torch/tests/test_pick_device.py diff --git a/metatomic-torch/tests/misc.cpp b/metatomic-torch/tests/misc.cpp index 4725bb638..55aae7d1c 100644 --- a/metatomic-torch/tests/misc.cpp +++ b/metatomic-torch/tests/misc.cpp @@ -29,20 +29,25 @@ TEST_CASE("Version macros") { TEST_CASE("Pick device") { // check that first entry in vector is picked, if no desired device is given - std::vector supported_devices = {"cpu", "cuda", "mps"}; - CHECK(metatomic_torch::pick_device(supported_devices) == "cpu"); - - // test that desired device is picked if available - std::vector desired_devices = {"cpu"}; + std::vector supported = {"cpu", "cuda", "mps"}; + + auto devtype = metatomic_torch::pick_device(supported); + // devtype must be one of the DeviceType enum values; at minimum CPU should be selectable + // Catch doesn't allow operator|| inside assertions in a portable way, so evaluate + // the boolean expression first and assert that. + bool dev_ok = (devtype == c10::DeviceType::CPU) || + (devtype == c10::DeviceType::CUDA) || + (devtype == c10::DeviceType::MPS); + REQUIRE(dev_ok); + + // Test requested device selection when available if (torch::cuda::is_available()) { - desired_devices.emplace_back("cuda"); + auto dt_cuda = metatomic_torch::pick_device(supported, std::string("cuda")); + CHECK(dt_cuda == c10::DeviceType::CUDA); } if (torch::mps::is_available()) { - desired_devices.emplace_back("mps"); - } - - for (const auto& desired_device: desired_devices) { - CHECK(metatomic_torch::pick_device(supported_devices, desired_device) == desired_device); + auto dt_mps = metatomic_torch::pick_device(supported, std::string("mps")); + CHECK(dt_mps == c10::DeviceType::MPS); } // Check that warning is emitted: @@ -51,16 +56,22 @@ TEST_CASE("Pick device") { torch::WarningUtils::set_warnAlways(true); std::vector supported_devices_foo = {"cpu", "fooo"}; - CHECK(metatomic_torch::pick_device(supported_devices_foo) == "cpu"); + auto tdevtype = torch::Device(metatomic_torch::pick_device(supported_devices_foo)); + CHECK(tdevtype.str() == "cpu"); REQUIRE_FALSE(handler.messages.empty()); CHECK(handler.messages[0].find("'model_devices' contains an entry for unknown device") != std::string::npos); +} - // check exception raised - std::vector supported_devices_cuda = {"cuda"}; - CHECK_THROWS_WITH(metatomic_torch::pick_device(supported_devices_cuda, "cpu"), StartsWith("failed to find a valid device")); - - std::vector supported_devices_cpu = {"cpu"}; - CHECK_THROWS_WITH(metatomic_torch::pick_device(supported_devices_cpu, "cuda"), StartsWith("failed to find requested device")); +TEST_CASE("Pick device errors") { + // If the model only supports CUDA and CUDA is not available, throw + std::vector only_cuda = {"cuda"}; + if (!torch::cuda::is_available()) { + CHECK_THROWS_WITH(metatomic_torch::pick_device(only_cuda), StartsWith("failed to find a valid device")); + } else { + // If CUDA is available, requesting CPU when not declared should raise + std::vector only_cpu = {"cpu"}; + CHECK_THROWS_WITH(metatomic_torch::pick_device(only_cpu, std::string("cuda")), StartsWith("failed to find requested device")); + } } diff --git a/python/metatomic_torch/tests/test_pick_device.py b/python/metatomic_torch/tests/test_pick_device.py new file mode 100644 index 000000000..fa57cc314 --- /dev/null +++ b/python/metatomic_torch/tests/test_pick_device.py @@ -0,0 +1,53 @@ +import pytest +import torch + +import metatomic.torch as mta + + +def test_pick_device_basic(): + # basic call should return a non-empty string describing a device + res = mta.pick_device(["cpu", "cuda", "mps"], None) + assert isinstance(res, str) + assert len(res) > 0 + # sanity: in typical environments we'll at least see "cpu" somewhere + assert "cpu" == res or "cuda" == res or "mps" == res + + +def test_pick_device_requested_if_available(): + # if CUDA is available, requesting it should yield a cuda device string + if torch.cuda.is_available(): + res = mta.pick_device(["cpu", "cuda"], "cuda") + assert isinstance(res, str) + assert "cuda" == res + # if MPS is available, requesting it should yield an mps device string + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + res = mta.pick_device(["cpu", "mps"], "mps") + assert isinstance(res, str) + assert "mps" == res + + +def test_pick_device_ignores_unrecognized_and_warns(capfd): + # Ensure unrecognized device names are ignored and a warning is emitted + res = mta.pick_device(["cpu", "fooo"], None) + assert isinstance(res, str) + # should pick cpu (ignore "fooo") + assert "cpu" == res + # at least one warning should have been produced about the unrecognized/ + # ignored entry + captured = capfd.readouterr() + err_output = captured.err.lower() + assert "fooo" in err_output + assert "ignoring" in err_output or "unknown" in err_output + + +def test_pick_device_error_on_unavailable_requested(): + # If a device is explicitly requested but isn't available/declared, + # the python wrapper is expected to raise a RuntimeError (propagated from C++). + only_cuda = ["cuda"] + if not torch.cuda.is_available(): + with pytest.raises(RuntimeError): + mta.pick_device(only_cuda, "cuda") + else: + # If CUDA is available, requesting a non-present device should raise + with pytest.raises(RuntimeError): + mta.pick_device(["cpu"], "cuda") From b09069cf2c737c920908f6ec5d0124014a466acc Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Wed, 12 Nov 2025 17:40:26 +0100 Subject: [PATCH 5/5] chore(CI): Update setuptools dependency in tox.ini --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 86a3b3e84..a0cd60b96 100644 --- a/tox.ini +++ b/tox.ini @@ -31,7 +31,7 @@ test_options = --import-mode=append packaging_deps = - setuptools >= 75 + setuptools >= 77 packaging >= 23 cmake