Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
timeout-minutes: 30
strategy:
matrix:
python-version: ['3.10', '3.11', '3.12']
python-version: ['3.10', '3.11', '3.12', '3.13']
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -38,7 +38,7 @@ jobs:
timeout-minutes: 30
strategy:
matrix:
python-version: ['3.12']
python-version: ['3.13']
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
Expand Down Expand Up @@ -72,7 +72,7 @@ jobs:
timeout-minutes: 30
strategy:
matrix:
python-version: ['3.12']
python-version: ['3.13']
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
Expand Down
16 changes: 2 additions & 14 deletions docs/utils/slurm.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,23 +73,11 @@ reduced = job.reduce(list, slurm_args=slurm_args) # Collect all results to a lis

# To save resources, to render the docs no actual optimization is performed.
# Instead optimize() is replaced by a method returning zeros:
print(reduced)
print(reduced) # doctest: +ELLIPSIS
```

```{testoutput} slurm
[array([[0., ..., 0.],
...,
[0., ..., 0.]]), array([[0., ..., 0.],
...,
[0., ..., 0.]]), array([[0., ..., 0.],
...,
[0., ..., 0.]]), array([[0., ..., 0.],
...,
[0., ..., 0.]]), array([[0., ..., 0.],
...,
[0., ..., 0.]]), array([[0., ..., 0.],
...,
[0., ..., 0.]])]
[array([[0., ..., 0.]], shape=(100, 50)), ...]
```

**Step 4b:** Save all results in one [pickle](https://docs.python.org/3/library/pickle.html) archive
Expand Down
21 changes: 20 additions & 1 deletion engibench/utils/slurm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import subprocess
import sys
import tempfile
from traceback import StackSummary
from traceback import TracebackException
from typing import Any, Generic, TypeVar

Expand All @@ -24,7 +25,7 @@ class JobError(Exception):
- :attr:`origin` - Original exception instance.
- :attr:`context` - Info (string) about which step failed (e.g. map, reduce or save).
- :attr:`job_args` - dict containing the arguments passed to the job callback if the exception occurred during a job.
- :attr:`traceback` - `TracebackException <https://docs.python.org/3/library/traceback.html#traceback.TracebackException>`__ object.
- :attr:`traceback` - `TracebackException <https://docs.python.org/3/library/traceback.html#traceback.TracebackException>`_ object.
"""

def __init__(self, origin: Exception, context: str, job_args: dict[str, Any]) -> None:
Expand All @@ -42,6 +43,24 @@ def __str__(self) -> str:
"""


def dump_with_job_error(obj: Any, path: str) -> None:
"""Pickle objects to a file which might contain a :py:class:`JobError` instance."""
with open(path, "wb") as stream:
pickler = TracebackPickler(stream)
pickler.dump(obj)


class TracebackPickler(pickle.Pickler):
"""Custom pickler to avoid pickling code objects when pickling tracebacks."""

def reducer_override(self, obj):
"""Custom reducer for StackSummary."""
if isinstance(obj, StackSummary):
return StackSummary.from_list, ([(s.filename, s.lineno, s.name, s.line) for s in obj],)
# For any other object, fallback to usual reduction
return NotImplemented


if sys.version_info < (3, 11):

class ExceptionGroup(Exception): # noqa: N818
Expand Down
21 changes: 10 additions & 11 deletions engibench/utils/slurm/run_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import sys
from typing import Any

from engibench.utils.slurm import dump_with_job_error
from engibench.utils.slurm import JobError
from engibench.utils.slurm import MemorizeModule

Expand Down Expand Up @@ -40,32 +41,31 @@ def map_callback(**_kwargs) -> None:
with open(os.path.join(work_dir, "jobs", f"{index}.pkl"), "rb") as stream:
args = pickle.load(stream)
except Exception as e: # noqa: BLE001
with open(result_path, "wb") as out_stream:
pickle.dump(JobError(e, "Unpickle job array item", {}), out_stream)
dump_with_job_error(JobError(e, "Unpickle job array item", {}), result_path)
continue
try:
result = map_callback(**args)
with open(result_path, "wb") as out_stream:
pickle.dump(MemorizeModule(result), out_stream)
dump_with_job_error(MemorizeModule(result), result_path)
except Exception as e: # noqa: BLE001
with open(result_path, "wb") as out_stream:
pickle.dump(JobError(e, "Run job array item", args), out_stream)
dump_with_job_error(JobError(e, "Run job array item", args), result_path)


def reduce_job_results(work_dir: str, n_jobs: int) -> None:
"""Collect all results or errors from job array jobs, passing to a reduce callback."""
results = [] # prepare empty list for error, occurring before `results` is assigned a value
reduced_pkl = os.path.join(work_dir, "reduced.pkl")
try:
with open(os.path.join(work_dir, "jobs", "reduce.pkl"), "rb") as in_stream:
reduce_callback = pickle.load(in_stream)

results = collect_jobs(work_dir, n_jobs)
reduced = reduce_callback(results)
dump_with_job_error(MemorizeModule(reduced), reduced_pkl)
except Exception as e: # noqa: BLE001
errors = [e] + [err for err in results if isinstance(err, Exception)]
reduced = JobError(ExceptionGroup("", errors), "reduce", {}) if errors else JobError(e, "reduce", {})
with open(os.path.join(work_dir, "reduced.pkl"), "wb") as out_stream:
pickle.dump(MemorizeModule(reduced), out_stream)
dump_with_job_error(
JobError(ExceptionGroup("", errors), "reduce", {}) if errors else JobError(e, "reduce", {}), reduced_pkl
)


def save(work_dir: str, n_jobs: int, out: str) -> None:
Expand All @@ -74,8 +74,7 @@ def save(work_dir: str, n_jobs: int, out: str) -> None:

if not any(isinstance(r, JobError) for r in results):
shutil.rmtree(work_dir)
with open(out, "wb") as out_stream:
pickle.dump(results, out_stream)
dump_with_job_error(results, out)


def collect_jobs(work_dir: str, n_jobs: int) -> list[Any]:
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ classifiers = [
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
'Intended Audience :: Science/Research',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
]
dependencies = [
"numpy <=2.0",
"numpy",
"gymnasium >= 1.0.0",
"datasets[vision] >= 3.1.0", # imports modules with image features
"pandas >= 2.2.3",
Expand Down
Loading