diff --git a/process/solver.py b/process/solver.py index 5e9bec1813..31551f5398 100644 --- a/process/solver.py +++ b/process/solver.py @@ -6,6 +6,7 @@ from process.evaluators import Evaluators from abc import ABC, abstractmethod from typing import Optional, Union +import importlib from pyvmcon import ( AbstractProblem, Result, @@ -233,6 +234,28 @@ def get_solver(solver_name: str = "vmcon") -> _Solver: if solver_name == "vmcon": solver = Vmcon() else: - raise ValueError(f'Unrecognised solver name argument "{solver_name}"') + try: + solver = load_external_solver(solver_name) + except Exception as e: + raise ValueError( + f'Solver name is not an inbuilt PROCESS solver or recognised package "{solver_name}"' + ) from e return solver + + +def load_external_solver(package: str): + """Attempts to load a package of name `package`. + + If a package of the name is available, return the `__process_solver__` + attribute of that package or raise an `AttributeError`.""" + module = importlib.import_module(package) + + solver = getattr(module, "__process_solver__", None) + + if solver is None: + raise AttributeError( + f"Module {module.__name__} does not have a '__process_solver__' attribute." + ) + + return solver()