From 9c6cd4e23f27b99c0c1c0e94557662f1f2aba253 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gerhard=20Br=C3=A4unlich?= Date: Wed, 7 May 2025 12:15:36 +0200 Subject: [PATCH 1/5] tests/test_problem_implementations.py: Remove __future__ import --- engibench/utils/slurm.py | 311 -------------------------- tests/test_problem_implementations.py | 2 - 2 files changed, 313 deletions(-) delete mode 100644 engibench/utils/slurm.py diff --git a/engibench/utils/slurm.py b/engibench/utils/slurm.py deleted file mode 100644 index 45dd5f83..00000000 --- a/engibench/utils/slurm.py +++ /dev/null @@ -1,311 +0,0 @@ -"""Slurm executor for parameter space discovery.""" - -from argparse import ArgumentParser -from collections.abc import Callable, Iterable, Sequence -from dataclasses import asdict -from dataclasses import dataclass -from dataclasses import field -import importlib -import itertools -import os -import pickle -import shutil -import subprocess -import sys -import tempfile -from typing import Any, Generic, TypeVar - -from numpy import typing as npt - -from engibench.core import OptiStep -from engibench.core import Problem - - -@dataclass -class Args: - """Collection of arguments passed to `Problem()`, `Problem.simulate()` and `DesignType()`.""" - - problem_args: dict[str, Any] = field(default_factory=dict) - """Keyword arguments to be passed to :class:`engibench.core.Problem()`.""" - simulate_args: dict[str, Any] = field(default_factory=dict) - """Keyword arguments to be passed to :meth:`engibench.core.Problem.simulate()`.""" - optimize_args: dict[str, Any] = field(default_factory=dict) - """Keyword arguments to be passed to :meth:`engibench.core.Problem.optimize()`.""" - design_args: dict[str, Any] = field(default_factory=dict) - """Keyword arguments to be passed to `DesignType()` or - the `design_factory` argument of :func:`submit`.""" - - -def merge_args(a: Args, b: Args) -> Args: - """Merge arguments from `a` with `b`.""" - return Args( - problem_args={**a.problem_args, **b.problem_args}, - simulate_args={**a.simulate_args, **b.simulate_args}, - design_args={**a.design_args, **b.design_args}, - ) - - -DesignType = TypeVar("DesignType") - - -@dataclass -class Job(Generic[DesignType]): - """Representation of a single slurm job.""" - - job_type: str - problem: Callable[..., Problem[DesignType]] - design_factory: Callable[..., DesignType] | None - args: Args - - def serialize(self) -> dict[str, Any]: - """Serialize a job object for an other python process.""" - return { - "job_type": self.job_type, - "problem": serialize_callable(self.problem), - "args": asdict(self.args), - "design_factory": serialize_callable(self.design_factory) if self.design_factory is not None else None, - } - - @classmethod - def deserialize(cls, serialized_job: dict[str, Any]) -> "Job": - """Deserialize a job object from an other python process.""" - design_factory = serialized_job["design_factory"] - return cls( - job_type=serialized_job["job_type"], - problem=deserialize_callable(serialized_job["problem"]), - args=Args(**serialized_job["args"]), - design_factory=deserialize_callable(design_factory) if design_factory is not None else None, - ) - - def run(self) -> tuple[DesignType, list[OptiStep]] | npt.NDArray[Any] | Any: - """Run the optimization defined by the job.""" - problem = self.problem(config=self.args.problem_args) - design = self.args.design_args.get("design", None) - if self.job_type == "simulate": - return problem.simulate(design=design, config=self.args.simulate_args) # type: ignore # noqa: PGH003 - if self.job_type == "optimize": - return problem.optimize(starting_point=design, config=self.args.optimize_args) # type: ignore # noqa: PGH003 - if self.job_type == "render": - return problem.render(design=design, config=self.args.simulate_args) # type: ignore # noqa: PGH003 - msg = f"Unknown job type: {self.job_type}" - raise ValueError(msg) - - -def design_type(t: type[Problem] | Callable[..., Problem]) -> type[Any]: - """Deduce the design type corresponding to the given `Problem` type.""" - if not isinstance(t, type): - msg = f"Could not deduce the design type corresponding to `{t.__name__}`: The object is not a type" - raise TypeError(msg) from None - if not issubclass(t, Problem): - msg = f"Could not deduce the design type corresponding to `{t.__name__}`: The object is not a Problem type" - raise TypeError(msg) from None - try: - (design_type,) = t.__orig_bases__[0].__args__ # type: ignore[attr-defined] - except AttributeError: - msg = f"Could not deduce the design type corresponding to `{t.__name__}`: The Problem class does not specify its type for its design" - raise ValueError(msg) from None - return design_type - - -SerializedType = tuple[str, str, str] - - -def serialize_callable(t: Callable[..., Any] | type[Any]) -> SerializedType: - """Serialize a callable (problem type supported) so it can be imported by a different python process.""" - top_level_module, _ = t.__module__.split(".", 1) - path = sys.modules[top_level_module].__file__ - if path is None: - msg = "Got a module without path" - raise RuntimeError(msg) - if os.path.basename(path) == "__init__.py": - path = os.path.dirname(path) - path = os.path.dirname(path) - return (path, t.__module__, t.__name__) - - -def deserialize_callable(serialized_type: SerializedType) -> Callable[..., Any] | type[Any]: - """Deserialize information on how to load a callable serialized by a different python process.""" - path, module_name, problem_name = serialized_type - sys.path.append(path) - module = importlib.import_module(module_name) - return getattr(module, problem_name) - - -@dataclass -class SlurmConfig: - """Collection of slurm parameters passed to sbatch.""" - - sbatch_executable: str = "sbatch" - """Path to the sbatch executable if not in PATH""" - log_dir: str | None = None - """Path of the log directory""" - name: str | None = None - """Optional name for the jobs""" - account: str | None = None - """Slurm account to use""" - runtime: str | None = None - """Optional runtime in the format ``hh:mm:ss``. """ - constraint: str | None = None - """Optional constraint""" - mem_per_cpu: str | None = None - """E.g. "4G".""" - mem: str | None = None - """E.g. "4G".""" - nodes: int | None = None - ntasks: int | None = None - cpus_per_task: int | None = None - extra_args: Sequence[str] = () - """Extra arguments passed to sbatch.""" - - -def submit( - job_type: str, - problem: type[Problem], - parameter_space: list[Args], - design_factory: Callable[..., DesignType] | None = None, - config: SlurmConfig | None = None, -) -> None: - """Submit a job array for a parameter discovery to slurm. - - - :attr:`job_type` - The type of the job to be submitted: 'simulate', 'optimize', or 'render'. - - :attr:`problem` - The problem type for which the simulation should be run. - - :attr:`parameter_space` - One :class:`Args` instance per simulation run to be submitted. - - :attr:`design_factory` - If not None, pass `Args.design_args` to `design_factory` instead of `DesignType()`. - - :attr:`design_factory` - Custom arguments passed to `sbatch`. - """ - if config is None: - config = SlurmConfig() - - log_file = os.path.join(config.log_dir, "%j.log") if config.log_dir is not None else None - if config.log_dir is not None: - os.makedirs(config.log_dir, exist_ok=True) - - # Dump parameter space: - param_dir = tempfile.mkdtemp(dir=os.environ.get("SCRATCH")) - for job_no, args in enumerate(parameter_space, start=1): - job = Job(job_type, problem=problem, design_factory=design_factory, args=args) - dump_job(job, param_dir, job_no) - - optional_args = ( - ("--output", log_file), - ("--comment", config.name), - ("--time", config.runtime), - ("--constraint", config.constraint), - ("--mem-per-cpu", config.mem_per_cpu), - ("--mem", config.mem), - ("--nodes", config.nodes), - ("--ntasks", config.ntasks), - ("--cpus-per-task", config.cpus_per_task), - ) - cmd = [ - config.sbatch_executable, - "--parsable", - "--export=ALL", - f"--array=1-{len(parameter_space)}%1000", - *(f"{arg}={value}" for arg, value in optional_args if value is not None), - *config.extra_args, - "--wrap", - f"{sys.executable} {__file__} run {param_dir}", - ] - - job_id = run_sbatch(cmd) - cleanup_cmd = [ - config.sbatch_executable, - "--parsable", - f"--dependency=afterany:{job_id}", - "--export=ALL", - "--wait", - "--wrap", - f"{sys.executable} {__file__} cleanup {param_dir}", - ] - run_sbatch(cleanup_cmd) - - -def dump_job(job: Job, folder: str, index: int) -> None: - """Dump a job object corresponding to the item of a slurm job array with specified index to disk.""" - parameter_file = os.path.join(folder, f"parameter_space_{index}.pkl") - with open(parameter_file, "wb") as stream: - pickle.dump(job.serialize(), stream) - - -def load_job(folder: str, index: int) -> Job: - """Load a job object corresponding to the item of a slurm job array with specified index from disk.""" - parameter_file = os.path.join(folder, f"parameter_space_{index}.pkl") - with open(parameter_file, "rb") as stream: - return Job.deserialize(pickle.load(stream)) - - -def load_job_args(folder: str) -> Iterable[tuple[int, dict[str, Any]]]: - """Load the enumerated argument parts of all jobs of a slurm job array from disk.""" - for index in itertools.count(1): - parameter_file = os.path.join(folder, f"parameter_space_{index}.pkl") - try: - with open(parameter_file, "rb") as stream: - yield index, pickle.load(stream)["args"] - except FileNotFoundError: - break - - -def run_sbatch(cmd: list[str]) -> str: - """Execute sbatch with the given arguments, returning the job id of the submitted job.""" - try: - proc = subprocess.run(cmd, shell=False, check=True, capture_output=True) - except subprocess.CalledProcessError as e: - msg = f"sbatch job submission failed: {e.stderr.decode()}" - raise RuntimeError(msg) from e - return proc.stdout.decode().strip() - - -def slurm_job_entrypoint() -> None: - """Entrypoint of a single slurm job. - - The "run" mode is for the job array items which run the simulation: - ```sh - python slurm.py run - ``` - this mode will read from the environment variable `SLURM_ARRAY_TASK_ID` and will load the corresponding simulation parameters. - The "cleanup" mode combines the results of all simulations to one file. - ```sh - python slurm.py cleanup - ``` - """ - - def run(work_dir: str) -> None: - index = int(os.environ["SLURM_ARRAY_TASK_ID"]) - job = load_job(work_dir, index) - results = job.run() - result_file = os.path.join(work_dir, f"{index}.pkl") - with open(result_file, "wb") as stream: - pickle.dump(results, stream) - - def cleanup(work_dir: str) -> None: - results = [] - for index, result_args in load_job_args(work_dir): - result_file = os.path.join(work_dir, f"{index}.pkl") - if not os.path.exists(result_file): - print(f"Warning: Result file {result_file} does not exist. Skipping.") - continue - try: - with open(result_file, "rb") as stream: - result = pickle.load(stream) - results.append({"results": result, **result_args}) - except Exception as e: # noqa: BLE001 - print(f"Error loading {result_file}: {e}. Skipping.") - continue - - print(os.getcwd()) - with open("results.pkl", "wb") as stream: - pickle.dump(results, stream) - shutil.rmtree(work_dir) - - modes = {f.__name__: f for f in (run, cleanup)} - parser = ArgumentParser() - parser.add_argument("mode", choices=list(modes.keys()), help="either run or cleanup") - parser.add_argument("work_dir", help="Path to the work directory") - args = parser.parse_args() - mode = modes[args.mode] - mode(work_dir=args.work_dir) - - -if __name__ == "__main__": - slurm_job_entrypoint() diff --git a/tests/test_problem_implementations.py b/tests/test_problem_implementations.py index 1488526e..b51c6ab0 100644 --- a/tests/test_problem_implementations.py +++ b/tests/test_problem_implementations.py @@ -1,7 +1,5 @@ """This file contains tests making sure the implemented problems respect the API.""" -from __future__ import annotations - import inspect from typing import get_args, get_origin From c39de05aa9b79e902e7f6bbee79fe3b8440e2861 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gerhard=20Br=C3=A4unlich?= Date: Wed, 7 May 2025 13:10:53 +0200 Subject: [PATCH 2/5] utils.slurm: Rewrite the module --- engibench/utils/slurm/__init__.py | 257 ++++++++++++++++++++++++++++++ engibench/utils/slurm/run_job.py | 112 +++++++++++++ tests/tools/fake_sbatch.py | 24 ++- tests/utils/test_slurm.py | 247 +++++++++++++++++----------- 4 files changed, 539 insertions(+), 101 deletions(-) create mode 100644 engibench/utils/slurm/__init__.py create mode 100644 engibench/utils/slurm/run_job.py diff --git a/engibench/utils/slurm/__init__.py b/engibench/utils/slurm/__init__.py new file mode 100644 index 00000000..5bb7b295 --- /dev/null +++ b/engibench/utils/slurm/__init__.py @@ -0,0 +1,257 @@ +"""Slurm executor for parameter space discovery.""" + +from collections.abc import Callable, Iterable, Sequence +from dataclasses import asdict +from dataclasses import dataclass +from dataclasses import field +import itertools +import os +import pickle +import shutil +import subprocess +import sys +import tempfile +from typing import Any, Generic, TypeVar + + +class JobError(Exception): + """User error happening during execution of a slurm job.""" + + def __init__(self, origin: Exception, context: str, job_args: dict[str, Any]) -> None: + self.origin = origin + self.context = context + self.job_args = job_args + + def __str__(self) -> str: + args = f"\nargs = {self.job_args}" if self.job_args else "" + return f"""💥 JobError({self.context}): +{self.origin}{args} +""" + + +@dataclass +class SlurmConfig: + """Collection of slurm parameters passed to sbatch.""" + + sbatch_executable: str = "sbatch" + """Path to the sbatch executable if not in PATH""" + log_dir: str | None = None + """Path of the log directory""" + name: str | None = None + """Optional name for the jobs""" + account: str | None = None + """Slurm account to use""" + runtime: str | None = None + """Optional runtime in the format ``hh:mm:ss``. """ + constraint: str | None = None + """Optional constraint""" + mem_per_cpu: str | None = None + """E.g. "4G".""" + mem: str | None = None + """E.g. "4G".""" + nodes: int | None = None + ntasks: int | None = None + cpus_per_task: int | None = None + extra_args: Sequence[str] = () + """Extra arguments passed to sbatch.""" + + +R = TypeVar("R") +S = TypeVar("S") + +WORKER = os.path.join(os.path.dirname(__file__), "run_job.py") + + +class SubmittedJobArray(Generic[R]): + """Representation for a submitted slurm job array.""" + + def __init__(self, job_id: str, work_dir: str, n_jobs: int) -> None: + self.job_id = job_id + self.work_dir = work_dir + self.n_jobs = n_jobs + + def reduce( + self, f_reduce: Callable[[list[R]], S], slurm_args: SlurmConfig | None = None, size_limit: int | None = 10000000 + ) -> S: + """Reduce the results of a slurm job array. + + The return values of the callable `f` passed to :function:`sbatch_map` will be collected into a list and passed as + the single argument to `f_reduce`. + After running `f_reduce` as a slurm job, its return value will be passed back and will be the return value of this method. + + To prevent larger workloads running on a login node, this function will raise an exception if the resulting list in pickled + form takes more than `size_limit` bytes (recommendation: 10MB). + Only increase / set to 0 if you want to annoy the HPC team 😈. + - :attr:`f_reduce` - The callable which performs the post processing on the list of return values for each job. + - :attr:`slurm_args` - Arguments passed to `sbatch`. + - :attr:`size_limit` - Upper limit for the allowed size of the post processed data in pickled form. + """ + with open(os.path.join(self.work_dir, "jobs", "reduce.pkl"), "wb") as stream: + pickle.dump(MemorizeModule(f_reduce), stream) + + # Submit reduce job: + cmd = " ".join((sys.executable, WORKER, "reduce", self.work_dir, str(self.n_jobs))) + run_sbatch(cmd, slurm_args=slurm_args or SlurmConfig(), job_dependency=self.job_id, wait=True) + + # Try to load and return the reduced result if it is not too large: + reduced_path = os.path.join(self.work_dir, "reduced.pkl") + if size_limit is not None and os.path.getsize(reduced_path) > size_limit: + raise RuntimeError(f"""Pickled data is too large to be processed by a login node. +Please submit a separate slurm job for postprocessing. +The pickled data is still accessible here: {reduced_path} +""") + with open(reduced_path, "rb") as stream: + result = pickle.load(stream) + shutil.rmtree(self.work_dir) + return result + + def save(self, out: str, slurm_args: SlurmConfig | None = None) -> None: + """Save the collected results of a slurm job array. + + The return values of the callable `f` passed to :function:`sbatch_map` will be collected into a list and saved to disk. + + - :attr:`out` - Path to store the pickle archive. + """ + cmd = " ".join((sys.executable, WORKER, "save", "-o", out, self.work_dir, str(self.n_jobs))) + run_sbatch(cmd, slurm_args=slurm_args, job_dependency=self.job_id, wait=True) + + +def sbatch_map( + f: Callable[..., R], + args: Iterable[dict[str, Any]], + slurm_args: SlurmConfig | None = None, + group_size: int = 1, +) -> SubmittedJobArray: + """Submit a job array for a parameter discovery to slurm. + + The returned :class:`SubmittedJobArray` object can be used to + start a post processing job which will run after all jobs of the array are done. + + - :attr:`f` - The callable which will be applied to each item in `args`. + - :attr:`args` - Array of keyword arguments which will be passed to `f`. + - :attr:`slurm_args` - Arguments passed to `sbatch`. + - :attr:`group_size` - Sequentially process a number of `group_size` jobs in one slurm job. + + Details: The individual jobs of the jobarray will be processed in + individual python instances running the `engibench.utils.slurm.run_job` + standalone script. + """ + # Dump jobs: + work_dir = tempfile.mkdtemp(dir=os.environ.get("SCRATCH")) + os.makedirs(os.path.join(work_dir, "jobs")) + os.makedirs(os.path.join(work_dir, "results")) + n_jobs = 0 + with open(os.path.join(work_dir, "jobs", "map_callback.pkl"), "wb") as stream: + pickle.dump(MemorizeModule(f), stream) + for job_no, arg in enumerate(args): + with open(os.path.join(work_dir, "jobs", f"{job_no}.pkl"), "wb") as stream: + pickle.dump(MemorizeModule(arg), stream) + n_jobs += 1 + + map_cmd = f"{sys.executable} {WORKER} run {work_dir} {n_jobs}" + job_id = run_sbatch( + cmd=map_cmd, + slurm_args=slurm_args or SlurmConfig(), + array_len=n_jobs // group_size + (1 if n_jobs % group_size else 0), + ) + return SubmittedJobArray(job_id, work_dir, n_jobs) + + +def run_sbatch( + cmd: str, + slurm_args: SlurmConfig | None = None, + array_len: int | None = None, + job_dependency: str | None = None, + *, + wait: bool = False, +) -> str: + """Execute sbatch with the given arguments, returning the job id of the submitted job.""" + if slurm_args is None: + slurm_args = SlurmConfig() + log_file = os.path.join(slurm_args.log_dir, "%j.log") if slurm_args.log_dir is not None else None + if slurm_args.log_dir is not None: + os.makedirs(slurm_args.log_dir, exist_ok=True) + + optional_args = ( + ("--output", log_file), + ("--comment", slurm_args.name), + ("--time", slurm_args.runtime), + ("--constraint", slurm_args.constraint), + ("--mem-per-cpu", slurm_args.mem_per_cpu), + ("--mem", slurm_args.mem), + ("--nodes", slurm_args.nodes), + ("--ntasks", slurm_args.ntasks), + ("--cpus-per-task", slurm_args.cpus_per_task), + ("--array", f"1-{array_len}%1000" if array_len is not None else None), + ("--dependency", f"afterany:{job_dependency}" if job_dependency is not None else None), + ) + sbatch_cmd = [ + slurm_args.sbatch_executable, + "--parsable", + "--export=ALL", + *(f"{arg}={value}" for arg, value in optional_args if value is not None), + *slurm_args.extra_args, + *(("--wait",) if wait else ()), + "--wrap", + cmd, + ] + try: + proc = subprocess.run(sbatch_cmd, shell=False, check=True, capture_output=True) + except subprocess.CalledProcessError as e: + msg = f"sbatch job submission failed: {e.stderr.decode()}" + raise RuntimeError(msg) from e + return proc.stdout.decode().strip() + + +def load_results() -> list[Any]: + """Load the pickled results produced by :func:`sbatch_map`.""" + with open("results.pkl", "rb") as stream: + return pickle.load(stream) + + +def collect_jobs(work_dir: str, n_jobs: int) -> list[Any]: + """Collect all results of a slurm job array into a list.""" + + def load_result(path: str) -> Any: + try: + with open(path, "rb") as stream: + return pickle.load(stream) + except Exception as e: # noqa: BLE001 + return JobError(e, "Collect job", {}) + + return [load_result(os.path.join(work_dir, "results", f"{index}.pkl")) for index in range(n_jobs)] + + +class MemorizeModule: + """Wrapper which allows unpickling the wrapped object even when its module is not in PYTHONPATH. + + Use it like `pickle.dumps(MemorizeModule(obj))`. + The resulting pickle archive will directly unpickle to obj. + """ + + def __init__(self, obj: Any) -> None: + self.obj = obj + + @staticmethod + def _reconstruct(reduced_module: str | None, pickled_obj: bytes) -> Any: + if reduced_module is not None and reduced_module not in sys.path: + sys.path.append(reduced_module) + return pickle.loads(pickled_obj) + + def __reduce__(self) -> tuple[Callable[..., Any], tuple[str | None, bytes]]: + pickled_obj = pickle.dumps(self.obj) + return (self._reconstruct, (module_path(self.obj), pickled_obj)) + + +def module_path(obj: Any) -> str | None: + """Return the path of the toplevel module of the module containing `obj`.""" + if not hasattr(obj, "__module__"): + return None + top_level_module, _ = obj.__module__.split(".", 1) + path = sys.modules[top_level_module].__file__ + if path is None: + msg = "Got a module without path" + raise RuntimeError(msg) + if os.path.basename(path) == "__init__.py": + path = os.path.dirname(path) + return os.path.dirname(path) diff --git a/engibench/utils/slurm/run_job.py b/engibench/utils/slurm/run_job.py new file mode 100644 index 00000000..bf544ff4 --- /dev/null +++ b/engibench/utils/slurm/run_job.py @@ -0,0 +1,112 @@ +"""Slurm job worker.""" + +from argparse import ArgumentParser +import os +import pickle +import shutil + +from engibench.utils.slurm import collect_jobs +from engibench.utils.slurm import JobError +from engibench.utils.slurm import MemorizeModule + + +def map_job_group(work_dir: str, n_jobs: int) -> None: + """Process a job or group of job of a slurm job array. + + This is the "map" step of "map - reduce". + """ + start = int(os.environ["SLURM_ARRAY_TASK_MIN"]) + stop = int(os.environ["SLURM_ARRAY_TASK_MAX"]) + current = int(os.environ["SLURM_ARRAY_TASK_ID"]) - start + array_size = stop - start + group_size = n_jobs // array_size + (1 if n_jobs % array_size else 0) + try: + with open(os.path.join(work_dir, "jobs", "map_callback.pkl"), "rb") as in_stream: + map_callback = pickle.load(in_stream) + except Exception as e: # noqa: BLE001 + exception = e + + def map_callback(**_kwargs) -> None: + raise exception + + # Run `group_size` jobs as sub jobs of current job (usecase: many small jobs): + for index in range(current * group_size, max((current + 1) * group_size, n_jobs)): + result_path = os.path.join(work_dir, "results", f"{index}.pkl") + try: + with open(os.path.join(work_dir, "jobs", f"{index}.pkl"), "rb") as stream: + args = pickle.load(stream) + except Exception as e: # noqa: BLE001 + with open(result_path, "wb") as out_stream: + pickle.dump(JobError(e, "Unpickle job array item", {}), out_stream) + continue + try: + result = map_callback(**args) + with open(result_path, "wb") as out_stream: + pickle.dump(MemorizeModule(result), out_stream) + except Exception as e: # noqa: BLE001 + with open(result_path, "wb") as out_stream: + pickle.dump(JobError(e, "Run job array item", args), out_stream) + + +def reduce_job_results(work_dir: str, n_jobs: int) -> None: + """Collect all results or errors from job array jobs, passing to a reduce callback.""" + with open(os.path.join(work_dir, "jobs", "reduce.pkl"), "rb") as in_stream: + reduce_callback = pickle.load(in_stream) + + results = collect_jobs(work_dir, n_jobs) + reduced = reduce_callback(results) + with open(os.path.join(work_dir, "reduced.pkl"), "wb") as out_stream: + pickle.dump(MemorizeModule(reduced), out_stream) + + +def save(work_dir: str, n_jobs: int, out: str) -> None: + """Collect all results or errors from job array jobs and save as a pickled list to disk.""" + results = collect_jobs(work_dir, n_jobs) + + if not any(isinstance(r, JobError) for r in results): + shutil.rmtree(work_dir) + with open(out, "wb") as out_stream: + pickle.dump(results, out_stream) + + +def cli() -> None: + """Entrypoint of a single slurm job. + + The "run" mode is for the job array items which run the simulation: + ```sh + python slurm.py run + ``` + this mode will read from the environment variable `SLURM_ARRAY_TASK_ID` and will load the corresponding simulation parameters. + The "cleanup" mode combines the results of all simulations to one file. + ```sh + python slurm.py reduce + ``` + """ + parser = ArgumentParser() + subparsers = parser.add_subparsers( + dest="subcmd", + title="List of sub-commands", + description="For an overview of action specific parameters, use %(prog)s --help", + help="Sub-command help", + metavar="", + ) + subparser = subparsers.add_parser("run", help=map_job_group.__doc__) + subparser.set_defaults(subcmd=map_job_group) + subparser.add_argument("work_dir", help="Path to the work directory") + subparser.add_argument("n_jobs", type=int, help="Total number of jobs") + subparser = subparsers.add_parser("reduce", help=reduce_job_results.__doc__) + subparser.set_defaults(subcmd=reduce_job_results) + subparser.add_argument("work_dir", help="Path to the work directory") + subparser.add_argument("n_jobs", type=int, help="Total number of jobs") + subparser = subparsers.add_parser("save", help=save.__doc__) + subparser.set_defaults(subcmd=save) + subparser.add_argument("work_dir", help="Path to the work directory") + subparser.add_argument("n_jobs", type=int, help="Total number of jobs") + subparser.add_argument("-o", dest="out", default=None, help="Output path for the pickle archive containing the results") + args = vars(parser.parse_args()) + subcmd = args.pop("subcmd") + subcmd(**args) + + +if __name__ == "__main__": + cli() diff --git a/tests/tools/fake_sbatch.py b/tests/tools/fake_sbatch.py index 5445de06..6fa6d4b8 100755 --- a/tests/tools/fake_sbatch.py +++ b/tests/tools/fake_sbatch.py @@ -5,10 +5,15 @@ import subprocess -def parse_array_range(s: str) -> slice: - """Parse a string like 1-3.""" +def parse_array_range(s: str) -> tuple[slice, int | None]: + """Parse a string like 1-3 or 1-3%1000.""" + if "%" in s: + s, max_jobs_raw = s.split("%", 1) + max_jobs = int(max_jobs_raw) + else: + max_jobs = None start, stop = s.split("-") - return slice(int(start), int(stop)) + return slice(int(start), int(stop)), max_jobs def parse_cmd(s: str) -> list[str]: @@ -22,8 +27,17 @@ def main() -> None: parser.add_argument("--array", type=parse_array_range, default=None) args, _ = parser.parse_known_args() if args.array is not None: - for index in range(args.array.start, args.array.stop + 1): - subprocess.run(args.wrap, check=True, env={"SLURM_ARRAY_TASK_ID": str(index)}) + arr, _max_jobs = args.array + for index in range(arr.start, arr.stop + 1): + subprocess.run( + args.wrap, + check=True, + env={ + "SLURM_ARRAY_TASK_ID": str(index), + "SLURM_ARRAY_TASK_MIN": str(arr.start), + "SLURM_ARRAY_TASK_MAX": str(arr.stop), + }, + ) else: subprocess.run(args.wrap, check=True) # Print a fake slurm job id: diff --git a/tests/utils/test_slurm.py b/tests/utils/test_slurm.py index 0588e5ef..81a1ebd9 100644 --- a/tests/utils/test_slurm.py +++ b/tests/utils/test_slurm.py @@ -1,96 +1,151 @@ -# ruff: noqa: ERA001 -# TODO(https://github.com/IDEALLab/EngiBench/issues/107): Rework slurm utils - -# from __future__ import annotations - -# import os -# import pickle -# import subprocess -# from typing import Any - -# import numpy as np -# import numpy.typing as npt -# import pytest - -# from engibench.core import Problem -# from engibench.utils import slurm - - -# class FakeDesign: -# """It's only a model.""" - -# def __init__(self, design_id: int) -> None: -# self.design_id = design_id - - -# class FakeProblem(Problem[FakeDesign]): -# def __init__(self, problem_id: int, *, some_arg: bool) -> None: -# self.problem_id = problem_id -# self.some_arg = some_arg - -# def simulate(self, design: FakeDesign, config: dict[str, Any] | None = None, **kwargs) -> npt.NDArray: -# offset = (config or {})["offset"] -# return np.array([design.design_id + offset]) - - -# FAKE_SBATCH = os.path.join( -# os.path.dirname(__file__), -# "..", -# "tools", -# "fake_sbatch.py", -# ) - - -# def find_real_sbatch() -> list[str]: -# try: -# if ( -# subprocess.run( -# ["sbatch", "--help"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=False -# ).returncode -# == 0 -# ): -# return ["sbatch"] -# except FileNotFoundError: -# pass -# return [] - - -# @pytest.mark.parametrize("sbatch_exec", [FAKE_SBATCH, *find_real_sbatch()]) -# def test_run_slurm(sbatch_exec: str) -> None: -# """Test if a fake slurm can process FakeProblem.""" - -# static_args = slurm.Args(simulate_args={"config": {"offset": 10}}, problem_args={"some_arg": True}) -# parameter_space = [ -# slurm.Args(problem_args={"problem_id": 1}, design_args={"design_id": -1}), -# slurm.Args(problem_args={"problem_id": 2}, design_args={"design_id": -2}), -# slurm.Args(problem_args={"problem_id": 3}, design_args={"design_id": -3}), -# ] -# slurm.submit( -# problem=FakeProblem, -# static_args=static_args, -# parameter_space=parameter_space, -# config=slurm.SlurmConfig(sbatch_executable=sbatch_exec), -# ) -# with open("results.pkl", "rb") as stream: -# results = pickle.load(stream) -# os.remove("results.pkl") -# assert results == [ -# { -# "problem_args": {"some_arg": True, "problem_id": 1}, -# "simulate_args": {"config": {"offset": 10}}, -# "design_args": {"design_id": -1}, -# "results": np.array([9]), -# }, -# { -# "problem_args": {"some_arg": True, "problem_id": 2}, -# "simulate_args": {"config": {"offset": 10}}, -# "design_args": {"design_id": -2}, -# "results": np.array([8]), -# }, -# { -# "problem_args": {"some_arg": True, "problem_id": 3}, -# "simulate_args": {"config": {"offset": 10}}, -# "design_args": {"design_id": -3}, -# "results": np.array([7]), -# }, -# ] +import os +import pickle +import subprocess +from typing import Any + +import numpy as np +from numpy.typing import NDArray +import pytest + +from engibench.core import OptiStep +from engibench.core import Problem +from engibench.utils import slurm + + +def test_pickle_callable_works_for_a_function() -> None: + serialized = pickle.dumps(slurm.MemorizeModule(a_function)) + deserialized = pickle.loads(serialized) + assert deserialized() + + +def test_pickle_callable_works_for_a_class() -> None: + serialized = pickle.dumps(slurm.MemorizeModule(AClass)) + deserialized = pickle.loads(serialized) + assert deserialized(1.0).x == 1.0 + + +def test_pickle_callable_works_for_a_method() -> None: + a_method = AClass(1.0).a_method + serialized = pickle.dumps(slurm.MemorizeModule(a_method)) + deserialized = pickle.loads(serialized) + assert deserialized() == 1.0 + + +class FakeDesign: + """It's only a model.""" + + def __init__(self, design_id: int) -> None: + self.design_id = design_id + + def __eq__(self, other: object) -> bool: + return isinstance(other, type(self)) and self.design_id == other.design_id + + +class FakePlot: + """It's only a model.""" + + def __init__(self, design: FakeDesign) -> None: + self.design = design + + def __eq__(self, other: object) -> bool: + return isinstance(other, type(self)) and self.design == other.design + + +class FakeProblem(Problem[FakeDesign]): + def __init__(self, problem_id: int, *, some_arg: bool) -> None: + self.problem_id = problem_id + self.some_arg = some_arg + + def simulate(self, design: FakeDesign, config: dict[str, Any] | None = None, **kwargs) -> NDArray[np.float64]: + offset = (config or {})["offset"] + return np.array([design.design_id + offset], dtype=np.float64) + + def optimize( + self, starting_point: FakeDesign, config: dict[str, Any] | None = None + ) -> tuple[FakeDesign, list[OptiStep]]: + return starting_point, [] + + def render(self, design: FakeDesign, *, open_window: bool = False) -> FakePlot: + return FakePlot(design) + + +FAKE_SBATCH = os.path.join( + os.path.dirname(__file__), + "..", + "tools", + "fake_sbatch.py", +) + + +def find_real_sbatch() -> list[str]: + try: + if ( + subprocess.run( + ["sbatch", "--help"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=False + ).returncode + == 0 + ): + return ["sbatch"] + except FileNotFoundError: + pass + return [] + + +def job_render(problem_id: int, design_id: int) -> FakePlot: + p = FakeProblem(problem_id, some_arg=True) + design, _ = p.optimize(FakeDesign(design_id)) + return p.render(design) + + +def job_simulate(problem_id: int, design_id: int) -> NDArray[np.float64]: + p = FakeProblem(problem_id, some_arg=True) + design, _ = p.optimize(FakeDesign(design_id)) + return p.simulate(design, {"offset": 10}) + + +def a_function() -> bool: + return True + + +class AClass: + def __init__(self, x: float) -> None: + self.x = x + + def a_method(self) -> float: + return self.x + + +@pytest.mark.parametrize("sbatch_exec", [FAKE_SBATCH, *find_real_sbatch()]) +def test_sbatch_map_save(sbatch_exec: str) -> None: + """Test if a fake slurm can process FakeProblem.""" + + slurm_args = slurm.SlurmConfig(sbatch_executable=sbatch_exec) + slurm.sbatch_map( + job_render, + args=[{"problem_id": 1, "design_id": -1}, {"problem_id": 2, "design_id": -2}, {"problem_id": 3, "design_id": -3}], + slurm_args=slurm_args, + ).save("results.pkl", slurm_args=slurm_args) + results = slurm.load_results() + os.remove("results.pkl") + for result in results: + if isinstance(result, slurm.JobError): + raise result + assert results == [FakePlot(FakeDesign(-1)), FakePlot(FakeDesign(-2)), FakePlot(FakeDesign(-3))] + + +def f_reduce(results: list[NDArray[np.float64]]) -> float: + return sum(v[0] for v in results) + + +@pytest.mark.parametrize("sbatch_exec", [FAKE_SBATCH, *find_real_sbatch()]) +def test_sbatch_map_reduce(sbatch_exec: str) -> None: + """Test if a fake slurm can process FakeProblem.""" + + slurm_args = slurm.SlurmConfig(sbatch_executable=sbatch_exec) + result = slurm.sbatch_map( + job_simulate, + args=[{"problem_id": 1, "design_id": -1}, {"problem_id": 2, "design_id": -2}, {"problem_id": 3, "design_id": -3}], + slurm_args=slurm_args, + ).reduce(f_reduce, slurm_args=slurm_args) + expected_result = 24.0 + assert result == expected_result From 5bb5622892cb9fa574919bda1cd95a6cea30b728 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gerhard=20Br=C3=A4unlich?= Date: Thu, 10 Jul 2025 10:55:06 +0200 Subject: [PATCH 3/5] tests/utils/test_slurm.py: Apply ruff lint eq-without-hash (PLW1641) --- tests/utils/test_slurm.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/utils/test_slurm.py b/tests/utils/test_slurm.py index 81a1ebd9..b4eeb088 100644 --- a/tests/utils/test_slurm.py +++ b/tests/utils/test_slurm.py @@ -40,6 +40,9 @@ def __init__(self, design_id: int) -> None: def __eq__(self, other: object) -> bool: return isinstance(other, type(self)) and self.design_id == other.design_id + def __hash__(self): + return hash(self.name) + class FakePlot: """It's only a model.""" @@ -50,6 +53,9 @@ def __init__(self, design: FakeDesign) -> None: def __eq__(self, other: object) -> bool: return isinstance(other, type(self)) and self.design == other.design + def __hash__(self): + return hash(self.name) + class FakeProblem(Problem[FakeDesign]): def __init__(self, problem_id: int, *, some_arg: bool) -> None: From 1ca9792773f92041567f19ec98d948f690c9c8c1 Mon Sep 17 00:00:00 2001 From: fgvangessel-umd Date: Thu, 24 Jul 2025 10:34:48 +0200 Subject: [PATCH 4/5] utils.slurm.run_job: Fix array size and group end index --- engibench/utils/slurm/run_job.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/engibench/utils/slurm/run_job.py b/engibench/utils/slurm/run_job.py index bf544ff4..0fa390de 100644 --- a/engibench/utils/slurm/run_job.py +++ b/engibench/utils/slurm/run_job.py @@ -18,7 +18,7 @@ def map_job_group(work_dir: str, n_jobs: int) -> None: start = int(os.environ["SLURM_ARRAY_TASK_MIN"]) stop = int(os.environ["SLURM_ARRAY_TASK_MAX"]) current = int(os.environ["SLURM_ARRAY_TASK_ID"]) - start - array_size = stop - start + array_size = stop - start + 1 group_size = n_jobs // array_size + (1 if n_jobs % array_size else 0) try: with open(os.path.join(work_dir, "jobs", "map_callback.pkl"), "rb") as in_stream: @@ -30,7 +30,7 @@ def map_callback(**_kwargs) -> None: raise exception # Run `group_size` jobs as sub jobs of current job (usecase: many small jobs): - for index in range(current * group_size, max((current + 1) * group_size, n_jobs)): + for index in range(current * group_size, min((current + 1) * group_size, n_jobs)): result_path = os.path.join(work_dir, "results", f"{index}.pkl") try: with open(os.path.join(work_dir, "jobs", f"{index}.pkl"), "rb") as stream: From 6a8339ea23bdf16d5eb6f86f991ee6ac56a2af60 Mon Sep 17 00:00:00 2001 From: fgvangessel-umd Date: Thu, 24 Jul 2025 10:35:26 +0200 Subject: [PATCH 5/5] utils.container: Support both apptainer and singularity --- engibench/utils/container.py | 67 +++++++++++++++++++----------------- 1 file changed, 35 insertions(+), 32 deletions(-) diff --git a/engibench/utils/container.py b/engibench/utils/container.py index 875f9dce..ef2f3ffb 100644 --- a/engibench/utils/container.py +++ b/engibench/utils/container.py @@ -217,12 +217,28 @@ def is_available(cls) -> bool: return False +DOCKER_PREFIX = "docker://" + + class Singularity(ContainerRuntime): - """Singularity / Apptainer.""" + """Singularity.""" name = "singularity" executable = "singularity" + @classmethod + def sif_filename(cls, image: str) -> str: + """Construct the sif filename from an image specifier.""" + # Extract just the image part if it's a docker URI + image = image.removeprefix(DOCKER_PREFIX) + + # Parse the image name to match Singularity's naming convention + # For "mdolab/public:u22-gcc-ompi-stable", Singularity creates "public_u22-gcc-ompi-stable.sif" + image_name = image.rsplit("/", 1)[-1] if "/" in image else image + + # Replace ":" with "_" in the image name + return image_name.replace(":", "_") + ".sif" + @classmethod def pull(cls, image: str) -> None: """Pull an image. @@ -230,27 +246,15 @@ def pull(cls, image: str) -> None: Args: image: Container image to pull. """ - # Convert to docker URI if needed - if "://" not in image: - docker_uri = "docker://" + image - else: - docker_uri = image - # Extract just the image part if it's already a docker URI - if docker_uri.startswith("docker://"): - image = docker_uri[len("docker://") :] - - # Parse the image name to match Singularity's naming convention - # For "mdolab/public:u22-gcc-ompi-stable", Singularity creates "public_u22-gcc-ompi-stable.sif" - image_name = image.split("/")[-1] if "/" in image else image - - # Replace ":" with "_" in the image name - sif_filename = image_name.replace(":", "_") + ".sif" + # Get sif filename + sif_filename = cls.sif_filename(image) # Check if the image already exists if os.path.exists(sif_filename): print(f"Image file already exists: {sif_filename} - skipping pull") return - + # Convert to docker URI if needed + docker_uri = DOCKER_PREFIX + image if "://" not in image else image # Image doesn't exist, proceed with pull subprocess.run([cls.executable, "pull", docker_uri], check=True) @@ -272,24 +276,16 @@ def run( env: Mapping of environment variable names and values to set inside the container. name: Optional name for the container (not supported by all runtimes). """ - # Create a mutable working copy to add required system mounts - working_mounts = list(mounts) + # Get sif filename + sif_image = cls.sif_filename(image) # HPC/Singularity containers require explicit /tmp mounting to prevent memory issues # and ensure application compatibility. This is container configuration, not insecure temp file creation. - if working_mounts: # Only add /tmp mount if we have existing mounts - # Use the first mount's host path for /tmp (existing logic) - tmp_host_path = working_mounts[0][0] - working_mounts.append((tmp_host_path, "/tmp")) # noqa: S108 - else: - # Handle the empty mounts case - perhaps use a default temp directory - # or skip the /tmp mount altogether - pass - - mount_args = (["--mount", f"type=bind,src={src},target={target}"] for src, target in working_mounts) + + # Reconstruct mount and env args + mount_args = (["--mount", f"type=bind,src={src},target={target}"] for src, target in mounts) env_args = (["--env", f"{var}={value}"] for var, value in (env or {}).items()) - if "://" not in image: - image = "docker://" + image + return subprocess.run( [ cls.executable, @@ -297,13 +293,20 @@ def run( "--compat", *(arg for args in mount_args for arg in args), *(arg for args in env_args for arg in args), - image, + sif_image, *command, ], check=False, ) +class Apptainer(Singularity): + """Apptainer.""" + + name = "apptainer" + executable = "apptainer" + + RUNTIMES = [ rt for rt in globals().values()