Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 60 additions & 31 deletions src/extremeweatherbench/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ def __init__(
)
self.evaluation_objects = evaluation_objects
self.cache_dir = pathlib.Path(cache_dir) if cache_dir else None

# Instantiate cache dir if needed
if self.cache_dir:
if not self.cache_dir.exists():
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.region_subsetter = region_subsetter

# Case operators as a property can be used as a convenience method for a workflow
Expand Down Expand Up @@ -107,40 +112,20 @@ def run(
Ignored if parallel_config is provided.
parallel_config: Optional dictionary of joblib parallel configuration.
If provided, this takes precedence over n_jobs. If not provided and
n_jobs is specified, a default config with threading backend is used.
n_jobs is specified, a default config with loky backend is used.

Returns:
A concatenated dataframe of the evaluation results.
"""
logger.info("Running ExtremeWeatherBench workflow...")

# Determine if running in serial or parallel mode
# Serial: n_jobs=1 or (parallel_config with n_jobs=1)
# Parallel: n_jobs>1 or (parallel_config with n_jobs>1)
is_serial = (
(n_jobs == 1)
or (parallel_config is not None and parallel_config.get("n_jobs") == 1)
or (n_jobs is None and parallel_config is None)
)
logger.debug("Running in %s mode.", "serial" if is_serial else "parallel")

if not is_serial:
# Build parallel_config if not provided
if parallel_config is None and n_jobs is not None:
logger.debug(
"No parallel_config provided, using threading backend and %s jobs.",
n_jobs,
)
parallel_config = {"backend": "threading", "n_jobs": n_jobs}
kwargs["parallel_config"] = parallel_config
else:
# Running in serial mode - instantiate cache dir if needed
if self.cache_dir:
if not self.cache_dir.exists():
self.cache_dir.mkdir(parents=True, exist_ok=True)

# Check for serial or parallel configuration
parallel_config = _parallel_serial_config_check(n_jobs, parallel_config)
run_results = _run_case_operators(
self.case_operators, cache_dir=self.cache_dir, **kwargs
self.case_operators,
cache_dir=self.cache_dir,
parallel_config=parallel_config,
**kwargs,
)

# If there are results, concatenate them and return, else return an empty
Expand All @@ -154,6 +139,7 @@ def run(
def _run_case_operators(
case_operators: list["cases.CaseOperator"],
cache_dir: Optional[pathlib.Path] = None,
parallel_config: Optional[dict] = None,
**kwargs,
) -> list[pd.DataFrame]:
"""Run the case operators in parallel or serial.
Expand All @@ -167,13 +153,15 @@ def _run_case_operators(
List of result DataFrames.
"""
with logging_redirect_tqdm():
# Check if parallel_config is provided
parallel_config = kwargs.get("parallel_config", None)

# Run in parallel if parallel_config exists and n_jobs != 1
if parallel_config is not None:
logger.info("Running case operators in parallel...")
return _run_parallel(case_operators, cache_dir=cache_dir, **kwargs)
return _run_parallel(
case_operators,
cache_dir=cache_dir,
parallel_config=parallel_config,
**kwargs,
)
else:
logger.info("Running case operators in serial...")
return _run_serial(case_operators, cache_dir=cache_dir, **kwargs)
Expand Down Expand Up @@ -910,3 +898,44 @@ def _safe_concat(
return pd.concat(valid_dfs, ignore_index=ignore_index)
else:
return pd.DataFrame(columns=OUTPUT_COLUMNS)


def _parallel_serial_config_check(
n_jobs: Optional[int] = None,
parallel_config: Optional[dict] = None,
) -> Optional[dict]:
"""Check if running in serial or parallel mode.

Args:
n_jobs: The number of jobs to run in parallel. If None, defaults to the
joblib backend default value. If 1, the workflow will run serially.
parallel_config: Optional dictionary of joblib parallel configuration. If
provided, this takes precedence over n_jobs. If not provided and n_jobs is
specified, a default config with loky backend is used.
Returns:
None if running in serial mode, otherwise a dictionary of joblib parallel
configuration.
"""
# Determine if running in serial or parallel mode
# Serial: n_jobs=1 or (parallel_config with n_jobs=1)
# Parallel: n_jobs>1 or (parallel_config with n_jobs>1)
is_serial = (
(n_jobs == 1)
or (parallel_config is not None and parallel_config.get("n_jobs") == 1)
or (n_jobs is None and parallel_config is None)
)
logger.debug("Running in %s mode.", "serial" if is_serial else "parallel")

if not is_serial:
# Build parallel_config if not provided
if parallel_config is None and n_jobs is not None:
logger.debug(
"No parallel_config provided, using loky backend and %s jobs.",
n_jobs,
)
parallel_config = {"backend": "loky", "n_jobs": n_jobs}
# If running in serial mode, set parallel_config to None if not already
else:
parallel_config = None
# Return the maybe updated kwargs
return parallel_config
58 changes: 56 additions & 2 deletions tests/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,10 +366,11 @@ def test_run_serial(

result = ewb.run(n_jobs=1)

# Serial mode should not pass parallel_config
# Serial mode passes parallel_config=None
mock_run_case_operators.assert_called_once_with(
[sample_case_operator],
cache_dir=None,
parallel_config=None,
)
assert isinstance(result, pd.DataFrame)
assert len(result) == 1
Expand Down Expand Up @@ -407,7 +408,7 @@ def test_run_parallel(
mock_run_case_operators.assert_called_once_with(
[sample_case_operator],
cache_dir=None,
parallel_config={"backend": "threading", "n_jobs": 2},
parallel_config={"backend": "loky", "n_jobs": 2},
)
assert isinstance(result, pd.DataFrame)
assert len(result) == 1
Expand Down Expand Up @@ -2725,5 +2726,58 @@ def test_compute_case_operator_no_output_variables(self, sample_individual_case)
assert all(result["target_variable"] == "MockDerivedVariableWithOutputs")


class TestParallelSerialConfigCheck:
def test_parallel_serial_config_check_serial(self):
"""Test that the parallel_serial_config_check returns None for serial mode.

If n_jobs == 1 in any of the arguments, parallel_config should always be
None."""
assert evaluate._parallel_serial_config_check(n_jobs=1) is None
assert (
evaluate._parallel_serial_config_check(parallel_config={"n_jobs": 1})
is None
)
assert (
evaluate._parallel_serial_config_check(
n_jobs=None, parallel_config={"n_jobs": 1}
)
is None
)
assert (
evaluate._parallel_serial_config_check(
n_jobs=None, parallel_config={"n_jobs": 1}
)
is None
)
assert (
evaluate._parallel_serial_config_check(
n_jobs=None, parallel_config={"backend": "threading", "n_jobs": 1}
)
is None
)

def test_parallel_serial_config_check_parallel(self):
"""Test that the parallel_serial_config_check returns a dictionary for parallel mode."""
assert evaluate._parallel_serial_config_check(n_jobs=2) == {
"backend": "loky",
"n_jobs": 2,
}
assert evaluate._parallel_serial_config_check(
parallel_config={"backend": "threading", "n_jobs": 2}
) == {"backend": "threading", "n_jobs": 2}
assert evaluate._parallel_serial_config_check(
n_jobs=2, parallel_config={"backend": "threading", "n_jobs": 2}
) == {"backend": "threading", "n_jobs": 2}
assert evaluate._parallel_serial_config_check(
n_jobs=2, parallel_config={"backend": "threading", "n_jobs": 2}
) == {"backend": "threading", "n_jobs": 2}
assert evaluate._parallel_serial_config_check(
n_jobs=2, parallel_config={"backend": "threading", "n_jobs": 2}
) == {"backend": "threading", "n_jobs": 2}
assert evaluate._parallel_serial_config_check(
n_jobs=2, parallel_config={"backend": "threading", "n_jobs": 2}
) == {"backend": "threading", "n_jobs": 2}


if __name__ == "__main__":
pytest.main([__file__])
Loading