From 9650bc5e9b94634e1f6d2991488bf65d293ecbb9 Mon Sep 17 00:00:00 2001 From: taylor Date: Tue, 13 Jan 2026 15:24:39 +0000 Subject: [PATCH 1/4] move function out of run, move cache mkdir to init --- src/extremeweatherbench/evaluate.py | 76 ++++++++++++++++++----------- 1 file changed, 48 insertions(+), 28 deletions(-) diff --git a/src/extremeweatherbench/evaluate.py b/src/extremeweatherbench/evaluate.py index df36591b..739c8894 100644 --- a/src/extremeweatherbench/evaluate.py +++ b/src/extremeweatherbench/evaluate.py @@ -72,6 +72,11 @@ def __init__( ) self.evaluation_objects = evaluation_objects self.cache_dir = pathlib.Path(cache_dir) if cache_dir else None + + # Instantiate cache dir if needed + if self.cache_dir: + if not self.cache_dir.exists(): + self.cache_dir.mkdir(parents=True, exist_ok=True) self.region_subsetter = region_subsetter # Case operators as a property can be used as a convenience method for a workflow @@ -114,31 +119,8 @@ def run( """ logger.info("Running ExtremeWeatherBench workflow...") - # Determine if running in serial or parallel mode - # Serial: n_jobs=1 or (parallel_config with n_jobs=1) - # Parallel: n_jobs>1 or (parallel_config with n_jobs>1) - is_serial = ( - (n_jobs == 1) - or (parallel_config is not None and parallel_config.get("n_jobs") == 1) - or (n_jobs is None and parallel_config is None) - ) - logger.debug("Running in %s mode.", "serial" if is_serial else "parallel") - - if not is_serial: - # Build parallel_config if not provided - if parallel_config is None and n_jobs is not None: - logger.debug( - "No parallel_config provided, using threading backend and %s jobs.", - n_jobs, - ) - parallel_config = {"backend": "threading", "n_jobs": n_jobs} - kwargs["parallel_config"] = parallel_config - else: - # Running in serial mode - instantiate cache dir if needed - if self.cache_dir: - if not self.cache_dir.exists(): - self.cache_dir.mkdir(parents=True, exist_ok=True) - + # Check for serial or parallel configuration + parallel_config = _parallel_serial_config_check(n_jobs, parallel_config) run_results = _run_case_operators( self.case_operators, cache_dir=self.cache_dir, **kwargs ) @@ -154,6 +136,7 @@ def run( def _run_case_operators( case_operators: list["cases.CaseOperator"], cache_dir: Optional[pathlib.Path] = None, + parallel_config: Optional[dict] = None, **kwargs, ) -> list[pd.DataFrame]: """Run the case operators in parallel or serial. @@ -167,9 +150,6 @@ def _run_case_operators( List of result DataFrames. """ with logging_redirect_tqdm(): - # Check if parallel_config is provided - parallel_config = kwargs.get("parallel_config", None) - # Run in parallel if parallel_config exists and n_jobs != 1 if parallel_config is not None: logger.info("Running case operators in parallel...") @@ -910,3 +890,43 @@ def _safe_concat( return pd.concat(valid_dfs, ignore_index=ignore_index) else: return pd.DataFrame(columns=OUTPUT_COLUMNS) + +def _parallel_serial_config_check( + n_jobs: Optional[int] = None, + parallel_config: Optional[dict] = None, +) -> Optional[dict]: + """Check if running in serial or parallel mode. + + Args: + n_jobs: The number of jobs to run in parallel. If None, defaults to the + joblib backend default value. If 1, the workflow will run serially. + parallel_config: Optional dictionary of joblib parallel configuration. If + provided, this takes precedence over n_jobs. If not provided and n_jobs is + specified, a default config with loky backend is used. + Returns: + None if running in serial mode, otherwise a dictionary of joblib parallel + configuration. + """ + # Determine if running in serial or parallel mode + # Serial: n_jobs=1 or (parallel_config with n_jobs=1) + # Parallel: n_jobs>1 or (parallel_config with n_jobs>1) + is_serial = ( + (n_jobs == 1) + or (parallel_config is not None and parallel_config.get("n_jobs") == 1) + or (n_jobs is None and parallel_config is None) + ) + logger.debug("Running in %s mode.", "serial" if is_serial else "parallel") + + if not is_serial: + # Build parallel_config if not provided + if parallel_config is None and n_jobs is not None: + logger.debug( + "No parallel_config provided, using loky backend and %s jobs.", + n_jobs, + ) + parallel_config = {"backend": "loky", "n_jobs": n_jobs} + # If running in serial mode, set parallel_config to None if not already + else: + parallel_config = None + # Return the maybe updated kwargs + return parallel_config \ No newline at end of file From 1fedf89e59621aa2ef63ac1bc6811f2fdc1c5f4a Mon Sep 17 00:00:00 2001 From: taylor Date: Tue, 13 Jan 2026 15:25:15 +0000 Subject: [PATCH 2/4] add tests for new func --- tests/test_evaluate.py | 52 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/tests/test_evaluate.py b/tests/test_evaluate.py index eca58b63..3bc18fa0 100644 --- a/tests/test_evaluate.py +++ b/tests/test_evaluate.py @@ -2724,6 +2724,58 @@ def test_compute_case_operator_no_output_variables(self, sample_individual_case) assert len(result) == 10 assert all(result["target_variable"] == "MockDerivedVariableWithOutputs") +class TestParallelSerialConfigCheck: + def test_parallel_serial_config_check_serial(self): + """Test that the parallel_serial_config_check returns None for serial mode. + + If n_jobs == 1 in any of the arguments, parallel_config should always be + None.""" + assert evaluate._parallel_serial_config_check(n_jobs=1) is None + assert ( + evaluate._parallel_serial_config_check(parallel_config={"n_jobs": 1}) + is None + ) + assert ( + evaluate._parallel_serial_config_check( + n_jobs=None, parallel_config={"n_jobs": 1} + ) + is None + ) + assert ( + evaluate._parallel_serial_config_check( + n_jobs=None, parallel_config={"n_jobs": 1} + ) + is None + ) + assert ( + evaluate._parallel_serial_config_check( + n_jobs=None, parallel_config={"backend": "threading", "n_jobs": 1} + ) + is None + ) + + def test_parallel_serial_config_check_parallel(self): + """Test that the parallel_serial_config_check returns a dictionary for parallel mode.""" + assert evaluate._parallel_serial_config_check(n_jobs=2) == { + "backend": "loky", + "n_jobs": 2, + } + assert evaluate._parallel_serial_config_check( + parallel_config={"backend": "threading", "n_jobs": 2} + ) == {"backend": "threading", "n_jobs": 2} + assert evaluate._parallel_serial_config_check( + n_jobs=2, parallel_config={"backend": "threading", "n_jobs": 2} + ) == {"backend": "threading", "n_jobs": 2} + assert evaluate._parallel_serial_config_check( + n_jobs=2, parallel_config={"backend": "threading", "n_jobs": 2} + ) == {"backend": "threading", "n_jobs": 2} + assert evaluate._parallel_serial_config_check( + n_jobs=2, parallel_config={"backend": "threading", "n_jobs": 2} + ) == {"backend": "threading", "n_jobs": 2} + assert evaluate._parallel_serial_config_check( + n_jobs=2, parallel_config={"backend": "threading", "n_jobs": 2} + ) == {"backend": "threading", "n_jobs": 2} + if __name__ == "__main__": pytest.main([__file__]) From 1db58f7d1c70678236338888a760b70b0eb8d192 Mon Sep 17 00:00:00 2001 From: taylor Date: Tue, 13 Jan 2026 15:25:27 +0000 Subject: [PATCH 3/4] ruff --- src/extremeweatherbench/evaluate.py | 5 +++-- tests/test_evaluate.py | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/extremeweatherbench/evaluate.py b/src/extremeweatherbench/evaluate.py index 739c8894..9b0edfb9 100644 --- a/src/extremeweatherbench/evaluate.py +++ b/src/extremeweatherbench/evaluate.py @@ -72,7 +72,7 @@ def __init__( ) self.evaluation_objects = evaluation_objects self.cache_dir = pathlib.Path(cache_dir) if cache_dir else None - + # Instantiate cache dir if needed if self.cache_dir: if not self.cache_dir.exists(): @@ -891,6 +891,7 @@ def _safe_concat( else: return pd.DataFrame(columns=OUTPUT_COLUMNS) + def _parallel_serial_config_check( n_jobs: Optional[int] = None, parallel_config: Optional[dict] = None, @@ -929,4 +930,4 @@ def _parallel_serial_config_check( else: parallel_config = None # Return the maybe updated kwargs - return parallel_config \ No newline at end of file + return parallel_config diff --git a/tests/test_evaluate.py b/tests/test_evaluate.py index 3bc18fa0..0c674cc4 100644 --- a/tests/test_evaluate.py +++ b/tests/test_evaluate.py @@ -2724,6 +2724,7 @@ def test_compute_case_operator_no_output_variables(self, sample_individual_case) assert len(result) == 10 assert all(result["target_variable"] == "MockDerivedVariableWithOutputs") + class TestParallelSerialConfigCheck: def test_parallel_serial_config_check_serial(self): """Test that the parallel_serial_config_check returns None for serial mode. From 2e1a0738c6a3b343e56e363027dd0aa40f1dae03 Mon Sep 17 00:00:00 2001 From: taylor Date: Tue, 13 Jan 2026 16:41:58 +0000 Subject: [PATCH 4/4] update parallel_config passthrough and tests --- src/extremeweatherbench/evaluate.py | 14 +++++++++++--- tests/test_evaluate.py | 5 +++-- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/src/extremeweatherbench/evaluate.py b/src/extremeweatherbench/evaluate.py index 9b0edfb9..ccc88004 100644 --- a/src/extremeweatherbench/evaluate.py +++ b/src/extremeweatherbench/evaluate.py @@ -112,7 +112,7 @@ def run( Ignored if parallel_config is provided. parallel_config: Optional dictionary of joblib parallel configuration. If provided, this takes precedence over n_jobs. If not provided and - n_jobs is specified, a default config with threading backend is used. + n_jobs is specified, a default config with loky backend is used. Returns: A concatenated dataframe of the evaluation results. @@ -122,7 +122,10 @@ def run( # Check for serial or parallel configuration parallel_config = _parallel_serial_config_check(n_jobs, parallel_config) run_results = _run_case_operators( - self.case_operators, cache_dir=self.cache_dir, **kwargs + self.case_operators, + cache_dir=self.cache_dir, + parallel_config=parallel_config, + **kwargs, ) # If there are results, concatenate them and return, else return an empty @@ -153,7 +156,12 @@ def _run_case_operators( # Run in parallel if parallel_config exists and n_jobs != 1 if parallel_config is not None: logger.info("Running case operators in parallel...") - return _run_parallel(case_operators, cache_dir=cache_dir, **kwargs) + return _run_parallel( + case_operators, + cache_dir=cache_dir, + parallel_config=parallel_config, + **kwargs, + ) else: logger.info("Running case operators in serial...") return _run_serial(case_operators, cache_dir=cache_dir, **kwargs) diff --git a/tests/test_evaluate.py b/tests/test_evaluate.py index 0c674cc4..fe23d95c 100644 --- a/tests/test_evaluate.py +++ b/tests/test_evaluate.py @@ -366,10 +366,11 @@ def test_run_serial( result = ewb.run(n_jobs=1) - # Serial mode should not pass parallel_config + # Serial mode passes parallel_config=None mock_run_case_operators.assert_called_once_with( [sample_case_operator], cache_dir=None, + parallel_config=None, ) assert isinstance(result, pd.DataFrame) assert len(result) == 1 @@ -407,7 +408,7 @@ def test_run_parallel( mock_run_case_operators.assert_called_once_with( [sample_case_operator], cache_dir=None, - parallel_config={"backend": "threading", "n_jobs": 2}, + parallel_config={"backend": "loky", "n_jobs": 2}, ) assert isinstance(result, pd.DataFrame) assert len(result) == 1