From 1b7c693872d379f15e2d848a5859de9cc18e03f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gerhard=20Br=C3=A4unlich?= Date: Fri, 22 Aug 2025 08:55:44 +0200 Subject: [PATCH] utils.slurm: Allow serializing / deserializing from __main__ --- engibench/utils/slurm/__init__.py | 36 +++++++++++++++++++++++++------ engibench/utils/slurm/run_job.py | 1 + 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/engibench/utils/slurm/__init__.py b/engibench/utils/slurm/__init__.py index 6eaa5813..7028d351 100644 --- a/engibench/utils/slurm/__init__.py +++ b/engibench/utils/slurm/__init__.py @@ -4,6 +4,7 @@ from dataclasses import asdict from dataclasses import dataclass from dataclasses import field +import importlib import itertools import os import pickle @@ -14,6 +15,8 @@ from traceback import TracebackException from typing import Any, Generic, TypeVar +import __main__ + class JobError(Exception): """User error happening during execution of a slurm job. @@ -248,17 +251,24 @@ def __init__(self, obj: Any) -> None: self.obj = obj @staticmethod - def _reconstruct(reduced_module: str | None, pickled_obj: bytes) -> Any: - if reduced_module is not None and reduced_module not in sys.path: - sys.path.append(reduced_module) + def _reconstruct(reduced_module: tuple[str, ...] | None, pickled_obj: bytes) -> Any: + mod_path, *modules = reduced_module or (None,) + if mod_path is not None and mod_path not in sys.path: + sys.path.append(mod_path) + if modules: # obj was pickled from __main__ + (mod_name, obj_name) = modules + check_main_guard(mod_path, mod_name) + mod = importlib.import_module(mod_name) + obj = getattr(mod, obj_name) + setattr(__main__, obj_name, obj) return pickle.loads(pickled_obj) - def __reduce__(self) -> tuple[Callable[..., Any], tuple[str | None, bytes]]: + def __reduce__(self) -> tuple[Callable[..., Any], tuple[tuple[str, ...] | None, bytes]]: pickled_obj = pickle.dumps(self.obj) return (self._reconstruct, (module_path(self.obj), pickled_obj)) -def module_path(obj: Any) -> str | None: +def module_path(obj: Any) -> tuple[str, ...] | None: """Return the path of the toplevel module of the module containing `obj`.""" if not hasattr(obj, "__module__"): return None @@ -269,6 +279,20 @@ def module_path(obj: Any) -> str | None: if path is None: msg = "Got a module without path" raise RuntimeError(msg) + if top_level_module == "__main__": + return os.path.dirname(path), os.path.basename(path).removesuffix(".py"), obj.__name__ if os.path.basename(path) == "__init__.py": path = os.path.dirname(path) - return os.path.dirname(path) + return (os.path.dirname(path),) + + +def check_main_guard(path, mod_name): + """Check that a python main script as a __main__ guard. + + If this is not the case and pickle tries to load from that file, + this will run the whole script during unpickling. + """ + with open(os.path.join(path, mod_name) + ".py") as stream: + content = stream.read() + if 'if __name__ == "__main__"' not in content and "if __name__ == '__main__'" not in content: + raise RuntimeError("Main script does not have a __main__ guard. This will lead to infinite recurson") diff --git a/engibench/utils/slurm/run_job.py b/engibench/utils/slurm/run_job.py index 3df2fa33..6d9b00e6 100644 --- a/engibench/utils/slurm/run_job.py +++ b/engibench/utils/slurm/run_job.py @@ -54,6 +54,7 @@ def map_callback(**_kwargs) -> None: def reduce_job_results(work_dir: str, n_jobs: int) -> None: """Collect all results or errors from job array jobs, passing to a reduce callback.""" + results = [] # prepare empty list for error, occurring before `results` is assigned a value try: with open(os.path.join(work_dir, "jobs", "reduce.pkl"), "rb") as in_stream: reduce_callback = pickle.load(in_stream)