Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 30 additions & 6 deletions engibench/utils/slurm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dataclasses import asdict
from dataclasses import dataclass
from dataclasses import field
import importlib
import itertools
import os
import pickle
Expand All @@ -14,6 +15,8 @@
from traceback import TracebackException
from typing import Any, Generic, TypeVar

import __main__


class JobError(Exception):
"""User error happening during execution of a slurm job.
Expand Down Expand Up @@ -248,17 +251,24 @@ def __init__(self, obj: Any) -> None:
self.obj = obj

@staticmethod
def _reconstruct(reduced_module: str | None, pickled_obj: bytes) -> Any:
if reduced_module is not None and reduced_module not in sys.path:
sys.path.append(reduced_module)
def _reconstruct(reduced_module: tuple[str, ...] | None, pickled_obj: bytes) -> Any:
mod_path, *modules = reduced_module or (None,)
if mod_path is not None and mod_path not in sys.path:
sys.path.append(mod_path)
if modules: # obj was pickled from __main__
(mod_name, obj_name) = modules
check_main_guard(mod_path, mod_name)
mod = importlib.import_module(mod_name)
obj = getattr(mod, obj_name)
setattr(__main__, obj_name, obj)
return pickle.loads(pickled_obj)

def __reduce__(self) -> tuple[Callable[..., Any], tuple[str | None, bytes]]:
def __reduce__(self) -> tuple[Callable[..., Any], tuple[tuple[str, ...] | None, bytes]]:
pickled_obj = pickle.dumps(self.obj)
return (self._reconstruct, (module_path(self.obj), pickled_obj))


def module_path(obj: Any) -> str | None:
def module_path(obj: Any) -> tuple[str, ...] | None:
"""Return the path of the toplevel module of the module containing `obj`."""
if not hasattr(obj, "__module__"):
return None
Expand All @@ -269,6 +279,20 @@ def module_path(obj: Any) -> str | None:
if path is None:
msg = "Got a module without path"
raise RuntimeError(msg)
if top_level_module == "__main__":
return os.path.dirname(path), os.path.basename(path).removesuffix(".py"), obj.__name__
if os.path.basename(path) == "__init__.py":
path = os.path.dirname(path)
return os.path.dirname(path)
return (os.path.dirname(path),)


def check_main_guard(path, mod_name):
"""Check that a python main script as a __main__ guard.

If this is not the case and pickle tries to load from that file,
this will run the whole script during unpickling.
"""
with open(os.path.join(path, mod_name) + ".py") as stream:
content = stream.read()
if 'if __name__ == "__main__"' not in content and "if __name__ == '__main__'" not in content:
raise RuntimeError("Main script does not have a __main__ guard. This will lead to infinite recurson")
1 change: 1 addition & 0 deletions engibench/utils/slurm/run_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def map_callback(**_kwargs) -> None:

def reduce_job_results(work_dir: str, n_jobs: int) -> None:
"""Collect all results or errors from job array jobs, passing to a reduce callback."""
results = [] # prepare empty list for error, occurring before `results` is assigned a value
try:
with open(os.path.join(work_dir, "jobs", "reduce.pkl"), "rb") as in_stream:
reduce_callback = pickle.load(in_stream)
Expand Down
Loading