From e780431382529f3e4f5f95712f69922be653a085 Mon Sep 17 00:00:00 2001 From: taylor Date: Mon, 12 Jan 2026 22:11:54 +0000 Subject: [PATCH 01/13] move cache dir creation to init, rename funcs, add parallel/serial check function, update test names --- src/extremeweatherbench/evaluate.py | 117 +++++++----- tests/test_evaluate.py | 282 +++++++++++++++------------- 2 files changed, 223 insertions(+), 176 deletions(-) diff --git a/src/extremeweatherbench/evaluate.py b/src/extremeweatherbench/evaluate.py index df36591b..8bd98ef4 100644 --- a/src/extremeweatherbench/evaluate.py +++ b/src/extremeweatherbench/evaluate.py @@ -72,6 +72,11 @@ def __init__( ) self.evaluation_objects = evaluation_objects self.cache_dir = pathlib.Path(cache_dir) if cache_dir else None + + # Instantiate cache dir if needed + if self.cache_dir: + if not self.cache_dir.exists(): + self.cache_dir.mkdir(parents=True, exist_ok=True) self.region_subsetter = region_subsetter # Case operators as a property can be used as a convenience method for a workflow @@ -89,17 +94,17 @@ def case_operators(self) -> list["cases.CaseOperator"]: subset_collection = self.case_metadata return cases.build_case_operators(subset_collection, self.evaluation_objects) - def run( + def run_evaluation( self, n_jobs: Optional[int] = None, parallel_config: Optional[dict] = None, **kwargs, ) -> pd.DataFrame: - """Runs the ExtremeWeatherBench workflow. + """Runs the ExtremeWeatherBench evaluation workflow. - This method will run the workflow in the order of the case operators, optionally - caching the mid-flight outputs of the workflow if cache_dir was provided for - serial runs. + This method will run the evaluation workflow in the order of the case operators, + optionally caching the mid-flight outputs of the workflow if cache_dir was + provided for serial runs. Args: n_jobs: The number of jobs to run in parallel. If None, defaults to the @@ -112,35 +117,13 @@ def run( Returns: A concatenated dataframe of the evaluation results. """ - logger.info("Running ExtremeWeatherBench workflow...") - - # Determine if running in serial or parallel mode - # Serial: n_jobs=1 or (parallel_config with n_jobs=1) - # Parallel: n_jobs>1 or (parallel_config with n_jobs>1) - is_serial = ( - (n_jobs == 1) - or (parallel_config is not None and parallel_config.get("n_jobs") == 1) - or (n_jobs is None and parallel_config is None) - ) - logger.debug("Running in %s mode.", "serial" if is_serial else "parallel") - - if not is_serial: - # Build parallel_config if not provided - if parallel_config is None and n_jobs is not None: - logger.debug( - "No parallel_config provided, using threading backend and %s jobs.", - n_jobs, - ) - parallel_config = {"backend": "threading", "n_jobs": n_jobs} - kwargs["parallel_config"] = parallel_config - else: - # Running in serial mode - instantiate cache dir if needed - if self.cache_dir: - if not self.cache_dir.exists(): - self.cache_dir.mkdir(parents=True, exist_ok=True) + logger.info("Running ExtremeWeatherBench evaluations...") + + # Check for serial or parallel configuration + run_config_kwargs = _parallel_config_check(n_jobs, parallel_config, **kwargs) - run_results = _run_case_operators( - self.case_operators, cache_dir=self.cache_dir, **kwargs + run_results = _run_evaluation( + self.case_operators, cache_dir=self.cache_dir, **run_config_kwargs ) # If there are results, concatenate them and return, else return an empty @@ -151,7 +134,54 @@ def run( return pd.DataFrame(columns=OUTPUT_COLUMNS) -def _run_case_operators( +def _parallel_config_check( + n_jobs: Optional[int] = None, + parallel_config: Optional[dict] = None, + **kwargs, +) -> dict: + """Build the run configuration. + + Builds the run configuration for EWB workflows depending on the configuration + provided via arguments. + + Args: + n_jobs: The number of jobs to run in parallel. If None, defaults to the + joblib backend default value. If 1, the workflow will run serially. + Ignored if parallel_config is provided. + parallel_config: Optional dictionary of joblib parallel configuration. + If provided, this takes precedence over n_jobs. If not provided and + n_jobs is specified, a default config with loky backend is used. + cache_dir: Optional directory for caching (serial mode only). + **kwargs: Additional arguments, may include 'parallel_config' dict. + + Returns: + Maybe updated kwargs if n_jobs was provided instead of parallel_config. + """ + # Determine if running in serial or parallel mode + # Serial: n_jobs=1 or (parallel_config with n_jobs=1) + # Parallel: n_jobs>1 or (parallel_config with n_jobs>1) + is_serial = ( + (n_jobs == 1) + or (parallel_config is not None and parallel_config.get("n_jobs") == 1) + or (n_jobs is None and parallel_config is None) + ) + logger.debug("Running in %s mode.", "serial" if is_serial else "parallel") + + if not is_serial: + # Build parallel_config if not provided + if parallel_config is None and n_jobs is not None: + logger.debug( + "No parallel_config provided, using loky backend and %s jobs.", + n_jobs, + ) + parallel_config = {"backend": "loky", "n_jobs": n_jobs} + kwargs["parallel_config"] = parallel_config + + # Return the maybe updated kwargs + return kwargs + + +def _run_evaluation( case_operators: list["cases.CaseOperator"], cache_dir: Optional[pathlib.Path] = None, **kwargs, @@ -173,13 +203,18 @@ def _run_case_operators( # Run in parallel if parallel_config exists and n_jobs != 1 if parallel_config is not None: logger.info("Running case operators in parallel...") - return _run_parallel(case_operators, cache_dir=cache_dir, **kwargs) + return _run_parallel_evaluation( + case_operators, + cache_dir=cache_dir, + parallel_config=parallel_config, + **kwargs, + ) else: logger.info("Running case operators in serial...") - return _run_serial(case_operators, cache_dir=cache_dir, **kwargs) + return _run_serial_evaluation(case_operators, cache_dir=cache_dir, **kwargs) -def _run_serial( +def _run_serial_evaluation( case_operators: list["cases.CaseOperator"], cache_dir: Optional[pathlib.Path] = None, **kwargs, @@ -193,8 +228,9 @@ def _run_serial( return run_results -def _run_parallel( +def _run_parallel_evaluation( case_operators: list["cases.CaseOperator"], + parallel_config: dict, cache_dir: Optional[pathlib.Path] = None, **kwargs, ) -> list[pd.DataFrame]: @@ -207,11 +243,6 @@ def _run_parallel( Returns: List of result DataFrames. """ - parallel_config = kwargs.pop("parallel_config", None) - - if parallel_config is None: - raise ValueError("parallel_config must be provided to _run_parallel") - if parallel_config.get("n_jobs") is None: logger.warning("No number of jobs provided, using joblib backend default.") diff --git a/tests/test_evaluate.py b/tests/test_evaluate.py index eca58b63..656b29d0 100644 --- a/tests/test_evaluate.py +++ b/tests/test_evaluate.py @@ -334,10 +334,10 @@ def test_case_operators_property( # Check that the result is what the mock returned assert result == [sample_case_operator] - @mock.patch("extremeweatherbench.evaluate._run_case_operators") - def test_run_serial( + @mock.patch("extremeweatherbench.evaluate._run_evaluation") + def test_run_serial_evaluation( self, - mock_run_case_operators, + mock_run_evaluation, sample_cases_dict, sample_evaluation_object, sample_case_operator, @@ -347,7 +347,7 @@ def test_run_serial( with mock.patch.object( evaluate.ExtremeWeatherBench, "case_operators", new=[sample_case_operator] ): - # Mock _run_case_operators to return a list of DataFrames + # Mock _run_evaluation to return a list of DataFrames mock_result = [ pd.DataFrame( { @@ -357,7 +357,7 @@ def test_run_serial( } ) ] - mock_run_case_operators.return_value = mock_result + mock_run_evaluation.return_value = mock_result ewb = evaluate.ExtremeWeatherBench( case_metadata=sample_cases_dict, @@ -367,17 +367,17 @@ def test_run_serial( result = ewb.run(n_jobs=1) # Serial mode should not pass parallel_config - mock_run_case_operators.assert_called_once_with( + mock_run_evaluation.assert_called_once_with( [sample_case_operator], cache_dir=None, ) assert isinstance(result, pd.DataFrame) assert len(result) == 1 - @mock.patch("extremeweatherbench.evaluate._run_case_operators") - def test_run_parallel( + @mock.patch("extremeweatherbench.evaluate._run_evaluation") + def test_run_parallel_evaluation( self, - mock_run_case_operators, + mock_run_evaluation, sample_cases_dict, sample_evaluation_object, sample_case_operator, @@ -395,7 +395,7 @@ def test_run_parallel( } ) ] - mock_run_case_operators.return_value = mock_result + mock_run_evaluation.return_value = mock_result ewb = evaluate.ExtremeWeatherBench( case_metadata=sample_cases_dict, @@ -404,7 +404,7 @@ def test_run_parallel( result = ewb.run(n_jobs=2) - mock_run_case_operators.assert_called_once_with( + mock_run_evaluation.assert_called_once_with( [sample_case_operator], cache_dir=None, parallel_config={"backend": "threading", "n_jobs": 2}, @@ -412,10 +412,10 @@ def test_run_parallel( assert isinstance(result, pd.DataFrame) assert len(result) == 1 - @mock.patch("extremeweatherbench.evaluate._run_case_operators") + @mock.patch("extremeweatherbench.evaluate._run_evaluation") def test_run_with_kwargs( self, - mock_run_case_operators, + mock_run_evaluation, sample_cases_dict, sample_evaluation_object, sample_case_operator, @@ -425,7 +425,7 @@ def test_run_with_kwargs( evaluate.ExtremeWeatherBench, "case_operators", new=[sample_case_operator] ): mock_result = [pd.DataFrame({"value": [1.0]})] - mock_run_case_operators.return_value = mock_result + mock_run_evaluation.return_value = mock_result ewb = evaluate.ExtremeWeatherBench( case_metadata=sample_cases_dict, @@ -435,20 +435,20 @@ def test_run_with_kwargs( result = ewb.run(n_jobs=1, threshold=0.5) # Check that kwargs were passed through - call_args = mock_run_case_operators.call_args + call_args = mock_run_evaluation.call_args assert call_args[1]["threshold"] == 0.5 assert isinstance(result, pd.DataFrame) - @mock.patch("extremeweatherbench.evaluate._run_case_operators") + @mock.patch("extremeweatherbench.evaluate._run_evaluation") def test_run_empty_results( self, - mock_run_case_operators, + mock_run_evaluation, sample_cases_dict, sample_evaluation_object, ): """Test the run method handles empty results.""" with mock.patch.object(evaluate.ExtremeWeatherBench, "case_operators", new=[]): - mock_run_case_operators.return_value = [] + mock_run_evaluation.return_value = [] ewb = evaluate.ExtremeWeatherBench( case_metadata=sample_cases_dict, @@ -547,107 +547,115 @@ def test_run_multiple_cases( class TestRunCaseOperators: - """Test the _run_case_operators function.""" + """Test the _run_evaluation function.""" - @mock.patch("extremeweatherbench.evaluate._run_serial") - def test_run_case_operators_serial(self, mock_run_serial, sample_case_operator): - """Test _run_case_operators routes to serial execution.""" + @mock.patch("extremeweatherbench.evaluate._run_serial_evaluation") + def test_run_evaluation_serial( + self, mock_run_serial_evaluation, sample_case_operator + ): + """Test _run_evaluation routes to serial execution.""" mock_results = [pd.DataFrame({"value": [1.0]})] - mock_run_serial.return_value = mock_results + mock_run_serial_evaluation.return_value = mock_results # Serial mode: don't pass parallel_config - result = evaluate._run_case_operators([sample_case_operator], cache_dir=None) + result = evaluate._run_evaluation([sample_case_operator], cache_dir=None) - mock_run_serial.assert_called_once_with([sample_case_operator], cache_dir=None) + mock_run_serial_evaluation.assert_called_once_with( + [sample_case_operator], cache_dir=None + ) assert result == mock_results - @mock.patch("extremeweatherbench.evaluate._run_parallel") - def test_run_case_operators_parallel(self, mock_run_parallel, sample_case_operator): - """Test _run_case_operators routes to parallel execution.""" + @mock.patch("extremeweatherbench.evaluate._run_parallel_evaluation") + def test_run_evaluation_parallel( + self, mock_run_parallel_evaluation, sample_case_operator + ): + """Test _run_evaluation routes to parallel execution.""" mock_results = [pd.DataFrame({"value": [1.0]})] - mock_run_parallel.return_value = mock_results + mock_run_parallel_evaluation.return_value = mock_results - result = evaluate._run_case_operators( + result = evaluate._run_evaluation( [sample_case_operator], cache_dir=None, parallel_config={"backend": "threading", "n_jobs": 4}, ) - mock_run_parallel.assert_called_once_with( + mock_run_parallel_evaluation.assert_called_once_with( [sample_case_operator], cache_dir=None, parallel_config={"backend": "threading", "n_jobs": 4}, ) assert result == mock_results - @mock.patch("extremeweatherbench.evaluate._run_serial") - def test_run_case_operators_with_kwargs( - self, mock_run_serial, sample_case_operator + @mock.patch("extremeweatherbench.evaluate._run_serial_evaluation") + def test_run_evaluation_with_kwargs( + self, mock_run_serial_evaluation, sample_case_operator ): - """Test _run_case_operators passes kwargs correctly.""" + """Test _run_evaluation passes kwargs correctly.""" mock_results = [pd.DataFrame({"value": [1.0]})] - mock_run_serial.return_value = mock_results + mock_run_serial_evaluation.return_value = mock_results # Serial mode: don't pass parallel_config - result = evaluate._run_case_operators( + result = evaluate._run_evaluation( [sample_case_operator], cache_dir=None, threshold=0.5, ) - call_args = mock_run_serial.call_args + call_args = mock_run_serial_evaluation.call_args assert call_args[0][0] == [sample_case_operator] assert call_args[1]["cache_dir"] is None assert call_args[1]["threshold"] == 0.5 assert isinstance(result, list) - @mock.patch("extremeweatherbench.evaluate._run_parallel") - def test_run_case_operators_parallel_with_kwargs( - self, mock_run_parallel, sample_case_operator + @mock.patch("extremeweatherbench.evaluate._run_parallel_evaluation") + def test_run_evaluation_parallel_with_kwargs( + self, mock_run_parallel_evaluation, sample_case_operator ): - """Test _run_case_operators passes kwargs to parallel execution.""" + """Test _run_evaluation passes kwargs to parallel execution.""" mock_results = [pd.DataFrame({"value": [1.0]})] - mock_run_parallel.return_value = mock_results + mock_run_parallel_evaluation.return_value = mock_results - result = evaluate._run_case_operators( + result = evaluate._run_evaluation( [sample_case_operator], parallel_config={"backend": "threading", "n_jobs": 2}, custom_param="test_value", ) - call_args = mock_run_parallel.call_args + call_args = mock_run_parallel_evaluation.call_args assert call_args[0][0] == [sample_case_operator] assert call_args[1]["parallel_config"] == {"backend": "threading", "n_jobs": 2} assert call_args[1]["custom_param"] == "test_value" assert isinstance(result, list) - def test_run_case_operators_empty_list(self): - """Test _run_case_operators with empty case operator list.""" - with mock.patch("extremeweatherbench.evaluate._run_serial") as mock_serial: + def test_run_evaluation_empty_list(self): + """Test _run_evaluation with empty case operator list.""" + with mock.patch( + "extremeweatherbench.evaluate._run_serial_evaluation" + ) as mock_serial: mock_serial.return_value = [] # Serial mode: don't pass parallel_config - result = evaluate._run_case_operators([], cache_dir=None) + result = evaluate._run_evaluation([], cache_dir=None) mock_serial.assert_called_once_with([], cache_dir=None) assert result == [] class TestRunSerial: - """Test the _run_serial function.""" + """Test the _run_serial_evaluation function.""" @mock.patch("extremeweatherbench.evaluate.compute_case_operator") @mock.patch("tqdm.auto.tqdm") - def test_run_serial_basic( + def test_run_serial_evaluation_basic( self, mock_tqdm, mock_compute_case_operator, sample_case_operator ): - """Test basic _run_serial functionality.""" + """Test basic _run_serial_evaluation functionality.""" # Setup mocks mock_tqdm.return_value = [sample_case_operator] # tqdm returns iterable mock_result = pd.DataFrame({"value": [1.0], "case_id_number": [1]}) mock_compute_case_operator.return_value = mock_result - result = evaluate._run_serial([sample_case_operator]) + result = evaluate._run_serial_evaluation([sample_case_operator]) mock_compute_case_operator.assert_called_once_with(sample_case_operator, None) assert len(result) == 1 @@ -655,8 +663,10 @@ def test_run_serial_basic( @mock.patch("extremeweatherbench.evaluate.compute_case_operator") @mock.patch("tqdm.auto.tqdm") - def test_run_serial_multiple_cases(self, mock_tqdm, mock_compute_case_operator): - """Test _run_serial with multiple case operators.""" + def test_run_serial_evaluation_multiple_cases( + self, mock_tqdm, mock_compute_case_operator + ): + """Test _run_serial_evaluation with multiple case operators.""" case_op_1 = mock.Mock() case_op_2 = mock.Mock() case_operators = [case_op_1, case_op_2] @@ -667,7 +677,7 @@ def test_run_serial_multiple_cases(self, mock_tqdm, mock_compute_case_operator): pd.DataFrame({"value": [2.0], "case_id_number": [2]}), ] - result = evaluate._run_serial(case_operators) + result = evaluate._run_serial_evaluation(case_operators) assert mock_compute_case_operator.call_count == 2 assert len(result) == 2 @@ -676,15 +686,15 @@ def test_run_serial_multiple_cases(self, mock_tqdm, mock_compute_case_operator): @mock.patch("extremeweatherbench.evaluate.compute_case_operator") @mock.patch("tqdm.auto.tqdm") - def test_run_serial_with_kwargs( + def test_run_serial_evaluation_with_kwargs( self, mock_tqdm, mock_compute_case_operator, sample_case_operator ): - """Test _run_serial passes kwargs to compute_case_operator.""" + """Test _run_serial_evaluation passes kwargs to compute_case_operator.""" mock_tqdm.return_value = [sample_case_operator] mock_result = pd.DataFrame({"value": [1.0]}) mock_compute_case_operator.return_value = mock_result - result = evaluate._run_serial( + result = evaluate._run_serial_evaluation( [sample_case_operator], threshold=0.7, custom_param="test" ) @@ -694,22 +704,22 @@ def test_run_serial_with_kwargs( assert call_args[1]["custom_param"] == "test" assert isinstance(result, list) - def test_run_serial_empty_list(self): - """Test _run_serial with empty case operator list.""" - result = evaluate._run_serial([]) + def test_run_serial_evaluation_empty_list(self): + """Test _run_serial_evaluation with empty case operator list.""" + result = evaluate._run_serial_evaluation([]) assert result == [] class TestRunParallel: - """Test the _run_parallel function.""" + """Test the _run_parallel_evaluation function.""" @mock.patch("extremeweatherbench.utils.ParallelTqdm") @mock.patch("joblib.delayed") @mock.patch("tqdm.auto.tqdm") - def test_run_parallel_basic( + def test_run_parallel_evaluation_basic( self, mock_tqdm, mock_delayed, mock_parallel_class, sample_case_operator ): - """Test basic _run_parallel functionality.""" + """Test basic _run_parallel_evaluation functionality.""" # Setup mocks mock_tqdm.return_value = [sample_case_operator] mock_delayed_func = mock.Mock() @@ -720,7 +730,7 @@ def test_run_parallel_basic( mock_result = [pd.DataFrame({"value": [1.0], "case_id_number": [1]})] mock_parallel_instance.return_value = mock_result - result = evaluate._run_parallel( + result = evaluate._run_parallel_evaluation( [sample_case_operator], parallel_config={"backend": "threading", "n_jobs": 2}, ) @@ -736,10 +746,10 @@ def test_run_parallel_basic( @mock.patch("extremeweatherbench.utils.ParallelTqdm") @mock.patch("joblib.delayed") @mock.patch("tqdm.auto.tqdm") - def test_run_parallel_with_none_n_jobs( + def test_run_parallel_evaluation_with_none_n_jobs( self, mock_tqdm, mock_delayed, mock_parallel_class, sample_case_operator ): - """Test _run_parallel with n_jobs=None (should use all CPUs).""" + """Test _run_parallel_evaluation with n_jobs=None (should use all CPUs).""" mock_tqdm.return_value = [sample_case_operator] mock_delayed_func = mock.Mock() mock_delayed.return_value = mock_delayed_func @@ -750,7 +760,7 @@ def test_run_parallel_with_none_n_jobs( mock_parallel_instance.return_value = mock_result with mock.patch("extremeweatherbench.evaluate.logger.warning") as mock_warning: - result = evaluate._run_parallel( + result = evaluate._run_parallel_evaluation( [sample_case_operator], parallel_config={"backend": "threading", "n_jobs": None}, ) @@ -766,7 +776,7 @@ def test_run_parallel_with_none_n_jobs( @mock.patch("joblib.parallel_config") @mock.patch("extremeweatherbench.utils.ParallelTqdm") - def test_run_parallel_n_jobs_in_config( + def test_run_parallel_evaluation_n_jobs_in_config( self, mock_parallel_class, mock_parallel_config ): """Test that n_jobs is passed through parallel_config, not directly.""" @@ -783,7 +793,7 @@ def test_run_parallel_n_jobs_in_config( ) mock_parallel_config.return_value.__exit__ = mock.Mock(return_value=False) - result = evaluate._run_parallel( + result = evaluate._run_parallel_evaluation( [sample_case_operator], parallel_config={"backend": "threading", "n_jobs": 4}, ) @@ -801,10 +811,10 @@ def test_run_parallel_n_jobs_in_config( @mock.patch("extremeweatherbench.utils.ParallelTqdm") @mock.patch("joblib.delayed") @mock.patch("tqdm.auto.tqdm") - def test_run_parallel_multiple_cases( + def test_run_parallel_evaluation_multiple_cases( self, mock_tqdm, mock_delayed, mock_parallel_class ): - """Test _run_parallel with multiple case operators.""" + """Test _run_parallel_evaluation with multiple case operators.""" case_op_1 = mock.Mock() case_op_2 = mock.Mock() case_operators = [case_op_1, case_op_2] @@ -821,7 +831,7 @@ def test_run_parallel_multiple_cases( ] mock_parallel_instance.return_value = mock_result - result = evaluate._run_parallel( + result = evaluate._run_parallel_evaluation( case_operators, parallel_config={"backend": "threading", "n_jobs": 4} ) @@ -832,10 +842,10 @@ def test_run_parallel_multiple_cases( @mock.patch("extremeweatherbench.utils.ParallelTqdm") @mock.patch("joblib.delayed") @mock.patch("tqdm.auto.tqdm") - def test_run_parallel_with_kwargs( + def test_run_parallel_evaluation_with_kwargs( self, mock_tqdm, mock_delayed, mock_parallel_class, sample_case_operator ): - """Test _run_parallel passes kwargs correctly.""" + """Test _run_parallel_evaluation passes kwargs correctly.""" mock_tqdm.return_value = [sample_case_operator] mock_delayed_func = mock.Mock() mock_delayed.return_value = mock_delayed_func @@ -845,7 +855,7 @@ def test_run_parallel_with_kwargs( mock_result = [pd.DataFrame({"value": [1.0]})] mock_parallel_instance.return_value = mock_result - result = evaluate._run_parallel( + result = evaluate._run_parallel_evaluation( [sample_case_operator], parallel_config={"backend": "threading", "n_jobs": 2}, threshold=0.8, @@ -862,8 +872,8 @@ def test_run_parallel_with_kwargs( assert len(delayed_calls) == 1 assert isinstance(result, list) - def test_run_parallel_empty_list(self): - """Test _run_parallel with empty case operator list.""" + def test_run_parallel_evaluation_empty_list(self): + """Test _run_parallel_evaluation with empty case operator list.""" with mock.patch( "extremeweatherbench.utils.ParallelTqdm" ) as mock_parallel_class: @@ -873,7 +883,7 @@ def test_run_parallel_empty_list(self): mock_parallel_class.return_value = mock_parallel_instance mock_parallel_instance.return_value = [] - result = evaluate._run_parallel( + result = evaluate._run_parallel_evaluation( [], parallel_config={"backend": "threading", "n_jobs": 2} ) @@ -884,10 +894,10 @@ def test_run_parallel_empty_list(self): ) @mock.patch("dask.distributed.Client") @mock.patch("dask.distributed.LocalCluster") - def test_run_parallel_dask_backend_auto_client( + def test_run_parallel_evaluation_dask_backend_auto_client( self, mock_local_cluster, mock_client_class, sample_case_operator ): - """Test _run_parallel with dask backend automatically creates client.""" + """Test _run_parallel_evaluation with dask backend automatically creates client.""" # Mock Client.current() to raise ValueError (no existing client) mock_client_class.current.side_effect = ValueError("No client found") @@ -906,7 +916,7 @@ def test_run_parallel_dask_backend_auto_client( mock_parallel_instance.return_value = [pd.DataFrame({"test": [1]})] with mock.patch("joblib.parallel_config"): - result = evaluate._run_parallel( + result = evaluate._run_parallel_evaluation( [sample_case_operator], parallel_config={"backend": "dask", "n_jobs": 2}, ) @@ -920,10 +930,10 @@ def test_run_parallel_dask_backend_auto_client( not HAS_DASK_DISTRIBUTED, reason="dask.distributed not installed" ) @mock.patch("dask.distributed.Client") - def test_run_parallel_dask_backend_existing_client( + def test_run_parallel_evaluation_dask_backend_existing_client( self, mock_client_class, sample_case_operator ): - """Test _run_parallel with dask backend uses existing client.""" + """Test _run_parallel_evaluation with dask backend uses existing client.""" # Mock existing client mock_existing_client = mock.Mock() mock_client_class.current.return_value = mock_existing_client @@ -935,7 +945,7 @@ def test_run_parallel_dask_backend_existing_client( mock_parallel_instance.return_value = [pd.DataFrame({"test": [1]})] with mock.patch("joblib.parallel_config"): - result = evaluate._run_parallel( + result = evaluate._run_parallel_evaluation( [sample_case_operator], parallel_config={"backend": "dask", "n_jobs": 2}, ) @@ -1648,49 +1658,51 @@ def test_evaluate_metric_computation_failure( case_operator=sample_case_operator, ) - @mock.patch("extremeweatherbench.evaluate._run_serial") - def test_run_case_operators_serial_exception( - self, mock_run_serial, sample_case_operator + @mock.patch("extremeweatherbench.evaluate._run_serial_evaluation") + def test_run_evaluation_serial_exception( + self, mock_run_serial_evaluation, sample_case_operator ): - """Test _run_case_operators handles exceptions in serial execution.""" - mock_run_serial.side_effect = Exception("Serial execution failed") + """Test _run_evaluation handles exceptions in serial execution.""" + mock_run_serial_evaluation.side_effect = Exception("Serial execution failed") with pytest.raises(Exception, match="Serial execution failed"): # Serial mode: don't pass parallel_config - evaluate._run_case_operators([sample_case_operator], None) + evaluate._run_evaluation([sample_case_operator], None) - @mock.patch("extremeweatherbench.evaluate._run_parallel") - def test_run_case_operators_parallel_exception( - self, mock_run_parallel, sample_case_operator + @mock.patch("extremeweatherbench.evaluate._run_parallel_evaluation") + def test_run_evaluation_parallel_exception( + self, mock_run_parallel_evaluation, sample_case_operator ): - """Test _run_case_operators handles exceptions in parallel execution.""" - mock_run_parallel.side_effect = Exception("Parallel execution failed") + """Test _run_evaluation handles exceptions in parallel execution.""" + mock_run_parallel_evaluation.side_effect = Exception( + "Parallel execution failed" + ) with pytest.raises(Exception, match="Parallel execution failed"): - evaluate._run_case_operators( + evaluate._run_evaluation( [sample_case_operator], parallel_config={"backend": "threading", "n_jobs": 2}, ) @mock.patch("extremeweatherbench.evaluate.compute_case_operator") @mock.patch("tqdm.auto.tqdm") - def test_run_serial_case_operator_exception( + def test_run_serial_evaluation_case_operator_exception( self, mock_tqdm, mock_compute_case_operator, sample_case_operator ): - """Test _run_serial handles exceptions from individual case operators.""" + """Test _run_serial_evaluation handles exceptions from individual case operators.""" mock_tqdm.return_value = [sample_case_operator] mock_compute_case_operator.side_effect = Exception("Case operator failed") with pytest.raises(Exception, match="Case operator failed"): - evaluate._run_serial([sample_case_operator]) + evaluate._run_serial_evaluation([sample_case_operator]) @mock.patch("extremeweatherbench.utils.ParallelTqdm") @mock.patch("joblib.delayed") @mock.patch("tqdm.auto.tqdm") - def test_run_parallel_joblib_exception( + def test_run_parallel_evaluation_joblib_exception( self, mock_tqdm, mock_delayed, mock_parallel_class, sample_case_operator ): - """Test _run_parallel handles joblib Parallel exceptions.""" + """Test _run_parallel_evaluation handles joblib Parallel exceptions.""" mock_tqdm.return_value = [sample_case_operator] mock_delayed_func = mock.Mock() mock_delayed.return_value = mock_delayed_func @@ -1700,7 +1712,7 @@ def test_run_parallel_joblib_exception( mock_parallel_instance.side_effect = Exception("Joblib parallel failed") with pytest.raises(Exception, match="Joblib parallel failed"): - evaluate._run_parallel( + evaluate._run_parallel_evaluation( [sample_case_operator], parallel_config={"backend": "threading", "n_jobs": 2}, ) @@ -1708,10 +1720,10 @@ def test_run_parallel_joblib_exception( @mock.patch("extremeweatherbench.utils.ParallelTqdm") @mock.patch("joblib.delayed") @mock.patch("tqdm.auto.tqdm") - def test_run_parallel_delayed_function_exception( + def test_run_parallel_evaluation_delayed_function_exception( self, mock_tqdm, mock_delayed, mock_parallel_class, sample_case_operator ): - """Test _run_parallel handles exceptions in delayed functions.""" + """Test _run_parallel_evaluation handles exceptions in delayed functions.""" mock_tqdm.return_value = [sample_case_operator] # Mock delayed to raise an exception @@ -1727,17 +1739,17 @@ def consume_generator(generator): mock_parallel_instance.side_effect = consume_generator with pytest.raises(Exception, match="Delayed function creation failed"): - evaluate._run_parallel( + evaluate._run_parallel_evaluation( [sample_case_operator], parallel_config={"backend": "threading", "n_jobs": 2}, ) - @mock.patch("extremeweatherbench.evaluate._run_case_operators") + @mock.patch("extremeweatherbench.evaluate._run_evaluation") def test_run_method_exception_propagation( - self, mock_run_case_operators, sample_cases_dict, sample_evaluation_object + self, mock_run_evaluation, sample_cases_dict, sample_evaluation_object ): """Test that ExtremeWeatherBench.run() propagates exceptions correctly.""" - mock_run_case_operators.side_effect = Exception("Execution failed") + mock_run_evaluation.side_effect = Exception("Execution failed") ewb = evaluate.ExtremeWeatherBench( case_metadata=sample_cases_dict, @@ -1749,8 +1761,10 @@ def test_run_method_exception_propagation( @mock.patch("extremeweatherbench.evaluate.compute_case_operator") @mock.patch("tqdm.auto.tqdm") - def test_run_serial_partial_failure(self, mock_tqdm, mock_compute_case_operator): - """Test _run_serial behavior when some case operators fail.""" + def test_run_serial_evaluation_partial_failure( + self, mock_tqdm, mock_compute_case_operator + ): + """Test _run_serial_evaluation behavior when some case operators fail.""" case_op_1 = mock.Mock() case_op_2 = mock.Mock() case_op_3 = mock.Mock() @@ -1767,7 +1781,7 @@ def test_run_serial_partial_failure(self, mock_tqdm, mock_compute_case_operator) # Should fail on the second case operator with pytest.raises(Exception, match="Case operator 2 failed"): - evaluate._run_serial(case_operators) + evaluate._run_serial_evaluation(case_operators) # Should have tried only the first two assert mock_compute_case_operator.call_count == 2 @@ -1775,10 +1789,10 @@ def test_run_serial_partial_failure(self, mock_tqdm, mock_compute_case_operator) @mock.patch("extremeweatherbench.utils.ParallelTqdm") @mock.patch("joblib.delayed") @mock.patch("tqdm.auto.tqdm") - def test_run_parallel_invalid_n_jobs( + def test_run_parallel_evaluation_invalid_n_jobs( self, mock_tqdm, mock_delayed, mock_parallel_class, sample_case_operator ): - """Test _run_parallel with invalid n_jobs parameter.""" + """Test _run_parallel_evaluation with invalid n_jobs parameter.""" mock_tqdm.return_value = [sample_case_operator] mock_delayed_func = mock.Mock() mock_delayed.return_value = mock_delayed_func @@ -1787,7 +1801,7 @@ def test_run_parallel_invalid_n_jobs( mock_parallel_class.side_effect = ValueError("Invalid n_jobs parameter") with pytest.raises(ValueError, match="Invalid n_jobs parameter"): - evaluate._run_parallel( + evaluate._run_parallel_evaluation( [sample_case_operator], parallel_config={"backend": "threading", "n_jobs": -5}, ) @@ -2054,13 +2068,13 @@ def test_execution_method_performance_comparison(self, mock_compute_case_operato for i in range(10) ] - # Test serial execution timing - call _run_serial directly + # Test serial execution timing - call _run_serial_evaluation directly mock_compute_case_operator.side_effect = mock_results start_time = time.time() - serial_result = evaluate._run_serial(case_operators) + serial_result = evaluate._run_serial_evaluation(case_operators) serial_time = time.time() - start_time - # Test parallel execution timing - call _run_parallel directly with mocked + # Test parallel execution timing - call _run_parallel_evaluation directly with mocked # Parallel serial_call_count = mock_compute_case_operator.call_count mock_compute_case_operator.side_effect = mock_results @@ -2073,7 +2087,7 @@ def test_execution_method_performance_comparison(self, mock_compute_case_operato mock_parallel_instance.return_value = mock_results start_time = time.time() - parallel_result = evaluate._run_parallel( + parallel_result = evaluate._run_parallel_evaluation( case_operators, parallel_config={"backend": "threading", "n_jobs": 2} ) parallel_time = time.time() - start_time @@ -2115,7 +2129,7 @@ def test_mixed_execution_parameters(self, mock_compute_case_operator): mock_compute_case_operator.side_effect = mock_results if config["method"] == "serial": - result = evaluate._run_serial(*config["args"]) + result = evaluate._run_serial_evaluation(*config["args"]) # All configurations should produce valid results assert isinstance(result, list) assert len(result) == 2 @@ -2138,7 +2152,9 @@ def test_mixed_execution_parameters(self, mock_compute_case_operator): "n_jobs": n_jobs, } - result = evaluate._run_parallel(*config["args"], **kwargs) + result = evaluate._run_parallel_evaluation( + *config["args"], **kwargs + ) # All configurations should produce valid results assert isinstance(result, list) @@ -2164,7 +2180,7 @@ def mock_compute_with_kwargs(case_op, cache_dir, **kwargs): side_effect=mock_compute_with_kwargs, ): # Test serial kwargs propagation - result = evaluate._run_serial( + result = evaluate._run_serial_evaluation( [case_operator], custom_param="serial_test", threshold=0.9 ) @@ -2188,7 +2204,7 @@ def mock_compute_with_kwargs(case_op, cache_dir, **kwargs): # Reset captured kwargs mock_compute_with_kwargs.captured_kwargs = {} - result = evaluate._run_parallel( + result = evaluate._run_parallel_evaluation( [case_operator], parallel_config={"backend": "threading", "n_jobs": 2}, custom_param="parallel_test", @@ -2201,20 +2217,20 @@ def mock_compute_with_kwargs(case_op, cache_dir, **kwargs): def test_empty_case_operators_all_methods(self): """Test that all execution methods handle empty case operator lists.""" - # Test _run_case_operators - result = evaluate._run_case_operators([], parallel_config={"n_jobs": 1}) + # Test _run_evaluation + result = evaluate._run_evaluation([], parallel_config={"n_jobs": 1}) assert result == [] - result = evaluate._run_case_operators( + result = evaluate._run_evaluation( [], parallel_config={"backend": "threading", "n_jobs": 2} ) assert result == [] - # Test _run_serial - result = evaluate._run_serial([]) + # Test _run_serial_evaluation + result = evaluate._run_serial_evaluation([]) assert result == [] - # Test _run_parallel + # Test _run_parallel_evaluation with mock.patch( "extremeweatherbench.utils.ParallelTqdm" ) as mock_parallel_class: @@ -2222,7 +2238,7 @@ def test_empty_case_operators_all_methods(self): mock_parallel_class.return_value = mock_parallel_instance mock_parallel_instance.return_value = [] - result = evaluate._run_parallel( + result = evaluate._run_parallel_evaluation( [], parallel_config={"backend": "threading", "n_jobs": 2} ) assert result == [] @@ -2244,7 +2260,7 @@ def test_large_case_operator_list_handling(self, mock_compute_case_operator): # Test serial execution mock_compute_case_operator.side_effect = mock_results - serial_results = evaluate._run_serial(case_operators) + serial_results = evaluate._run_serial_evaluation(case_operators) assert len(serial_results) == num_cases assert mock_compute_case_operator.call_count == num_cases @@ -2260,7 +2276,7 @@ def test_large_case_operator_list_handling(self, mock_compute_case_operator): mock_parallel_class.return_value = mock_parallel_instance mock_parallel_instance.return_value = mock_results - parallel_results = evaluate._run_parallel( + parallel_results = evaluate._run_parallel_evaluation( case_operators, parallel_config={"backend": "threading", "n_jobs": 4} ) From 7451ac3d51838d97a5acb1d5828dcea7d01f90d6 Mon Sep 17 00:00:00 2001 From: taylor Date: Tue, 13 Jan 2026 01:51:34 +0000 Subject: [PATCH 02/13] update naming --- docs/examples/applied_ar.py | 2 +- docs/examples/applied_freeze.py | 4 +- docs/examples/applied_heatwave.py | 2 +- docs/examples/applied_severe.py | 3 +- docs/examples/applied_tc.py | 2 +- docs/parallelism.md | 2 +- docs/usage.md | 2 +- scripts/brightband_evaluation.py | 2 +- src/extremeweatherbench/evaluate.py | 68 ++++++++++--------------- src/extremeweatherbench/evaluate_cli.py | 2 +- tests/test_evaluate.py | 24 ++++----- tests/test_integration.py | 4 +- 12 files changed, 50 insertions(+), 67 deletions(-) diff --git a/docs/examples/applied_ar.py b/docs/examples/applied_ar.py index d07a3edf..08c4ad59 100644 --- a/docs/examples/applied_ar.py +++ b/docs/examples/applied_ar.py @@ -133,7 +133,7 @@ def _preprocess_bb_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: ) # Run the workflow using 3 jobs - outputs = ar_ewb.run(parallel_config={"backend": "loky", "n_jobs": 3}) + outputs = ar_ewb.run_evaluation(parallel_config={"backend": "loky", "n_jobs": 3}) # Save the evaluation outputs to a csv file outputs.to_csv("ar_signal_outputs.csv") diff --git a/docs/examples/applied_freeze.py b/docs/examples/applied_freeze.py index 70cf00d9..a430bd39 100644 --- a/docs/examples/applied_freeze.py +++ b/docs/examples/applied_freeze.py @@ -1,7 +1,7 @@ import logging import operator -from extremeweatherbench import cases, evaluate, inputs, metrics, defaults +from extremeweatherbench import cases, defaults, evaluate, inputs, metrics # Set the logger level to INFO logger = logging.getLogger("extremeweatherbench") @@ -65,7 +65,7 @@ ) # Run the workflow - outputs = ewb.run(parallel_config={"backend": "loky", "n_jobs": 1}) + outputs = ewb.run_evaluation(parallel_config={"backend": "loky", "n_jobs": 1}) # Print the outputs; can be saved if desired outputs.to_csv("freeze_outputs.csv") diff --git a/docs/examples/applied_heatwave.py b/docs/examples/applied_heatwave.py index e74d7cbe..9f7a7dd1 100644 --- a/docs/examples/applied_heatwave.py +++ b/docs/examples/applied_heatwave.py @@ -65,5 +65,5 @@ ) # Run the workflow - outputs = ewb.run(parallel_config={"backend": "loky", "n_jobs": 2}) + outputs = ewb.run_evaluation(parallel_config={"backend": "loky", "n_jobs": 2}) outputs.to_csv("applied_heatwave_outputs.csv") diff --git a/docs/examples/applied_severe.py b/docs/examples/applied_severe.py index 72cb52fd..40c00770 100644 --- a/docs/examples/applied_severe.py +++ b/docs/examples/applied_severe.py @@ -1,6 +1,5 @@ import logging - from extremeweatherbench import cases, derived, evaluate, inputs, metrics # Set the logger level to INFO @@ -84,7 +83,7 @@ logger.info("Starting EWB run") # Run the workflow with parllel_config backend set to dask - outputs = ewb.run(parallel_config={"backend": "loky", "n_jobs": 3}) + outputs = ewb.run_evaluation(parallel_config={"backend": "loky", "n_jobs": 3}) # Save the results to a CSV file outputs.to_csv("applied_severe_convection_results.csv") diff --git a/docs/examples/applied_tc.py b/docs/examples/applied_tc.py index 88c88d8a..955c45a5 100644 --- a/docs/examples/applied_tc.py +++ b/docs/examples/applied_tc.py @@ -153,7 +153,7 @@ def _preprocess_hres_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: ) logger.info("Starting EWB run") # Run the workflow with parallel_config backend set to dask - outputs = ewb.run( + outputs = ewb.run_evaluation( parallel_config={"backend": "loky", "n_jobs": 3}, ) outputs.to_csv("tc_metric_test_results.csv") diff --git a/docs/parallelism.md b/docs/parallelism.md index 60ad5603..fc90328f 100644 --- a/docs/parallelism.md +++ b/docs/parallelism.md @@ -35,7 +35,7 @@ ewb = evaluate.ExtremeWeatherBench( # The larger the machine, the larger n_jobs can be (a bit of an oversimplification) parallel_config = {"backend":"loky","n_jobs":len(evaluation_objects)} -outputs = ewb.run(parallel_config=parallel_config) +outputs = ewb.run_evaluation(parallel_config=parallel_config) ``` The _safest_ approach is to run EWB in serial, with `n_jobs` set to 1. `Dask` will still be invoked during each `CaseOperator` when the case executes and computes the directed acyclic graph, only one at a time. That said, for evaluations with more cases this approach would likely be too time-consuming. \ No newline at end of file diff --git a/docs/usage.md b/docs/usage.md index a6c83e82..b1891ebb 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -23,7 +23,7 @@ cases = cases.load_ewb_events_yaml_into_case_collection() ewb = ExtremeWeatherBench(cases=cases, evaluation_objects=eval_objects) -outputs = ewb.run() +outputs = ewb.run_evaluation() outputs.to_csv('your_outputs.csv') ``` diff --git a/scripts/brightband_evaluation.py b/scripts/brightband_evaluation.py index 0ce87f49..0a4be555 100644 --- a/scripts/brightband_evaluation.py +++ b/scripts/brightband_evaluation.py @@ -39,5 +39,5 @@ def configure_logger(level=logging.INFO): # Set up parallel configuration parallel_config = {"backend": "loky", "n_jobs": n_processes} - results = ewb.run(parallel_config=parallel_config) + results = ewb.run_evaluation(parallel_config=parallel_config) results.to_csv("brightband_evaluation_results.csv", index=False) diff --git a/src/extremeweatherbench/evaluate.py b/src/extremeweatherbench/evaluate.py index 8bd98ef4..a318762c 100644 --- a/src/extremeweatherbench/evaluate.py +++ b/src/extremeweatherbench/evaluate.py @@ -113,17 +113,20 @@ def run_evaluation( parallel_config: Optional dictionary of joblib parallel configuration. If provided, this takes precedence over n_jobs. If not provided and n_jobs is specified, a default config with threading backend is used. - + **kwargs: Additional arguments to pass to compute_case_operator. Returns: A concatenated dataframe of the evaluation results. """ logger.info("Running ExtremeWeatherBench evaluations...") # Check for serial or parallel configuration - run_config_kwargs = _parallel_config_check(n_jobs, parallel_config, **kwargs) + parallel_config = _parallel_serial_config_check(n_jobs, parallel_config) run_results = _run_evaluation( - self.case_operators, cache_dir=self.cache_dir, **run_config_kwargs + self.case_operators, + cache_dir=self.cache_dir, + parallel_config=parallel_config, + **kwargs, ) # If there are results, concatenate them and return, else return an empty @@ -134,28 +137,21 @@ def run_evaluation( return pd.DataFrame(columns=OUTPUT_COLUMNS) -def _parallel_config_check( +def _parallel_serial_config_check( n_jobs: Optional[int] = None, parallel_config: Optional[dict] = None, - **kwargs, -) -> dict: - """Build the run configuration. - - Builds the run configuration for EWB workflows depending on the configuration - provided via arguments. +) -> Optional[dict]: + """Check if running in serial or parallel mode. Args: n_jobs: The number of jobs to run in parallel. If None, defaults to the joblib backend default value. If 1, the workflow will run serially. - Ignored if parallel_config is provided. - parallel_config: Optional dictionary of joblib parallel configuration. - If provided, this takes precedence over n_jobs. If not provided and - n_jobs is specified, a default config with loky backend is used. - cache_dir: Optional directory for caching (serial mode only). - **kwargs: Additional arguments, may include 'parallel_config' dict. - + parallel_config: Optional dictionary of joblib parallel configuration. If + provided, this takes precedence over n_jobs. If not provided and n_jobs is + specified, a default config with loky backend is used. Returns: - Maybe updated kwargs if n_jobs was provided instead of parallel_config. + None if running in serial mode, otherwise a dictionary of joblib parallel + configuration. """ # Determine if running in serial or parallel mode # Serial: n_jobs=1 or (parallel_config with n_jobs=1) @@ -175,15 +171,15 @@ def _parallel_config_check( n_jobs, ) parallel_config = {"backend": "loky", "n_jobs": n_jobs} - kwargs["parallel_config"] = parallel_config # Return the maybe updated kwargs - return kwargs + return parallel_config def _run_evaluation( case_operators: list["cases.CaseOperator"], cache_dir: Optional[pathlib.Path] = None, + parallel_config: Optional[dict] = None, **kwargs, ) -> list[pd.DataFrame]: """Run the case operators in parallel or serial. @@ -196,35 +192,23 @@ def _run_evaluation( Returns: List of result DataFrames. """ - with logging_redirect_tqdm(): - # Check if parallel_config is provided - parallel_config = kwargs.get("parallel_config", None) - - # Run in parallel if parallel_config exists and n_jobs != 1 - if parallel_config is not None: + if parallel_config is not None: + with logging_redirect_tqdm(): logger.info("Running case operators in parallel...") - return _run_parallel_evaluation( + run_results = _run_parallel_evaluation( case_operators, cache_dir=cache_dir, parallel_config=parallel_config, **kwargs, ) - else: - logger.info("Running case operators in serial...") - return _run_serial_evaluation(case_operators, cache_dir=cache_dir, **kwargs) - - -def _run_serial_evaluation( - case_operators: list["cases.CaseOperator"], - cache_dir: Optional[pathlib.Path] = None, - **kwargs, -) -> list[pd.DataFrame]: - """Run the case operators in serial.""" - run_results = [] + else: + logger.info("Running case operators in serial...") + run_results = [] + for case_operator in tqdm(case_operators): + run_results.append( + compute_case_operator(case_operator, cache_dir, **kwargs) + ) - # Loop over the case operators - for case_operator in tqdm(case_operators): - run_results.append(compute_case_operator(case_operator, cache_dir, **kwargs)) return run_results diff --git a/src/extremeweatherbench/evaluate_cli.py b/src/extremeweatherbench/evaluate_cli.py index c6bc0abe..eafd263d 100644 --- a/src/extremeweatherbench/evaluate_cli.py +++ b/src/extremeweatherbench/evaluate_cli.py @@ -153,7 +153,7 @@ def cli_runner( # Run evaluation click.echo("Running evaluation...") - results = ewb.run( + results = ewb.run_evaluation( n_jobs=n_jobs, parallel_config=parallel_config, ) diff --git a/tests/test_evaluate.py b/tests/test_evaluate.py index 656b29d0..61d8843b 100644 --- a/tests/test_evaluate.py +++ b/tests/test_evaluate.py @@ -364,7 +364,7 @@ def test_run_serial_evaluation( evaluation_objects=[sample_evaluation_object], ) - result = ewb.run(n_jobs=1) + result = ewb.run_evaluation(n_jobs=1) # Serial mode should not pass parallel_config mock_run_evaluation.assert_called_once_with( @@ -402,7 +402,7 @@ def test_run_parallel_evaluation( evaluation_objects=[sample_evaluation_object], ) - result = ewb.run(n_jobs=2) + result = ewb.run_evaluation(n_jobs=2) mock_run_evaluation.assert_called_once_with( [sample_case_operator], @@ -432,7 +432,7 @@ def test_run_with_kwargs( evaluation_objects=[sample_evaluation_object], ) - result = ewb.run(n_jobs=1, threshold=0.5) + result = ewb.run_evaluation(n_jobs=1, threshold=0.5) # Check that kwargs were passed through call_args = mock_run_evaluation.call_args @@ -455,7 +455,7 @@ def test_run_empty_results( evaluation_objects=[sample_evaluation_object], ) - result = ewb.run() + result = ewb.run_evaluation() assert isinstance(result, pd.DataFrame) assert len(result) == 0 @@ -505,7 +505,7 @@ def mock_compute_with_caching(case_operator, cache_dir_arg, **kwargs): cache_dir=cache_dir, ) - ewb.run(n_jobs=1) + ewb.run_evaluation(n_jobs=1) # Check that cache directory was created assert cache_dir.exists() @@ -539,7 +539,7 @@ def test_run_multiple_cases( evaluation_objects=[sample_evaluation_object], ) - result = ewb.run() + result = ewb.run_evaluation() assert mock_compute_case_operator.call_count == 2 assert len(result) == 2 @@ -1611,7 +1611,7 @@ def test_extremeweatherbench_empty_cases(self, sample_evaluation_object): with mock.patch("extremeweatherbench.cases.build_case_operators") as mock_build: mock_build.return_value = [] - result = ewb.run() + result = ewb.run_evaluation() # Should return empty DataFrame when no cases assert isinstance(result, pd.DataFrame) @@ -1757,7 +1757,7 @@ def test_run_method_exception_propagation( ) with pytest.raises(Exception, match="Execution failed"): - ewb.run() + ewb.run_evaluation() @mock.patch("extremeweatherbench.evaluate.compute_case_operator") @mock.patch("tqdm.auto.tqdm") @@ -1886,7 +1886,7 @@ def test_end_to_end_workflow( evaluation_objects=[sample_evaluation_object], ) - result = ewb.run() + result = ewb.run_evaluation() # Verify the result structure assert isinstance(result, pd.DataFrame) @@ -1990,7 +1990,7 @@ def test_multiple_variables_and_metrics( evaluation_objects=[eval_obj], ) - result = ewb.run() + result = ewb.run_evaluation() # Should have results for each metric combination assert len(result) >= 2 # At least 2 metrics * 1 case @@ -2033,12 +2033,12 @@ def test_serial_vs_parallel_results_consistency( # Test serial execution mock_compute_case_operator.side_effect = [result_1, result_2] - serial_result = ewb.run(n_jobs=1) + serial_result = ewb.run_evaluation(n_jobs=1) # Reset mock and test parallel execution mock_compute_case_operator.reset_mock() mock_compute_case_operator.side_effect = [result_1, result_2] - parallel_result = ewb.run(n_jobs=2) + parallel_result = ewb.run_evaluation(n_jobs=2) # Both should produce valid DataFrames with same structure assert isinstance(serial_result, pd.DataFrame) diff --git a/tests/test_integration.py b/tests/test_integration.py index 1d269d5a..14b0449a 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -602,7 +602,7 @@ def test_full_workflow_single_variable( evaluation_objects=[evaluation_obj], ) - result = ewb.run() + result = ewb.run_evaluation() # Verify results assert isinstance(result, pd.DataFrame) @@ -683,7 +683,7 @@ def test_full_workflow_multiple_variables( evaluation_objects=[evaluation_obj], ) - result = ewb.run() + result = ewb.run_evaluation() # Verify results assert isinstance(result, pd.DataFrame) From 2c2262a1e7d8845159fc65b57bd0209f82984b07 Mon Sep 17 00:00:00 2001 From: taylor Date: Tue, 13 Jan 2026 01:59:12 +0000 Subject: [PATCH 03/13] add run method for backwards compatibility --- src/extremeweatherbench/evaluate.py | 43 +++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/src/extremeweatherbench/evaluate.py b/src/extremeweatherbench/evaluate.py index a318762c..34a12f7d 100644 --- a/src/extremeweatherbench/evaluate.py +++ b/src/extremeweatherbench/evaluate.py @@ -94,6 +94,49 @@ def case_operators(self) -> list["cases.CaseOperator"]: subset_collection = self.case_metadata return cases.build_case_operators(subset_collection, self.evaluation_objects) + def run( + self, + n_jobs: Optional[int] = None, + parallel_config: Optional[dict] = None, + **kwargs, + ) -> pd.DataFrame: + """Runs the ExtremeWeatherBench evaluation workflow. + + This method will run the evaluation workflow in the order of the case operators, + optionally caching the mid-flight outputs of the workflow if cache_dir was + provided for serial runs. + + Args: + n_jobs: The number of jobs to run in parallel. If None, defaults to the + joblib backend default value. If 1, the workflow will run serially. + Ignored if parallel_config is provided. + parallel_config: Optional dictionary of joblib parallel configuration. + If provided, this takes precedence over n_jobs. If not provided and + n_jobs is specified, a default config with threading backend is used. + **kwargs: Additional arguments to pass to compute_case_operator. + Returns: + A concatenated dataframe of the evaluation results. + """ + logger.warning("The run method is deprecated. Use run_evaluation instead.") + logger.info("Running ExtremeWeatherBench evaluations...") + + # Check for serial or parallel configuration + parallel_config = _parallel_serial_config_check(n_jobs, parallel_config) + + run_results = _run_evaluation( + self.case_operators, + cache_dir=self.cache_dir, + parallel_config=parallel_config, + **kwargs, + ) + + # If there are results, concatenate them and return, else return an empty + # DataFrame with the expected columns + if run_results: + return _safe_concat(run_results, ignore_index=True) + else: + return pd.DataFrame(columns=OUTPUT_COLUMNS) + def run_evaluation( self, n_jobs: Optional[int] = None, From b4693576aa33ac0be170a24ffde6958b089e1204 Mon Sep 17 00:00:00 2001 From: taylor Date: Tue, 13 Jan 2026 02:22:45 +0000 Subject: [PATCH 04/13] update tests --- tests/test_evaluate.py | 145 ++++++++++++++++++++++--------------- tests/test_evaluate_cli.py | 30 ++++---- 2 files changed, 101 insertions(+), 74 deletions(-) diff --git a/tests/test_evaluate.py b/tests/test_evaluate.py index 61d8843b..ed8ca93b 100644 --- a/tests/test_evaluate.py +++ b/tests/test_evaluate.py @@ -366,10 +366,11 @@ def test_run_serial_evaluation( result = ewb.run_evaluation(n_jobs=1) - # Serial mode should not pass parallel_config + # Serial mode should pass parallel_config=None mock_run_evaluation.assert_called_once_with( [sample_case_operator], cache_dir=None, + parallel_config=None, ) assert isinstance(result, pd.DataFrame) assert len(result) == 1 @@ -407,7 +408,7 @@ def test_run_parallel_evaluation( mock_run_evaluation.assert_called_once_with( [sample_case_operator], cache_dir=None, - parallel_config={"backend": "threading", "n_jobs": 2}, + parallel_config={"backend": "loky", "n_jobs": 2}, ) assert isinstance(result, pd.DataFrame) assert len(result) == 1 @@ -549,21 +550,24 @@ def test_run_multiple_cases( class TestRunCaseOperators: """Test the _run_evaluation function.""" - @mock.patch("extremeweatherbench.evaluate._run_serial_evaluation") + @mock.patch("extremeweatherbench.evaluate.compute_case_operator") + @mock.patch("tqdm.auto.tqdm") def test_run_evaluation_serial( - self, mock_run_serial_evaluation, sample_case_operator + self, mock_tqdm, mock_compute_case_operator, sample_case_operator ): - """Test _run_evaluation routes to serial execution.""" - mock_results = [pd.DataFrame({"value": [1.0]})] - mock_run_serial_evaluation.return_value = mock_results + """Test _run_evaluation executes serially when parallel_config=None.""" + mock_tqdm.return_value = [sample_case_operator] + mock_results = pd.DataFrame({"value": [1.0]}) + mock_compute_case_operator.return_value = mock_results # Serial mode: don't pass parallel_config result = evaluate._run_evaluation([sample_case_operator], cache_dir=None) - mock_run_serial_evaluation.assert_called_once_with( - [sample_case_operator], cache_dir=None + mock_compute_case_operator.assert_called_once_with( + sample_case_operator, None ) - assert result == mock_results + assert len(result) == 1 + assert result[0].equals(mock_results) @mock.patch("extremeweatherbench.evaluate._run_parallel_evaluation") def test_run_evaluation_parallel( @@ -586,13 +590,15 @@ def test_run_evaluation_parallel( ) assert result == mock_results - @mock.patch("extremeweatherbench.evaluate._run_serial_evaluation") + @mock.patch("extremeweatherbench.evaluate.compute_case_operator") + @mock.patch("tqdm.auto.tqdm") def test_run_evaluation_with_kwargs( - self, mock_run_serial_evaluation, sample_case_operator + self, mock_tqdm, mock_compute_case_operator, sample_case_operator ): - """Test _run_evaluation passes kwargs correctly.""" - mock_results = [pd.DataFrame({"value": [1.0]})] - mock_run_serial_evaluation.return_value = mock_results + """Test _run_evaluation passes kwargs correctly in serial mode.""" + mock_tqdm.return_value = [sample_case_operator] + mock_results = pd.DataFrame({"value": [1.0]}) + mock_compute_case_operator.return_value = mock_results # Serial mode: don't pass parallel_config result = evaluate._run_evaluation( @@ -601,9 +607,9 @@ def test_run_evaluation_with_kwargs( threshold=0.5, ) - call_args = mock_run_serial_evaluation.call_args - assert call_args[0][0] == [sample_case_operator] - assert call_args[1]["cache_dir"] is None + call_args = mock_compute_case_operator.call_args + assert call_args[0][0] == sample_case_operator + assert call_args[0][1] is None # cache_dir assert call_args[1]["threshold"] == 0.5 assert isinstance(result, list) @@ -629,33 +635,28 @@ def test_run_evaluation_parallel_with_kwargs( def test_run_evaluation_empty_list(self): """Test _run_evaluation with empty case operator list.""" - with mock.patch( - "extremeweatherbench.evaluate._run_serial_evaluation" - ) as mock_serial: - mock_serial.return_value = [] - - # Serial mode: don't pass parallel_config - result = evaluate._run_evaluation([], cache_dir=None) - - mock_serial.assert_called_once_with([], cache_dir=None) - assert result == [] + # Serial mode: don't pass parallel_config + result = evaluate._run_evaluation([], cache_dir=None) + assert result == [] class TestRunSerial: - """Test the _run_serial_evaluation function.""" + """Test the serial execution path of _run_evaluation.""" @mock.patch("extremeweatherbench.evaluate.compute_case_operator") @mock.patch("tqdm.auto.tqdm") def test_run_serial_evaluation_basic( self, mock_tqdm, mock_compute_case_operator, sample_case_operator ): - """Test basic _run_serial_evaluation functionality.""" + """Test basic serial execution functionality.""" # Setup mocks mock_tqdm.return_value = [sample_case_operator] # tqdm returns iterable mock_result = pd.DataFrame({"value": [1.0], "case_id_number": [1]}) mock_compute_case_operator.return_value = mock_result - result = evaluate._run_serial_evaluation([sample_case_operator]) + result = evaluate._run_evaluation( + [sample_case_operator], parallel_config=None + ) mock_compute_case_operator.assert_called_once_with(sample_case_operator, None) assert len(result) == 1 @@ -666,7 +667,7 @@ def test_run_serial_evaluation_basic( def test_run_serial_evaluation_multiple_cases( self, mock_tqdm, mock_compute_case_operator ): - """Test _run_serial_evaluation with multiple case operators.""" + """Test serial execution with multiple case operators.""" case_op_1 = mock.Mock() case_op_2 = mock.Mock() case_operators = [case_op_1, case_op_2] @@ -677,7 +678,7 @@ def test_run_serial_evaluation_multiple_cases( pd.DataFrame({"value": [2.0], "case_id_number": [2]}), ] - result = evaluate._run_serial_evaluation(case_operators) + result = evaluate._run_evaluation(case_operators, parallel_config=None) assert mock_compute_case_operator.call_count == 2 assert len(result) == 2 @@ -689,13 +690,16 @@ def test_run_serial_evaluation_multiple_cases( def test_run_serial_evaluation_with_kwargs( self, mock_tqdm, mock_compute_case_operator, sample_case_operator ): - """Test _run_serial_evaluation passes kwargs to compute_case_operator.""" + """Test serial execution passes kwargs to compute_case_operator.""" mock_tqdm.return_value = [sample_case_operator] mock_result = pd.DataFrame({"value": [1.0]}) mock_compute_case_operator.return_value = mock_result - result = evaluate._run_serial_evaluation( - [sample_case_operator], threshold=0.7, custom_param="test" + result = evaluate._run_evaluation( + [sample_case_operator], + parallel_config=None, + threshold=0.7, + custom_param="test", ) call_args = mock_compute_case_operator.call_args @@ -705,8 +709,8 @@ def test_run_serial_evaluation_with_kwargs( assert isinstance(result, list) def test_run_serial_evaluation_empty_list(self): - """Test _run_serial_evaluation with empty case operator list.""" - result = evaluate._run_serial_evaluation([]) + """Test serial execution with empty case operator list.""" + result = evaluate._run_evaluation([], parallel_config=None) assert result == [] @@ -1658,16 +1662,18 @@ def test_evaluate_metric_computation_failure( case_operator=sample_case_operator, ) - @mock.patch("extremeweatherbench.evaluate._run_serial_evaluation") + @mock.patch("extremeweatherbench.evaluate.compute_case_operator") + @mock.patch("tqdm.auto.tqdm") def test_run_evaluation_serial_exception( - self, mock_run_serial_evaluation, sample_case_operator + self, mock_tqdm, mock_compute_case_operator, sample_case_operator ): """Test _run_evaluation handles exceptions in serial execution.""" - mock_run_serial_evaluation.side_effect = Exception("Serial execution failed") + mock_tqdm.return_value = [sample_case_operator] + mock_compute_case_operator.side_effect = Exception("Serial execution failed") with pytest.raises(Exception, match="Serial execution failed"): # Serial mode: don't pass parallel_config - evaluate._run_evaluation([sample_case_operator], None) + evaluate._run_evaluation([sample_case_operator], parallel_config=None) @mock.patch("extremeweatherbench.evaluate._run_parallel_evaluation") def test_run_evaluation_parallel_exception( @@ -1689,12 +1695,12 @@ def test_run_evaluation_parallel_exception( def test_run_serial_evaluation_case_operator_exception( self, mock_tqdm, mock_compute_case_operator, sample_case_operator ): - """Test _run_serial_evaluation handles exceptions from individual case operators.""" + """Test serial execution handles exceptions from individual case operators.""" mock_tqdm.return_value = [sample_case_operator] mock_compute_case_operator.side_effect = Exception("Case operator failed") with pytest.raises(Exception, match="Case operator failed"): - evaluate._run_serial_evaluation([sample_case_operator]) + evaluate._run_evaluation([sample_case_operator], parallel_config=None) @mock.patch("extremeweatherbench.utils.ParallelTqdm") @mock.patch("joblib.delayed") @@ -1764,7 +1770,7 @@ def test_run_method_exception_propagation( def test_run_serial_evaluation_partial_failure( self, mock_tqdm, mock_compute_case_operator ): - """Test _run_serial_evaluation behavior when some case operators fail.""" + """Test serial execution behavior when some case operators fail.""" case_op_1 = mock.Mock() case_op_2 = mock.Mock() case_op_3 = mock.Mock() @@ -1781,7 +1787,7 @@ def test_run_serial_evaluation_partial_failure( # Should fail on the second case operator with pytest.raises(Exception, match="Case operator 2 failed"): - evaluate._run_serial_evaluation(case_operators) + evaluate._run_evaluation(case_operators, parallel_config=None) # Should have tried only the first two assert mock_compute_case_operator.call_count == 2 @@ -2048,12 +2054,16 @@ def test_serial_vs_parallel_results_consistency( assert list(serial_result.columns) == list(parallel_result.columns) @mock.patch("extremeweatherbench.evaluate.compute_case_operator") - def test_execution_method_performance_comparison(self, mock_compute_case_operator): + @mock.patch("tqdm.auto.tqdm") + def test_execution_method_performance_comparison( + self, mock_tqdm, mock_compute_case_operator + ): """Test that both execution methods handle the same workload.""" import time # Create many case operators to simulate realistic workload case_operators = [mock.Mock() for _ in range(10)] + mock_tqdm.return_value = case_operators # Mock results mock_results = [ @@ -2068,10 +2078,12 @@ def test_execution_method_performance_comparison(self, mock_compute_case_operato for i in range(10) ] - # Test serial execution timing - call _run_serial_evaluation directly + # Test serial execution timing - call _run_evaluation in serial mode mock_compute_case_operator.side_effect = mock_results start_time = time.time() - serial_result = evaluate._run_serial_evaluation(case_operators) + serial_result = evaluate._run_evaluation( + case_operators, parallel_config=None + ) serial_time = time.time() - start_time # Test parallel execution timing - call _run_parallel_evaluation directly with mocked @@ -2104,9 +2116,13 @@ def test_execution_method_performance_comparison(self, mock_compute_case_operato assert parallel_time >= 0 @mock.patch("extremeweatherbench.evaluate.compute_case_operator") - def test_mixed_execution_parameters(self, mock_compute_case_operator): + @mock.patch("tqdm.auto.tqdm") + def test_mixed_execution_parameters( + self, mock_tqdm, mock_compute_case_operator + ): """Test various parameter combinations for execution methods.""" case_operators = [mock.Mock(), mock.Mock()] + mock_tqdm.return_value = case_operators mock_results = [ pd.DataFrame({"value": [1.0], "case_id_number": [1]}), pd.DataFrame({"value": [2.0], "case_id_number": [2]}), @@ -2129,7 +2145,9 @@ def test_mixed_execution_parameters(self, mock_compute_case_operator): mock_compute_case_operator.side_effect = mock_results if config["method"] == "serial": - result = evaluate._run_serial_evaluation(*config["args"]) + result = evaluate._run_evaluation( + *config["args"], parallel_config=None + ) # All configurations should produce valid results assert isinstance(result, list) assert len(result) == 2 @@ -2178,10 +2196,13 @@ def mock_compute_with_kwargs(case_op, cache_dir, **kwargs): with mock.patch( "extremeweatherbench.evaluate.compute_case_operator", side_effect=mock_compute_with_kwargs, - ): + ), mock.patch("tqdm.auto.tqdm", return_value=[case_operator]): # Test serial kwargs propagation - result = evaluate._run_serial_evaluation( - [case_operator], custom_param="serial_test", threshold=0.9 + result = evaluate._run_evaluation( + [case_operator], + parallel_config=None, + custom_param="serial_test", + threshold=0.9, ) captured = mock_compute_with_kwargs.captured_kwargs @@ -2217,7 +2238,7 @@ def mock_compute_with_kwargs(case_op, cache_dir, **kwargs): def test_empty_case_operators_all_methods(self): """Test that all execution methods handle empty case operator lists.""" - # Test _run_evaluation + # Test _run_evaluation with parallel config result = evaluate._run_evaluation([], parallel_config={"n_jobs": 1}) assert result == [] @@ -2226,8 +2247,8 @@ def test_empty_case_operators_all_methods(self): ) assert result == [] - # Test _run_serial_evaluation - result = evaluate._run_serial_evaluation([]) + # Test _run_evaluation in serial mode + result = evaluate._run_evaluation([], parallel_config=None) assert result == [] # Test _run_parallel_evaluation @@ -2244,11 +2265,15 @@ def test_empty_case_operators_all_methods(self): assert result == [] @mock.patch("extremeweatherbench.evaluate.compute_case_operator") - def test_large_case_operator_list_handling(self, mock_compute_case_operator): + @mock.patch("tqdm.auto.tqdm") + def test_large_case_operator_list_handling( + self, mock_tqdm, mock_compute_case_operator + ): """Test handling of large numbers of case operators.""" # Create a large list of case operators num_cases = 100 case_operators = [mock.Mock() for _ in range(num_cases)] + mock_tqdm.return_value = case_operators # Create mock results mock_results = [ @@ -2260,7 +2285,9 @@ def test_large_case_operator_list_handling(self, mock_compute_case_operator): # Test serial execution mock_compute_case_operator.side_effect = mock_results - serial_results = evaluate._run_serial_evaluation(case_operators) + serial_results = evaluate._run_evaluation( + case_operators, parallel_config=None + ) assert len(serial_results) == num_cases assert mock_compute_case_operator.call_count == num_cases diff --git a/tests/test_evaluate_cli.py b/tests/test_evaluate_cli.py index 42c586be..338cbcbd 100644 --- a/tests/test_evaluate_cli.py +++ b/tests/test_evaluate_cli.py @@ -87,7 +87,7 @@ def test_default_mode_basic( # Mock the ExtremeWeatherBench class and its methods mock_ewb = mock.Mock() mock_ewb.case_operators = [mock.Mock(), mock.Mock()] # Mock 2 case operators - mock_ewb.run.return_value = pd.DataFrame({"test": [1, 2]}) + mock_ewb.run_evaluation.return_value = pd.DataFrame({"test": [1, 2]}) mock_ewb_class.return_value = mock_ewb # Mock loading default cases @@ -99,7 +99,7 @@ def test_default_mode_basic( assert result.exit_code == 0 mock_ewb_class.assert_called_once() - mock_ewb.run.assert_called_once() + mock_ewb.run_evaluation.assert_called_once() @mock.patch( "extremeweatherbench.defaults.get_brightband_evaluation_objects", @@ -118,7 +118,7 @@ def test_default_mode_with_cache_dir( """Test default mode with cache directory.""" mock_ewb = mock.Mock() mock_ewb.case_operators = [] - mock_ewb.run.return_value = pd.DataFrame() + mock_ewb.run_evaluation.return_value = pd.DataFrame() mock_ewb_class.return_value = mock_ewb mock_load_cases.return_value = {"cases": []} @@ -144,7 +144,7 @@ def test_config_file_mode_basic( """Test basic config file mode execution.""" mock_ewb = mock.Mock() mock_ewb.case_operators = [mock.Mock()] - mock_ewb.run.return_value = pd.DataFrame({"test": [1]}) + mock_ewb.run_evaluation.return_value = pd.DataFrame({"test": [1]}) mock_ewb_class.return_value = mock_ewb result = runner.invoke( @@ -218,15 +218,15 @@ def test_parallel_execution( """Test parallel execution mode.""" mock_ewb = mock.Mock() mock_ewb.case_operators = [mock.Mock(), mock.Mock(), mock.Mock()] - mock_ewb.run.return_value = pd.DataFrame({"test": [1, 2, 3]}) + mock_ewb.run_evaluation.return_value = pd.DataFrame({"test": [1, 2, 3]}) mock_ewb_class.return_value = mock_ewb mock_load_cases.return_value = {"cases": []} result = runner.invoke(evaluate_cli.cli_runner, ["--default", "--n-jobs", "3"]) assert result.exit_code == 0 - # Verify ewb.run was called with parallel config - mock_ewb.run.assert_called_once_with( + # Verify ewb.run_evaluation was called with parallel config + mock_ewb.run_evaluation.assert_called_once_with( n_jobs=3, parallel_config=None, ) @@ -243,7 +243,7 @@ def test_serial_execution_default( """Test that serial execution is default (parallel=1).""" mock_ewb = mock.Mock() mock_ewb.case_operators = [] - mock_ewb.run.return_value = pd.DataFrame() + mock_ewb.run_evaluation.return_value = pd.DataFrame() mock_ewb_class.return_value = mock_ewb mock_load_cases.return_value = {"cases": []} @@ -251,7 +251,7 @@ def test_serial_execution_default( assert result.exit_code == 0 # Output suppressed - only check exit code - mock_ewb.run.assert_called_once() + mock_ewb.run_evaluation.assert_called_once() class TestCaseOperatorSaving: @@ -277,7 +277,7 @@ def test_save_case_operators( mock_case_op2 = {"id": 2, "type": "test_case_op"} mock_ewb = mock.Mock() mock_ewb.case_operators = [mock_case_op1, mock_case_op2] - mock_ewb.run.return_value = pd.DataFrame() + mock_ewb.run_evaluation.return_value = pd.DataFrame() mock_ewb_class.return_value = mock_ewb mock_load_cases.return_value = {"cases": []} @@ -316,7 +316,7 @@ def test_save_case_operators_creates_directory( """Test that saving case operators creates parent directories.""" mock_ewb = mock.Mock() mock_ewb.case_operators = [] - mock_ewb.run.return_value = pd.DataFrame() + mock_ewb.run_evaluation.return_value = pd.DataFrame() mock_ewb_class.return_value = mock_ewb mock_load_cases.return_value = {"cases": []} @@ -368,7 +368,7 @@ def test_output_directory_creation( """Test that output directory is created if it doesn't exist.""" mock_ewb = mock.Mock() mock_ewb.case_operators = [] - mock_ewb.run.return_value = pd.DataFrame() + mock_ewb.run_evaluation.return_value = pd.DataFrame() mock_ewb_class.return_value = mock_ewb mock_load_cases.return_value = {"cases": []} @@ -394,7 +394,7 @@ def test_default_output_directory( """Test that default output directory is current working directory.""" mock_ewb = mock.Mock() mock_ewb.case_operators = [] - mock_ewb.run.return_value = pd.DataFrame() + mock_ewb.run_evaluation.return_value = pd.DataFrame() mock_ewb_class.return_value = mock_ewb mock_load_cases.return_value = {"cases": []} @@ -434,7 +434,7 @@ def test_results_saved_to_csv( mock_ewb = mock.Mock() mock_ewb.case_operators = [] - mock_ewb.run.return_value = mock_results + mock_ewb.run_evaluation.return_value = mock_results mock_ewb_class.return_value = mock_ewb mock_load_cases.return_value = {"cases": []} @@ -455,7 +455,7 @@ def test_empty_results_handling(self, mock_ewb_class, mock_load_cases, runner): """Test handling when no results are returned.""" mock_ewb = mock.Mock() mock_ewb.case_operators = [] - mock_ewb.run.return_value = pd.DataFrame() # Empty results + mock_ewb.run_evaluation.return_value = pd.DataFrame() # Empty results mock_ewb_class.return_value = mock_ewb mock_load_cases.return_value = {"cases": []} From 5654abb986d03114976088e0e82cca5260706d03 Mon Sep 17 00:00:00 2001 From: taylor Date: Tue, 13 Jan 2026 15:15:38 +0000 Subject: [PATCH 05/13] add tests and cover if serial and parallel_config is not None --- src/extremeweatherbench/evaluate.py | 4 +- tests/test_evaluate.py | 88 +++++++++++++++++++++-------- 2 files changed, 69 insertions(+), 23 deletions(-) diff --git a/src/extremeweatherbench/evaluate.py b/src/extremeweatherbench/evaluate.py index 34a12f7d..70ca3288 100644 --- a/src/extremeweatherbench/evaluate.py +++ b/src/extremeweatherbench/evaluate.py @@ -214,7 +214,9 @@ def _parallel_serial_config_check( n_jobs, ) parallel_config = {"backend": "loky", "n_jobs": n_jobs} - + # If running in serial mode, set parallel_config to None if not already + else: + parallel_config = None # Return the maybe updated kwargs return parallel_config diff --git a/tests/test_evaluate.py b/tests/test_evaluate.py index ed8ca93b..9f236129 100644 --- a/tests/test_evaluate.py +++ b/tests/test_evaluate.py @@ -563,9 +563,7 @@ def test_run_evaluation_serial( # Serial mode: don't pass parallel_config result = evaluate._run_evaluation([sample_case_operator], cache_dir=None) - mock_compute_case_operator.assert_called_once_with( - sample_case_operator, None - ) + mock_compute_case_operator.assert_called_once_with(sample_case_operator, None) assert len(result) == 1 assert result[0].equals(mock_results) @@ -654,9 +652,7 @@ def test_run_serial_evaluation_basic( mock_result = pd.DataFrame({"value": [1.0], "case_id_number": [1]}) mock_compute_case_operator.return_value = mock_result - result = evaluate._run_evaluation( - [sample_case_operator], parallel_config=None - ) + result = evaluate._run_evaluation([sample_case_operator], parallel_config=None) mock_compute_case_operator.assert_called_once_with(sample_case_operator, None) assert len(result) == 1 @@ -2081,9 +2077,7 @@ def test_execution_method_performance_comparison( # Test serial execution timing - call _run_evaluation in serial mode mock_compute_case_operator.side_effect = mock_results start_time = time.time() - serial_result = evaluate._run_evaluation( - case_operators, parallel_config=None - ) + serial_result = evaluate._run_evaluation(case_operators, parallel_config=None) serial_time = time.time() - start_time # Test parallel execution timing - call _run_parallel_evaluation directly with mocked @@ -2117,9 +2111,7 @@ def test_execution_method_performance_comparison( @mock.patch("extremeweatherbench.evaluate.compute_case_operator") @mock.patch("tqdm.auto.tqdm") - def test_mixed_execution_parameters( - self, mock_tqdm, mock_compute_case_operator - ): + def test_mixed_execution_parameters(self, mock_tqdm, mock_compute_case_operator): """Test various parameter combinations for execution methods.""" case_operators = [mock.Mock(), mock.Mock()] mock_tqdm.return_value = case_operators @@ -2145,9 +2137,7 @@ def test_mixed_execution_parameters( mock_compute_case_operator.side_effect = mock_results if config["method"] == "serial": - result = evaluate._run_evaluation( - *config["args"], parallel_config=None - ) + result = evaluate._run_evaluation(*config["args"], parallel_config=None) # All configurations should produce valid results assert isinstance(result, list) assert len(result) == 2 @@ -2193,10 +2183,13 @@ def mock_compute_with_kwargs(case_op, cache_dir, **kwargs): mock_compute_with_kwargs.captured_kwargs = {} - with mock.patch( - "extremeweatherbench.evaluate.compute_case_operator", - side_effect=mock_compute_with_kwargs, - ), mock.patch("tqdm.auto.tqdm", return_value=[case_operator]): + with ( + mock.patch( + "extremeweatherbench.evaluate.compute_case_operator", + side_effect=mock_compute_with_kwargs, + ), + mock.patch("tqdm.auto.tqdm", return_value=[case_operator]), + ): # Test serial kwargs propagation result = evaluate._run_evaluation( [case_operator], @@ -2285,9 +2278,7 @@ def test_large_case_operator_list_handling( # Test serial execution mock_compute_case_operator.side_effect = mock_results - serial_results = evaluate._run_evaluation( - case_operators, parallel_config=None - ) + serial_results = evaluate._run_evaluation(case_operators, parallel_config=None) assert len(serial_results) == num_cases assert mock_compute_case_operator.call_count == num_cases @@ -2768,5 +2759,58 @@ def test_compute_case_operator_no_output_variables(self, sample_individual_case) assert all(result["target_variable"] == "MockDerivedVariableWithOutputs") +class TestParallelSerialConfigCheck: + def test_parallel_serial_config_check_serial(self): + """Test that the parallel_serial_config_check returns None for serial mode. + + If n_jobs == 1 in any of the arguments, parallel_config should always be + None.""" + assert evaluate._parallel_serial_config_check(n_jobs=1) is None + assert ( + evaluate._parallel_serial_config_check(parallel_config={"n_jobs": 1}) + is None + ) + assert ( + evaluate._parallel_serial_config_check( + n_jobs=None, parallel_config={"n_jobs": 1} + ) + is None + ) + assert ( + evaluate._parallel_serial_config_check( + n_jobs=None, parallel_config={"n_jobs": 1} + ) + is None + ) + assert ( + evaluate._parallel_serial_config_check( + n_jobs=None, parallel_config={"backend": "threading", "n_jobs": 1} + ) + is None + ) + + def test_parallel_serial_config_check_parallel(self): + """Test that the parallel_serial_config_check returns a dictionary for parallel mode.""" + assert evaluate._parallel_serial_config_check(n_jobs=2) == { + "backend": "loky", + "n_jobs": 2, + } + assert evaluate._parallel_serial_config_check( + parallel_config={"backend": "threading", "n_jobs": 2} + ) == {"backend": "threading", "n_jobs": 2} + assert evaluate._parallel_serial_config_check( + n_jobs=2, parallel_config={"backend": "threading", "n_jobs": 2} + ) == {"backend": "threading", "n_jobs": 2} + assert evaluate._parallel_serial_config_check( + n_jobs=2, parallel_config={"backend": "threading", "n_jobs": 2} + ) == {"backend": "threading", "n_jobs": 2} + assert evaluate._parallel_serial_config_check( + n_jobs=2, parallel_config={"backend": "threading", "n_jobs": 2} + ) == {"backend": "threading", "n_jobs": 2} + assert evaluate._parallel_serial_config_check( + n_jobs=2, parallel_config={"backend": "threading", "n_jobs": 2} + ) == {"backend": "threading", "n_jobs": 2} + + if __name__ == "__main__": pytest.main([__file__]) From 90823181aed5789954053efdb419748feab8bc7a Mon Sep 17 00:00:00 2001 From: taylor Date: Tue, 13 Jan 2026 18:01:38 +0000 Subject: [PATCH 06/13] feat: redesign public API with hierarchical namespace submodules - Add ewb.evaluation() as main entry point (alias for ExtremeWeatherBench) - Create namespace submodules: ewb.targets, ewb.forecasts, ewb.metrics, ewb.derived, ewb.regions, ewb.cases, ewb.defaults - Expose all classes at top level for convenience (ewb.ERA5, etc.) - Add ewb.load_cases() convenience alias - Update all example files to use new import pattern - Update usage.md documentation - Maintain backward compatibility with existing imports --- data_prep/ar_bounds.py | 6 +- data_prep/ibtracs_bounds.py | 4 +- .../practically_perfect_hindcast_from_lsr.py | 3 +- data_prep/severe_convection_bounds.py | 3 +- data_prep/subset_heat_cold_events.py | 3 +- docs/examples/applied_ar.py | 54 +-- docs/examples/applied_freeze.py | 30 +- docs/examples/applied_heatwave.py | 30 +- docs/examples/applied_severe.py | 36 +- docs/examples/applied_tc.py | 43 +-- docs/examples/example_config.py | 19 +- docs/usage.md | 87 +++-- src/extremeweatherbench/__init__.py | 343 ++++++++++++++++++ src/extremeweatherbench/cases.py | 5 +- src/extremeweatherbench/defaults.py | 2 +- src/extremeweatherbench/derived.py | 2 +- src/extremeweatherbench/evaluate.py | 50 +-- src/extremeweatherbench/evaluate_cli.py | 4 +- src/extremeweatherbench/inputs.py | 7 +- src/extremeweatherbench/regions.py | 4 +- src/extremeweatherbench/sources/base.py | 2 +- .../sources/pandas_dataframe.py | 4 +- .../sources/polars_lazyframe.py | 4 +- .../sources/xarray_dataarray.py | 3 +- .../sources/xarray_dataset.py | 2 +- 25 files changed, 551 insertions(+), 199 deletions(-) diff --git a/data_prep/ar_bounds.py b/data_prep/ar_bounds.py index a21a0c87..15b8abfd 100644 --- a/data_prep/ar_bounds.py +++ b/data_prep/ar_bounds.py @@ -17,7 +17,11 @@ from dask.distributed import Client from matplotlib.patches import Rectangle -from extremeweatherbench import cases, derived, inputs, regions, utils +import extremeweatherbench.cases as cases +import extremeweatherbench.derived as derived +import extremeweatherbench.inputs as inputs +import extremeweatherbench.regions as regions +import extremeweatherbench.utils as utils from extremeweatherbench.events import atmospheric_river as ar logging.basicConfig() diff --git a/data_prep/ibtracs_bounds.py b/data_prep/ibtracs_bounds.py index cf159250..84d193a0 100644 --- a/data_prep/ibtracs_bounds.py +++ b/data_prep/ibtracs_bounds.py @@ -15,7 +15,9 @@ from matplotlib.patches import Rectangle import extremeweatherbench.data -from extremeweatherbench import inputs, regions, utils +import extremeweatherbench.inputs as inputs +import extremeweatherbench.regions as regions +import extremeweatherbench.utils as utils logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) diff --git a/data_prep/practically_perfect_hindcast_from_lsr.py b/data_prep/practically_perfect_hindcast_from_lsr.py index db114242..1e96cbda 100644 --- a/data_prep/practically_perfect_hindcast_from_lsr.py +++ b/data_prep/practically_perfect_hindcast_from_lsr.py @@ -11,7 +11,8 @@ from scipy.ndimage import gaussian_filter from tqdm.auto import tqdm -from extremeweatherbench import inputs, utils +import extremeweatherbench.inputs as inputs +import extremeweatherbench.utils as utils def sparse_practically_perfect_hindcast( diff --git a/data_prep/severe_convection_bounds.py b/data_prep/severe_convection_bounds.py index afdf803a..3224c251 100644 --- a/data_prep/severe_convection_bounds.py +++ b/data_prep/severe_convection_bounds.py @@ -17,7 +17,8 @@ import yaml from scipy.ndimage import label -from extremeweatherbench import calc, cases +import extremeweatherbench.calc as calc +import extremeweatherbench.cases as cases # Radius of Earth in km (mean radius) EARTH_RADIUS_KM = 6371.0 diff --git a/data_prep/subset_heat_cold_events.py b/data_prep/subset_heat_cold_events.py index e109889b..ee6b4add 100644 --- a/data_prep/subset_heat_cold_events.py +++ b/data_prep/subset_heat_cold_events.py @@ -13,7 +13,8 @@ from matplotlib import dates as mdates from mpl_toolkits.axes_grid1 import make_axes_locatable -from extremeweatherbench import cases, utils +import extremeweatherbench.cases as cases +import extremeweatherbench.utils as utils sns.set_theme(style="whitegrid", context="talk") diff --git a/docs/examples/applied_ar.py b/docs/examples/applied_ar.py index 08c4ad59..9fa749e5 100644 --- a/docs/examples/applied_ar.py +++ b/docs/examples/applied_ar.py @@ -4,7 +4,7 @@ import numpy as np import xarray as xr -from extremeweatherbench import cases, derived, evaluate, inputs, metrics +import extremeweatherbench as ewb # %% @@ -39,85 +39,85 @@ def _preprocess_bb_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: # Load case data from the default events.yaml # Users can also define their own cases_dict structure -case_yaml = cases.load_ewb_events_yaml_into_case_collection() +case_yaml = ewb.load_cases() case_yaml = case_yaml.select_cases(by="case_id_number", value=114) case_yaml.cases[0].start_date = datetime.datetime(2022, 12, 27, 11, 0, 0) case_yaml.cases[0].end_date = datetime.datetime(2022, 12, 27, 13, 0, 0) # Define ERA5 target -era5_target = inputs.ERA5( +era5_target = ewb.targets.ERA5( variables=[ - derived.AtmosphericRiverVariables( + ewb.derived.AtmosphericRiverVariables( output_variables=["atmospheric_river_land_intersection"] ) ], ) # Define forecast (HRES) -hres_forecast = inputs.ZarrForecast( +hres_forecast = ewb.forecasts.ZarrForecast( source="gs://weatherbench2/datasets/hres/2016-2022-0012-1440x721.zarr", name="HRES", variables=[ - derived.AtmosphericRiverVariables( + ewb.derived.AtmosphericRiverVariables( output_variables=["atmospheric_river_land_intersection"] ) ], - variable_mapping=inputs.HRES_metadata_variable_mapping, + variable_mapping=ewb.HRES_metadata_variable_mapping, ) -grap_forecast = inputs.KerchunkForecast( +grap_forecast = ewb.forecasts.KerchunkForecast( name="Graphcast", source="gs://extremeweatherbench/GRAP_v100_IFS.parq", variables=[ - derived.AtmosphericRiverVariables( + ewb.derived.AtmosphericRiverVariables( output_variables=["atmospheric_river_land_intersection"] ) ], - variable_mapping=inputs.CIRA_metadata_variable_mapping, + variable_mapping=ewb.CIRA_metadata_variable_mapping, storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, preprocess=_preprocess_bb_cira_forecast_dataset, ) -pang_forecast = inputs.KerchunkForecast( +pang_forecast = ewb.forecasts.KerchunkForecast( name="Pangu", source="gs://extremeweatherbench/PANG_v100_IFS.parq", variables=[ - derived.AtmosphericRiverVariables( + ewb.derived.AtmosphericRiverVariables( output_variables=["atmospheric_river_land_intersection"] ) ], - variable_mapping=inputs.CIRA_metadata_variable_mapping, + variable_mapping=ewb.CIRA_metadata_variable_mapping, storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, preprocess=_preprocess_bb_cira_forecast_dataset, ) # Create a list of evaluation objects for atmospheric river ar_evaluation_objects = [ - inputs.EvaluationObject( + ewb.EvaluationObject( event_type="atmospheric_river", metric_list=[ - metrics.CriticalSuccessIndex(), - metrics.EarlySignal(), - metrics.SpatialDisplacement(), + ewb.metrics.CriticalSuccessIndex(), + ewb.metrics.EarlySignal(), + ewb.metrics.SpatialDisplacement(), ], target=era5_target, forecast=hres_forecast, ), - inputs.EvaluationObject( + ewb.EvaluationObject( event_type="atmospheric_river", metric_list=[ - metrics.CriticalSuccessIndex(), - metrics.EarlySignal(), - metrics.SpatialDisplacement(), + ewb.metrics.CriticalSuccessIndex(), + ewb.metrics.EarlySignal(), + ewb.metrics.SpatialDisplacement(), ], target=era5_target, forecast=grap_forecast, ), - inputs.EvaluationObject( + ewb.EvaluationObject( event_type="atmospheric_river", metric_list=[ - metrics.CriticalSuccessIndex(), - metrics.EarlySignal(), - metrics.SpatialDisplacement(), + ewb.metrics.CriticalSuccessIndex(), + ewb.metrics.EarlySignal(), + ewb.metrics.SpatialDisplacement(), ], target=era5_target, forecast=pang_forecast, @@ -127,13 +127,13 @@ def _preprocess_bb_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: if __name__ == "__main__": # Initialize ExtremeWeatherBench; will only run on cases with event_type # atmospheric_river - ar_ewb = evaluate.ExtremeWeatherBench( + ar_ewb = ewb.evaluation( case_metadata=case_yaml, evaluation_objects=ar_evaluation_objects, ) # Run the workflow using 3 jobs - outputs = ar_ewb.run_evaluation(parallel_config={"backend": "loky", "n_jobs": 3}) + outputs = ar_ewb.run(parallel_config={"backend": "loky", "n_jobs": 3}) # Save the evaluation outputs to a csv file outputs.to_csv("ar_signal_outputs.csv") diff --git a/docs/examples/applied_freeze.py b/docs/examples/applied_freeze.py index a430bd39..864d76f7 100644 --- a/docs/examples/applied_freeze.py +++ b/docs/examples/applied_freeze.py @@ -1,7 +1,7 @@ import logging import operator -from extremeweatherbench import cases, defaults, evaluate, inputs, metrics +import extremeweatherbench as ewb # Set the logger level to INFO logger = logging.getLogger("extremeweatherbench") @@ -9,47 +9,47 @@ # Load case data from the default events.yaml # Users can also define their own cases_dict structure -case_yaml = cases.load_ewb_events_yaml_into_case_collection() +case_yaml = ewb.load_cases() # Define targets # ERA5 target -era5_freeze_target = inputs.ERA5( +era5_freeze_target = ewb.targets.ERA5( variables=["surface_air_temperature"], chunks=None, ) # GHCN target -ghcn_freeze_target = inputs.GHCN(variables=["surface_air_temperature"]) +ghcn_freeze_target = ewb.targets.GHCN(variables=["surface_air_temperature"]) # Define forecast (FCNv2 CIRA Virtualizarr) -fcnv2_forecast = inputs.KerchunkForecast( +fcnv2_forecast = ewb.forecasts.KerchunkForecast( name="fcnv2_forecast", source="gs://extremeweatherbench/FOUR_v200_GFS.parq", variables=["surface_air_temperature"], - variable_mapping=inputs.CIRA_metadata_variable_mapping, + variable_mapping=ewb.CIRA_metadata_variable_mapping, storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, - preprocess=defaults._preprocess_bb_cira_forecast_dataset, + preprocess=ewb.defaults._preprocess_bb_cira_forecast_dataset, ) # Load the climatology for DurationMeanError -climatology = defaults.get_climatology(quantile=0.85) +climatology = ewb.get_climatology(quantile=0.85) # Define the metrics metrics_list = [ - metrics.RootMeanSquaredError(), - metrics.MinimumMeanAbsoluteError(), - metrics.DurationMeanError(threshold_criteria=climatology, op_func=operator.le), + ewb.metrics.RootMeanSquaredError(), + ewb.metrics.MinimumMeanAbsoluteError(), + ewb.metrics.DurationMeanError(threshold_criteria=climatology, op_func=operator.le), ] # Create a list of evaluation objects for freeze freeze_evaluation_object = [ - inputs.EvaluationObject( + ewb.EvaluationObject( event_type="freeze", metric_list=metrics_list, target=ghcn_freeze_target, forecast=fcnv2_forecast, ), - inputs.EvaluationObject( + ewb.EvaluationObject( event_type="freeze", metric_list=metrics_list, target=era5_freeze_target, @@ -59,13 +59,13 @@ if __name__ == "__main__": # Initialize ExtremeWeatherBench runner instance - ewb = evaluate.ExtremeWeatherBench( + freeze_ewb = ewb.evaluation( case_metadata=case_yaml, evaluation_objects=freeze_evaluation_object, ) # Run the workflow - outputs = ewb.run_evaluation(parallel_config={"backend": "loky", "n_jobs": 1}) + outputs = freeze_ewb.run(parallel_config={"backend": "loky", "n_jobs": 1}) # Print the outputs; can be saved if desired outputs.to_csv("freeze_outputs.csv") diff --git a/docs/examples/applied_heatwave.py b/docs/examples/applied_heatwave.py index 9f7a7dd1..22c9b809 100644 --- a/docs/examples/applied_heatwave.py +++ b/docs/examples/applied_heatwave.py @@ -1,7 +1,7 @@ import logging import operator -from extremeweatherbench import cases, defaults, evaluate, inputs, metrics +import extremeweatherbench as ewb # Set the logger level to INFO logger = logging.getLogger("extremeweatherbench") @@ -9,48 +9,48 @@ # Load case data from the default events.yaml # Users can also define their own cases_dict structure -case_yaml = cases.load_ewb_events_yaml_into_case_collection() +case_yaml = ewb.load_cases() # Define targets # ERA5 target -era5_heatwave_target = inputs.ERA5( +era5_heatwave_target = ewb.targets.ERA5( variables=["surface_air_temperature"], chunks=None, ) # GHCN target -ghcn_heatwave_target = inputs.GHCN( +ghcn_heatwave_target = ewb.targets.GHCN( variables=["surface_air_temperature"], ) # Define forecast (HRES) -hres_forecast = inputs.ZarrForecast( +hres_forecast = ewb.forecasts.ZarrForecast( name="hres_forecast", source="gs://weatherbench2/datasets/hres/2016-2022-0012-1440x721.zarr", variables=["surface_air_temperature"], - variable_mapping=inputs.HRES_metadata_variable_mapping, + variable_mapping=ewb.HRES_metadata_variable_mapping, ) # Load the climatology for DurationMeanError -climatology = defaults.get_climatology(quantile=0.85) +climatology = ewb.get_climatology(quantile=0.85) # Define the metrics metrics_list = [ - metrics.MaximumMeanAbsoluteError(), - metrics.RootMeanSquaredError(), - metrics.DurationMeanError(threshold_criteria=climatology, op_func=operator.ge), - metrics.MaximumLowestMeanAbsoluteError(), + ewb.metrics.MaximumMeanAbsoluteError(), + ewb.metrics.RootMeanSquaredError(), + ewb.metrics.DurationMeanError(threshold_criteria=climatology, op_func=operator.ge), + ewb.metrics.MaximumLowestMeanAbsoluteError(), ] # Create a list of evaluation objects for heatwave heatwave_evaluation_object = [ - inputs.EvaluationObject( + ewb.EvaluationObject( event_type="heat_wave", metric_list=metrics_list, target=ghcn_heatwave_target, forecast=hres_forecast, ), - inputs.EvaluationObject( + ewb.EvaluationObject( event_type="heat_wave", metric_list=metrics_list, target=era5_heatwave_target, @@ -59,11 +59,11 @@ ] if __name__ == "__main__": # Initialize ExtremeWeatherBench - ewb = evaluate.ExtremeWeatherBench( + heatwave_ewb = ewb.evaluation( case_metadata=case_yaml, evaluation_objects=heatwave_evaluation_object, ) # Run the workflow - outputs = ewb.run_evaluation(parallel_config={"backend": "loky", "n_jobs": 2}) + outputs = heatwave_ewb.run(parallel_config={"backend": "loky", "n_jobs": 2}) outputs.to_csv("applied_heatwave_outputs.csv") diff --git a/docs/examples/applied_severe.py b/docs/examples/applied_severe.py index 40c00770..18ad624c 100644 --- a/docs/examples/applied_severe.py +++ b/docs/examples/applied_severe.py @@ -1,6 +1,6 @@ import logging -from extremeweatherbench import cases, derived, evaluate, inputs, metrics +import extremeweatherbench as ewb # Set the logger level to INFO logger = logging.getLogger("extremeweatherbench") @@ -8,45 +8,45 @@ # Load case data from the default events.yaml -case_yaml = cases.load_ewb_events_yaml_into_case_collection() +case_yaml = ewb.load_cases() case_yaml.select_cases("case_id_number", 305, inplace=True) # Define PPH target -pph_target = inputs.PPH( +pph_target = ewb.targets.PPH( variables=["practically_perfect_hindcast"], ) # Define LSR target -lsr_target = inputs.LSR() +lsr_target = ewb.targets.LSR() # Define HRES forecast -hres_forecast = inputs.ZarrForecast( +hres_forecast = ewb.forecasts.ZarrForecast( name="hres_forecast", source="gs://weatherbench2/datasets/hres/2016-2022-0012-1440x721.zarr", - variables=[derived.CravenBrooksSignificantSevere(layer_depth=100)], - variable_mapping=inputs.HRES_metadata_variable_mapping, + variables=[ewb.derived.CravenBrooksSignificantSevere(layer_depth=100)], + variable_mapping=ewb.HRES_metadata_variable_mapping, storage_options={"remote_options": {"anon": True}}, ) # Define pph metrics as thresholdmetric to share scores contingency table pph_metrics = [ - metrics.ThresholdMetric( + ewb.metrics.ThresholdMetric( metrics=[ - metrics.CriticalSuccessIndex, - metrics.FalseAlarmRatio, + ewb.metrics.CriticalSuccessIndex, + ewb.metrics.FalseAlarmRatio, ], forecast_threshold=15000, target_threshold=0.3, ), - metrics.EarlySignal(threshold=15000), + ewb.metrics.EarlySignal(threshold=15000), ] # Define LSR metrics as thresholdmetric to share scores contingency table lsr_metrics = [ - metrics.ThresholdMetric( + ewb.metrics.ThresholdMetric( metrics=[ - metrics.TruePositives, - metrics.FalseNegatives, + ewb.metrics.TruePositives, + ewb.metrics.FalseNegatives, ], forecast_threshold=15000, target_threshold=0.5, @@ -56,7 +56,7 @@ # Define evaluation objects for severe convection: # One evaluation object for PPH pph_evaluation_objects = [ - inputs.EvaluationObject( + ewb.EvaluationObject( event_type="severe_convection", metric_list=pph_metrics, target=pph_target, @@ -66,7 +66,7 @@ # One evaluation object for LSR lsr_evaluation_objects = [ - inputs.EvaluationObject( + ewb.EvaluationObject( event_type="severe_convection", metric_list=lsr_metrics, target=lsr_target, @@ -76,14 +76,14 @@ if __name__ == "__main__": # Initialize ExtremeWeatherBench with both evaluation objects - ewb = evaluate.ExtremeWeatherBench( + severe_ewb = ewb.evaluation( case_metadata=case_yaml, evaluation_objects=lsr_evaluation_objects + pph_evaluation_objects, ) logger.info("Starting EWB run") # Run the workflow with parllel_config backend set to dask - outputs = ewb.run_evaluation(parallel_config={"backend": "loky", "n_jobs": 3}) + outputs = severe_ewb.run(parallel_config={"backend": "loky", "n_jobs": 3}) # Save the results to a CSV file outputs.to_csv("applied_severe_convection_results.csv") diff --git a/docs/examples/applied_tc.py b/docs/examples/applied_tc.py index 955c45a5..8707d14c 100644 --- a/docs/examples/applied_tc.py +++ b/docs/examples/applied_tc.py @@ -3,7 +3,8 @@ import numpy as np import xarray as xr -from extremeweatherbench import calc, cases, derived, evaluate, inputs, metrics +import extremeweatherbench as ewb +from extremeweatherbench import calc # Set the logger level to INFO logger = logging.getLogger("extremeweatherbench") @@ -54,46 +55,46 @@ def _preprocess_hres_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: # Load the case collection from the YAML file -case_yaml = cases.load_ewb_events_yaml_into_case_collection() +case_yaml = ewb.load_cases() # Select single case (TC Ida) case_yaml.select_cases(by="case_id_number", value=220, inplace=True) # Define IBTrACS target, no arguments needed as defaults are sufficient -ibtracs_target = inputs.IBTrACS() +ibtracs_target = ewb.targets.IBTrACS() # Define HRES forecast -hres_forecast = inputs.ZarrForecast( +hres_forecast = ewb.forecasts.ZarrForecast( name="hres_forecast", source="gs://weatherbench2/datasets/hres/2016-2022-0012-1440x721.zarr", # Define tropical cyclone track derivedvariable to include in the forecast - variables=[derived.TropicalCycloneTrackVariables()], + variables=[ewb.derived.TropicalCycloneTrackVariables()], # Define metadata variable mapping for HRES forecast - variable_mapping=inputs.HRES_metadata_variable_mapping, + variable_mapping=ewb.HRES_metadata_variable_mapping, storage_options={"remote_options": {"anon": True}}, # Preprocess the HRES forecast to include geopotential thickness calculation preprocess=_preprocess_hres_forecast_dataset, ) # Define FCNv2 forecast -fcnv2_forecast = inputs.KerchunkForecast( +fcnv2_forecast = ewb.forecasts.KerchunkForecast( name="fcn_forecast", source="gs://extremeweatherbench/FOUR_v200_GFS.parq", - variables=[derived.TropicalCycloneTrackVariables()], + variables=[ewb.derived.TropicalCycloneTrackVariables()], # Define metadata variable mapping for FCNv2 forecast - variable_mapping=inputs.CIRA_metadata_variable_mapping, + variable_mapping=ewb.CIRA_metadata_variable_mapping, # Preprocess the FCNv2 forecast to include geopotential thickness calculation preprocess=_preprocess_bb_cira_tc_forecast_dataset, storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, ) # Define Pangu forecast -pangu_forecast = inputs.KerchunkForecast( +pangu_forecast = ewb.forecasts.KerchunkForecast( name="pangu_forecast", source="gs://extremeweatherbench/PANG_v100_GFS.parq", - variables=[derived.TropicalCycloneTrackVariables()], + variables=[ewb.derived.TropicalCycloneTrackVariables()], # Define metadata variable mapping for Pangu forecast - variable_mapping=inputs.CIRA_metadata_variable_mapping, + variable_mapping=ewb.CIRA_metadata_variable_mapping, # Preprocess the Pangu forecast to include geopotential thickness calculation # which uses the same preprocessing function as the FCNv2 forecast preprocess=_preprocess_bb_cira_tc_forecast_dataset, @@ -106,11 +107,11 @@ def _preprocess_hres_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: # the evaluation to occur, in the case of multiple landfalls, for the next landfall in # time to be evaluated against composite_landfall_metrics = [ - metrics.LandfallMetric( + ewb.metrics.LandfallMetric( metrics=[ - metrics.LandfallIntensityMeanAbsoluteError, - metrics.LandfallTimeMeanError, - metrics.LandfallDisplacement, + ewb.metrics.LandfallIntensityMeanAbsoluteError, + ewb.metrics.LandfallTimeMeanError, + ewb.metrics.LandfallDisplacement, ], approach="next", # Set the intensity variable to use for the metric @@ -123,21 +124,21 @@ def _preprocess_hres_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: # the relevant cases inside the events YAML file tc_evaluation_object = [ # HRES forecast - inputs.EvaluationObject( + ewb.EvaluationObject( event_type="tropical_cyclone", metric_list=composite_landfall_metrics, target=ibtracs_target, forecast=hres_forecast, ), # Pangu forecast - inputs.EvaluationObject( + ewb.EvaluationObject( event_type="tropical_cyclone", metric_list=composite_landfall_metrics, target=ibtracs_target, forecast=pangu_forecast, ), # FCNv2 forecast - inputs.EvaluationObject( + ewb.EvaluationObject( event_type="tropical_cyclone", metric_list=composite_landfall_metrics, target=ibtracs_target, @@ -147,13 +148,13 @@ def _preprocess_hres_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: if __name__ == "__main__": # Initialize ExtremeWeatherBench - ewb = evaluate.ExtremeWeatherBench( + tc_ewb = ewb.evaluation( case_metadata=case_yaml, evaluation_objects=tc_evaluation_object, ) logger.info("Starting EWB run") # Run the workflow with parallel_config backend set to dask - outputs = ewb.run_evaluation( + outputs = tc_ewb.run( parallel_config={"backend": "loky", "n_jobs": 3}, ) outputs.to_csv("tc_metric_test_results.csv") diff --git a/docs/examples/example_config.py b/docs/examples/example_config.py index eee19463..ad94d3d6 100644 --- a/docs/examples/example_config.py +++ b/docs/examples/example_config.py @@ -7,32 +7,29 @@ ewb --config-file example_config.py """ -from extremeweatherbench import inputs, metrics -from extremeweatherbench.cases import load_ewb_events_yaml_into_case_collection +import extremeweatherbench as ewb # Define targets (observation data) -era5_heatwave_target = inputs.ERA5( +era5_heatwave_target = ewb.targets.ERA5( variables=["surface_air_temperature"], chunks=None, ) # Define forecasts -fcnv2_forecast = inputs.KerchunkForecast( +fcnv2_forecast = ewb.forecasts.KerchunkForecast( name="fcnv2_forecast", source="gs://extremeweatherbench/FOUR_v200_GFS.parq", variables=["surface_air_temperature"], - variable_mapping=inputs.CIRA_metadata_variable_mapping, + variable_mapping=ewb.CIRA_metadata_variable_mapping, ) # Define evaluation objects evaluation_objects = [ - inputs.EvaluationObject( + ewb.EvaluationObject( event_type="heat_wave", metric_list=[ - metrics.MaximumMeanAbsoluteError(), - metrics.RootMeanSquaredError(), - metrics.OnsetMeanError(), - metrics.DurationMeanError(), + ewb.metrics.MaximumMeanAbsoluteError(), + ewb.metrics.RootMeanSquaredError(), ], target=era5_heatwave_target, forecast=fcnv2_forecast, @@ -41,7 +38,7 @@ # Load case data from the default events.yaml # Users can also define their own cases_dict structure -cases_dict = load_ewb_events_yaml_into_case_collection() +cases_dict = ewb.load_cases() # Alternatively, users could define custom cases like this: # cases_dict = { diff --git a/docs/usage.md b/docs/usage.md index b1891ebb..874edbd2 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -15,19 +15,44 @@ ewb --default or: ```python -from extremeweatherbench import evaluate, defaults, cases +import extremeweatherbench as ewb -eval_objects = defaults.get_brightband_evaluation_objects() +eval_objects = ewb.get_brightband_evaluation_objects() +cases = ewb.load_cases() -cases = cases.load_ewb_events_yaml_into_case_collection() -ewb = ExtremeWeatherBench(cases=cases, -evaluation_objects=eval_objects) - -outputs = ewb.run_evaluation() +runner = ewb.evaluation( + case_metadata=cases, + evaluation_objects=eval_objects +) +outputs = runner.run() outputs.to_csv('your_outputs.csv') ``` +## API Overview + +ExtremeWeatherBench provides a hierarchical API for accessing its components: + +```python +import extremeweatherbench as ewb + +# Main evaluation entry point +ewb.evaluation(...) # Alias for ExtremeWeatherBench class + +# Hierarchical access via namespaces +ewb.targets.ERA5(...) # Target classes +ewb.forecasts.ZarrForecast(...) # Forecast classes +ewb.metrics.MeanAbsoluteError() # Metric classes +ewb.derived.AtmosphericRiverVariables() # Derived variables +ewb.regions.BoundingBoxRegion(...) # Region classes +ewb.cases.IndividualCase # Case metadata classes + +# Also available at top level for convenience +ewb.ERA5(...) +ewb.ZarrForecast(...) +ewb.load_cases() +``` + ## Running an Evaluation for a Single Event Type ExtremeWeatherBench has default event types and cases for heat waves, freezes, severe convection, tropical cyclones, and atmospheric rivers. @@ -35,20 +60,21 @@ ExtremeWeatherBench has default event types and cases for heat waves, freezes, s To run an evaluation, there are three components required: a forecast, a target, and an evaluation object. ```python -from extremeweatherbench import inputs +import extremeweatherbench as ewb ``` + There are two built-in `ForecastBase` classes to set up a forecast: `ZarrForecast` and `KerchunkForecast`. Here is an example of a `ZarrForecast`, using Weatherbench2's HRES zarr store: ```python -hres_forecast = inputs.ZarrForecast( +hres_forecast = ewb.forecasts.ZarrForecast( source="gs://weatherbench2/datasets/hres/2016-2022-0012-1440x721.zarr", name="HRES", variables=["surface_air_temperature"], - variable_mapping=inputs.HRES_metadata_variable_mapping, # built-in mapping available + variable_mapping=ewb.HRES_metadata_variable_mapping, # built-in mapping available storage_options={"remote_options": {"anon": True}}, - ) ``` + There are required arguments, namely: - `source` @@ -58,59 +84,64 @@ There are required arguments, namely: * `variables` can be defined within one or more metrics instead of in a `ForecastBase` object. -A forecast needs a `source`, which is a link to the zarr store in this case. A `name` is required to identify the outputs. It also needs `variables` defined, which are based on CF Conventions. A list of variable namings exists in `defaults.py` as `DEFAULT_VARIABLE_NAMES`. Each forecast will likely have different names for their variables, so a `variable_mapping` dictionary is also essential to process the variables, as well as the coordinates and dimensions. EWB uses `lead_time`, `init_time`, and `valid_time` as time coordinates. The HRES data is mapped from `prediction_timedelta` to `lead_time`, as an example. `storage_options` define access patterns for the data if needed. These are passed to the opening function, e.g. `xarray.open_zarr`. +A forecast needs a `source`, which is a link to the zarr store in this case. A `name` is required to identify the outputs. It also needs `variables` defined, which are based on CF Conventions. A list of variable namings exists in `ewb.DEFAULT_VARIABLE_NAMES`. Each forecast will likely have different names for their variables, so a `variable_mapping` dictionary is also essential to process the variables, as well as the coordinates and dimensions. EWB uses `lead_time`, `init_time`, and `valid_time` as time coordinates. The HRES data is mapped from `prediction_timedelta` to `lead_time`, as an example. `storage_options` define access patterns for the data if needed. These are passed to the opening function, e.g. `xarray.open_zarr`. Next, a target dataset must be defined as well to evaluate against. For this evaluation, we'll use ERA5: ```python -era5_heatwave_target = inputs.ERA5( - source=inputs.ARCO_ERA5_FULL_URI, +era5_heatwave_target = ewb.targets.ERA5( + source=ewb.ARCO_ERA5_FULL_URI, variables=["surface_air_temperature"], storage_options={"remote_options": {"anon": True}}, chunks=None, ) ``` -Similarly to forecasts, we need to define the `source`, which here is the ARCO ERA5 provided by Google. `variables` are again required to be set for the `inputs.ERA5` class; `variable_mapping` defaults to `inputs.ERA5_metadata_variable_mapping` for many existing variables and likely is not required to be set unless your use case is for less common variables. Both forecasts and targets, if relevant, have an optional `chunks` parameter which defaults to what should be the most efficient value - usually `None` or `'auto'`, but can be changed as seen above. +Similarly to forecasts, we need to define the `source`, which here is the ARCO ERA5 provided by Google. `variables` are again required to be set for the `ewb.targets.ERA5` class; `variable_mapping` defaults to `ewb.ERA5_metadata_variable_mapping` for many existing variables and likely is not required to be set unless your use case is for less common variables. Both forecasts and targets, if relevant, have an optional `chunks` parameter which defaults to what should be the most efficient value - usually `None` or `'auto'`, but can be changed as seen above. We then set up an `EvaluationObject` list: ```python -from extremeweatherbench import metrics - heatwave_evaluation_list = [ - inputs.EvaluationObject( + ewb.EvaluationObject( event_type="heat_wave", metric_list=[ - metrics.MaximumMeanAbsoluteError(), - metrics.RootMeanSquaredError(), - metrics.MaximumLowestMeanAbsoluteError() + ewb.metrics.MaximumMeanAbsoluteError(), + ewb.metrics.RootMeanSquaredError(), + ewb.metrics.MaximumLowestMeanAbsoluteError() ], target=era5_heatwave_target, forecast=hres_forecast, ), ] ``` + Which includes the event_type of interest (as defined in the case dictionary or YAML file used), the list of metrics to run, one target, and one forecast. There can be multiple `EvaluationObjects` which are used for an evaluation run. Plugging these all in: ```python -from extremeweatherbench import cases, evaluate -case_yaml = cases.load_ewb_events_yaml_into_case_collection() - +case_yaml = ewb.load_cases() -ewb_instance = evaluate.ExtremeWeatherBench( - cases=case_yaml, +ewb_instance = ewb.evaluation( + case_metadata=case_yaml, evaluation_objects=heatwave_evaluation_list, ) outputs = ewb_instance.run() - outputs.to_csv('your_file_name.csv') ``` -Where the EWB default events YAML file is loaded in using a built-in utility helper function, then applied to an instance of `evaluate.ExtremeWeatherBench` along with the `EvaluationObject` list. Finally, we run the evaluation with the `.run()` method, where defaults are typically sufficient to run with a small to moderate-sized virtual machine. after subsetting and prior to metric calculation. +Where the EWB default events YAML file is loaded in using `ewb.load_cases()`, then applied to an instance of `ewb.evaluation` along with the `EvaluationObject` list. Finally, we run the evaluation with the `.run()` method, where defaults are typically sufficient to run with a small to moderate-sized virtual machine. The outputs are returned as a pandas DataFrame and can be manipulated in the script, a notebook, or post-hoc after saving it. + +## Backward Compatibility + +All existing import patterns remain functional: + +```python +from extremeweatherbench import evaluate, inputs, cases, metrics # Still works +from extremeweatherbench.evaluate import ExtremeWeatherBench # Still works +``` diff --git a/src/extremeweatherbench/__init__.py b/src/extremeweatherbench/__init__.py index e69de29b..57ea8cf0 100644 --- a/src/extremeweatherbench/__init__.py +++ b/src/extremeweatherbench/__init__.py @@ -0,0 +1,343 @@ +"""ExtremeWeatherBench: A benchmarking framework for extreme weather forecasts. + +This module provides the public API for ExtremeWeatherBench. Users can import +the package and access all key functionality: + + import extremeweatherbench as ewb + + # Main entry point for evaluation + ewb.evaluation(case_metadata=..., evaluation_objects=...) + + # Hierarchical access via namespace submodules + ewb.targets.ERA5(...) + ewb.forecasts.ZarrForecast(...) + ewb.metrics.MeanAbsoluteError(...) + + # Also available at top level + ewb.ERA5(...) + ewb.load_cases() +""" + +from types import SimpleNamespace + +from extremeweatherbench.cases import ( + CaseOperator, + IndividualCase, + IndividualCaseCollection, + build_case_operators, + load_ewb_events_yaml_into_case_collection, + load_individual_cases, + load_individual_cases_from_yaml, + read_incoming_yaml, +) +from extremeweatherbench.defaults import ( + DEFAULT_COORDINATE_VARIABLES, + DEFAULT_VARIABLE_NAMES, + _preprocess_bb_ar_cira_forecast_dataset, + _preprocess_bb_cira_forecast_dataset, + _preprocess_bb_cira_tc_forecast_dataset, + _preprocess_bb_hres_tc_forecast_dataset, + _preprocess_bb_severe_cira_forecast_dataset, + cira_atmospheric_river_forecast, + cira_freeze_forecast, + cira_heatwave_forecast, + cira_severe_convection_forecast, + cira_tropical_cyclone_forecast, + era5_atmospheric_river_target, + era5_freeze_target, + era5_heatwave_target, + get_brightband_evaluation_objects, + get_climatology, + ghcn_freeze_target, + ghcn_heatwave_target, + ibtracs_target, + lsr_target, + pph_target, +) +from extremeweatherbench.derived import ( + AtmosphericRiverVariables, + CravenBrooksSignificantSevere, + DerivedVariable, + TropicalCycloneTrackVariables, + maybe_derive_variables, + maybe_include_variables_from_derived_input, +) +from extremeweatherbench.evaluate import ExtremeWeatherBench +from extremeweatherbench.inputs import ( + ARCO_ERA5_FULL_URI, + DEFAULT_GHCN_URI, + ERA5, + GHCN, + IBTRACS_URI, + LSR, + LSR_URI, + PPH, + PPH_URI, + CIRA_metadata_variable_mapping, + ERA5_metadata_variable_mapping, + EvaluationObject, + ForecastBase, + HRES_metadata_variable_mapping, + IBTrACS, + IBTrACS_metadata_variable_mapping, + InputBase, + KerchunkForecast, + TargetBase, + XarrayForecast, + ZarrForecast, + align_forecast_to_target, + check_for_missing_data, + maybe_subset_variables, + open_kerchunk_reference, + zarr_target_subsetter, +) +from extremeweatherbench.metrics import ( + Accuracy, + BaseMetric, + CompositeMetric, + CriticalSuccessIndex, + DurationMeanError, + EarlySignal, + FalseAlarmRatio, + FalseNegatives, + FalsePositives, + LandfallDisplacement, + LandfallIntensityMeanAbsoluteError, + LandfallMetric, + LandfallTimeMeanError, + MaximumLowestMeanAbsoluteError, + MaximumMeanAbsoluteError, + MeanAbsoluteError, + MeanError, + MeanSquaredError, + MinimumMeanAbsoluteError, + RootMeanSquaredError, + SpatialDisplacement, + ThresholdMetric, + TrueNegatives, + TruePositives, +) +from extremeweatherbench.regions import ( + REGION_TYPES, + BoundingBoxRegion, + CenteredRegion, + Region, + RegionSubsetter, + ShapefileRegion, + map_to_create_region, + subset_cases_to_region, + subset_results_to_region, +) + +# Aliases +evaluation = ExtremeWeatherBench +load_cases = load_ewb_events_yaml_into_case_collection + +# Namespace submodules +targets = SimpleNamespace( + TargetBase=TargetBase, + ERA5=ERA5, + GHCN=GHCN, + IBTrACS=IBTrACS, + LSR=LSR, + PPH=PPH, +) + +forecasts = SimpleNamespace( + ForecastBase=ForecastBase, + ZarrForecast=ZarrForecast, + KerchunkForecast=KerchunkForecast, + XarrayForecast=XarrayForecast, +) + +metrics = SimpleNamespace( + BaseMetric=BaseMetric, + CompositeMetric=CompositeMetric, + ThresholdMetric=ThresholdMetric, + LandfallMetric=LandfallMetric, + MeanSquaredError=MeanSquaredError, + MeanAbsoluteError=MeanAbsoluteError, + MeanError=MeanError, + RootMeanSquaredError=RootMeanSquaredError, + CriticalSuccessIndex=CriticalSuccessIndex, + FalseAlarmRatio=FalseAlarmRatio, + Accuracy=Accuracy, + TruePositives=TruePositives, + FalsePositives=FalsePositives, + TrueNegatives=TrueNegatives, + FalseNegatives=FalseNegatives, + EarlySignal=EarlySignal, + MaximumMeanAbsoluteError=MaximumMeanAbsoluteError, + MinimumMeanAbsoluteError=MinimumMeanAbsoluteError, + MaximumLowestMeanAbsoluteError=MaximumLowestMeanAbsoluteError, + DurationMeanError=DurationMeanError, + SpatialDisplacement=SpatialDisplacement, + LandfallDisplacement=LandfallDisplacement, + LandfallTimeMeanError=LandfallTimeMeanError, + LandfallIntensityMeanAbsoluteError=LandfallIntensityMeanAbsoluteError, +) + +cases = SimpleNamespace( + IndividualCase=IndividualCase, + IndividualCaseCollection=IndividualCaseCollection, + CaseOperator=CaseOperator, + build_case_operators=build_case_operators, + load_individual_cases=load_individual_cases, + load_individual_cases_from_yaml=load_individual_cases_from_yaml, + load_ewb_events_yaml_into_case_collection=load_ewb_events_yaml_into_case_collection, + read_incoming_yaml=read_incoming_yaml, +) + +derived = SimpleNamespace( + DerivedVariable=DerivedVariable, + TropicalCycloneTrackVariables=TropicalCycloneTrackVariables, + CravenBrooksSignificantSevere=CravenBrooksSignificantSevere, + AtmosphericRiverVariables=AtmosphericRiverVariables, + maybe_derive_variables=maybe_derive_variables, + maybe_include_variables_from_derived_input=maybe_include_variables_from_derived_input, +) + +regions = SimpleNamespace( + Region=Region, + CenteredRegion=CenteredRegion, + BoundingBoxRegion=BoundingBoxRegion, + ShapefileRegion=ShapefileRegion, + RegionSubsetter=RegionSubsetter, + REGION_TYPES=REGION_TYPES, + map_to_create_region=map_to_create_region, + subset_cases_to_region=subset_cases_to_region, + subset_results_to_region=subset_results_to_region, +) + +defaults = SimpleNamespace( + DEFAULT_COORDINATE_VARIABLES=DEFAULT_COORDINATE_VARIABLES, + DEFAULT_VARIABLE_NAMES=DEFAULT_VARIABLE_NAMES, + get_climatology=get_climatology, + get_brightband_evaluation_objects=get_brightband_evaluation_objects, + _preprocess_bb_cira_forecast_dataset=_preprocess_bb_cira_forecast_dataset, + _preprocess_bb_cira_tc_forecast_dataset=_preprocess_bb_cira_tc_forecast_dataset, + _preprocess_bb_hres_tc_forecast_dataset=_preprocess_bb_hres_tc_forecast_dataset, + _preprocess_bb_ar_cira_forecast_dataset=_preprocess_bb_ar_cira_forecast_dataset, + _preprocess_bb_severe_cira_forecast_dataset=_preprocess_bb_severe_cira_forecast_dataset, + era5_heatwave_target=era5_heatwave_target, + era5_freeze_target=era5_freeze_target, + era5_atmospheric_river_target=era5_atmospheric_river_target, + ghcn_heatwave_target=ghcn_heatwave_target, + ghcn_freeze_target=ghcn_freeze_target, + lsr_target=lsr_target, + pph_target=pph_target, + ibtracs_target=ibtracs_target, + cira_heatwave_forecast=cira_heatwave_forecast, + cira_freeze_forecast=cira_freeze_forecast, + cira_tropical_cyclone_forecast=cira_tropical_cyclone_forecast, + cira_atmospheric_river_forecast=cira_atmospheric_river_forecast, + cira_severe_convection_forecast=cira_severe_convection_forecast, +) + +__all__ = [ + "evaluation", + "ExtremeWeatherBench", + "targets", + "forecasts", + "metrics", + "cases", + "derived", + "regions", + "load_cases", + "IndividualCase", + "IndividualCaseCollection", + "CaseOperator", + "build_case_operators", + "load_individual_cases", + "load_individual_cases_from_yaml", + "load_ewb_events_yaml_into_case_collection", + "read_incoming_yaml", + "InputBase", + "ForecastBase", + "TargetBase", + "EvaluationObject", + "ZarrForecast", + "KerchunkForecast", + "XarrayForecast", + "ERA5", + "GHCN", + "IBTrACS", + "LSR", + "PPH", + "ERA5_metadata_variable_mapping", + "CIRA_metadata_variable_mapping", + "HRES_metadata_variable_mapping", + "IBTrACS_metadata_variable_mapping", + "ARCO_ERA5_FULL_URI", + "DEFAULT_GHCN_URI", + "LSR_URI", + "PPH_URI", + "IBTRACS_URI", + "open_kerchunk_reference", + "zarr_target_subsetter", + "align_forecast_to_target", + "maybe_subset_variables", + "check_for_missing_data", + "BaseMetric", + "CompositeMetric", + "ThresholdMetric", + "LandfallMetric", + "MeanSquaredError", + "MeanAbsoluteError", + "MeanError", + "RootMeanSquaredError", + "CriticalSuccessIndex", + "FalseAlarmRatio", + "Accuracy", + "TruePositives", + "FalsePositives", + "TrueNegatives", + "FalseNegatives", + "EarlySignal", + "MaximumMeanAbsoluteError", + "MinimumMeanAbsoluteError", + "MaximumLowestMeanAbsoluteError", + "DurationMeanError", + "SpatialDisplacement", + "LandfallDisplacement", + "LandfallTimeMeanError", + "LandfallIntensityMeanAbsoluteError", + "DerivedVariable", + "TropicalCycloneTrackVariables", + "CravenBrooksSignificantSevere", + "AtmosphericRiverVariables", + "maybe_derive_variables", + "maybe_include_variables_from_derived_input", + "Region", + "CenteredRegion", + "BoundingBoxRegion", + "ShapefileRegion", + "RegionSubsetter", + "REGION_TYPES", + "map_to_create_region", + "subset_cases_to_region", + "subset_results_to_region", + "defaults", + "DEFAULT_COORDINATE_VARIABLES", + "DEFAULT_VARIABLE_NAMES", + "get_climatology", + "get_brightband_evaluation_objects", + "_preprocess_bb_cira_forecast_dataset", + "_preprocess_bb_cira_tc_forecast_dataset", + "_preprocess_bb_hres_tc_forecast_dataset", + "_preprocess_bb_ar_cira_forecast_dataset", + "_preprocess_bb_severe_cira_forecast_dataset", + "era5_heatwave_target", + "era5_freeze_target", + "era5_atmospheric_river_target", + "ghcn_heatwave_target", + "ghcn_freeze_target", + "lsr_target", + "pph_target", + "ibtracs_target", + "cira_heatwave_forecast", + "cira_freeze_forecast", + "cira_tropical_cyclone_forecast", + "cira_atmospheric_river_forecast", + "cira_severe_convection_forecast", +] diff --git a/src/extremeweatherbench/cases.py b/src/extremeweatherbench/cases.py index 1c10c7bf..9d6428a7 100644 --- a/src/extremeweatherbench/cases.py +++ b/src/extremeweatherbench/cases.py @@ -14,10 +14,11 @@ import dacite import yaml # type: ignore[import] -from extremeweatherbench import regions +import extremeweatherbench.regions as regions if TYPE_CHECKING: - from extremeweatherbench import inputs, metrics + import extremeweatherbench.inputs as inputs + import extremeweatherbench.metrics as metrics logger = logging.getLogger(__name__) diff --git a/src/extremeweatherbench/defaults.py b/src/extremeweatherbench/defaults.py index 3280fd22..b0f12a12 100644 --- a/src/extremeweatherbench/defaults.py +++ b/src/extremeweatherbench/defaults.py @@ -302,7 +302,7 @@ def get_brightband_evaluation_objects() -> list[inputs.EvaluationObject]: routine. """ # Import metrics here to avoid circular import - from extremeweatherbench import metrics + import extremeweatherbench.metrics as metrics heatwave_metric_list: list[metrics.BaseMetric] = [ metrics.MaximumMeanAbsoluteError(), diff --git a/src/extremeweatherbench/derived.py b/src/extremeweatherbench/derived.py index 94c97a8d..f94eefee 100644 --- a/src/extremeweatherbench/derived.py +++ b/src/extremeweatherbench/derived.py @@ -10,7 +10,7 @@ from extremeweatherbench.events import tropical_cyclone if TYPE_CHECKING: - from extremeweatherbench import cases + import extremeweatherbench.cases as cases logger = logging.getLogger(__name__) diff --git a/src/extremeweatherbench/evaluate.py b/src/extremeweatherbench/evaluate.py index ba5d2be5..64fd9cbe 100644 --- a/src/extremeweatherbench/evaluate.py +++ b/src/extremeweatherbench/evaluate.py @@ -15,10 +15,15 @@ from tqdm.contrib.logging import logging_redirect_tqdm from tqdm.dask import TqdmCallback -from extremeweatherbench import cases, derived, inputs, metrics, sources, utils +import extremeweatherbench.cases as cases +import extremeweatherbench.derived as derived +import extremeweatherbench.inputs as inputs +import extremeweatherbench.metrics as metrics +import extremeweatherbench.sources as sources +import extremeweatherbench.utils as utils if TYPE_CHECKING: - from extremeweatherbench import regions + import extremeweatherbench.regions as regions logger = logging.getLogger(__name__) @@ -970,44 +975,3 @@ def _safe_concat( return pd.concat(valid_dfs, ignore_index=ignore_index) else: return pd.DataFrame(columns=OUTPUT_COLUMNS) - - -def _parallel_serial_config_check( - n_jobs: Optional[int] = None, - parallel_config: Optional[dict] = None, -) -> Optional[dict]: - """Check if running in serial or parallel mode. - - Args: - n_jobs: The number of jobs to run in parallel. If None, defaults to the - joblib backend default value. If 1, the workflow will run serially. - parallel_config: Optional dictionary of joblib parallel configuration. If - provided, this takes precedence over n_jobs. If not provided and n_jobs is - specified, a default config with loky backend is used. - Returns: - None if running in serial mode, otherwise a dictionary of joblib parallel - configuration. - """ - # Determine if running in serial or parallel mode - # Serial: n_jobs=1 or (parallel_config with n_jobs=1) - # Parallel: n_jobs>1 or (parallel_config with n_jobs>1) - is_serial = ( - (n_jobs == 1) - or (parallel_config is not None and parallel_config.get("n_jobs") == 1) - or (n_jobs is None and parallel_config is None) - ) - logger.debug("Running in %s mode.", "serial" if is_serial else "parallel") - - if not is_serial: - # Build parallel_config if not provided - if parallel_config is None and n_jobs is not None: - logger.debug( - "No parallel_config provided, using loky backend and %s jobs.", - n_jobs, - ) - parallel_config = {"backend": "loky", "n_jobs": n_jobs} - # If running in serial mode, set parallel_config to None if not already - else: - parallel_config = None - # Return the maybe updated kwargs - return parallel_config diff --git a/src/extremeweatherbench/evaluate_cli.py b/src/extremeweatherbench/evaluate_cli.py index eafd263d..3809096f 100644 --- a/src/extremeweatherbench/evaluate_cli.py +++ b/src/extremeweatherbench/evaluate_cli.py @@ -7,7 +7,9 @@ import click import pandas as pd -from extremeweatherbench import cases, defaults, evaluate +import extremeweatherbench.cases as cases +import extremeweatherbench.defaults as defaults +import extremeweatherbench.evaluate as evaluate @click.command() diff --git a/src/extremeweatherbench/inputs.py b/src/extremeweatherbench/inputs.py index eb1c8425..addc2bd1 100644 --- a/src/extremeweatherbench/inputs.py +++ b/src/extremeweatherbench/inputs.py @@ -18,10 +18,13 @@ import polars as pl import xarray as xr -from extremeweatherbench import cases, derived, sources, utils +import extremeweatherbench.cases as cases +import extremeweatherbench.derived as derived +import extremeweatherbench.sources as sources +import extremeweatherbench.utils as utils if TYPE_CHECKING: - from extremeweatherbench import metrics + import extremeweatherbench.metrics as metrics logger = logging.getLogger(__name__) diff --git a/src/extremeweatherbench/regions.py b/src/extremeweatherbench/regions.py index 48f8b81f..24842a66 100644 --- a/src/extremeweatherbench/regions.py +++ b/src/extremeweatherbench/regions.py @@ -16,7 +16,7 @@ from extremeweatherbench import utils if TYPE_CHECKING: - from extremeweatherbench import cases + import extremeweatherbench.cases as cases logger = logging.getLogger(__name__) @@ -527,7 +527,7 @@ def subset_case_collection( A new IndividualCaseCollection with cases subset to the region """ # Avoid circular import - from extremeweatherbench import cases + import extremeweatherbench.cases as cases filtered_cases = [] diff --git a/src/extremeweatherbench/sources/base.py b/src/extremeweatherbench/sources/base.py index 0abed12c..3278038b 100644 --- a/src/extremeweatherbench/sources/base.py +++ b/src/extremeweatherbench/sources/base.py @@ -1,7 +1,7 @@ import datetime from typing import Any, Protocol, runtime_checkable -from extremeweatherbench import regions +import extremeweatherbench.regions as regions @runtime_checkable diff --git a/src/extremeweatherbench/sources/pandas_dataframe.py b/src/extremeweatherbench/sources/pandas_dataframe.py index 31bc4062..b6eb91a9 100644 --- a/src/extremeweatherbench/sources/pandas_dataframe.py +++ b/src/extremeweatherbench/sources/pandas_dataframe.py @@ -8,7 +8,7 @@ from extremeweatherbench import utils if TYPE_CHECKING: - from extremeweatherbench import regions + import extremeweatherbench.regions as regions def safely_pull_variables( @@ -43,7 +43,7 @@ def safely_pull_variables( >>> list(result.columns) ['temp'] """ - from extremeweatherbench import defaults + import extremeweatherbench.defaults as defaults # Get column names from DataFrame available_columns = list(data.columns) diff --git a/src/extremeweatherbench/sources/polars_lazyframe.py b/src/extremeweatherbench/sources/polars_lazyframe.py index e9e56cf4..f0caa41e 100644 --- a/src/extremeweatherbench/sources/polars_lazyframe.py +++ b/src/extremeweatherbench/sources/polars_lazyframe.py @@ -8,7 +8,7 @@ from extremeweatherbench import utils if TYPE_CHECKING: - from extremeweatherbench import regions + import extremeweatherbench.regions as regions def safely_pull_variables( @@ -47,7 +47,7 @@ def safely_pull_variables( >>> result.collect().columns ['temp'] """ - from extremeweatherbench import defaults + import extremeweatherbench.defaults as defaults # Get column names from LazyFrame available_columns = data.collect_schema().names() diff --git a/src/extremeweatherbench/sources/xarray_dataarray.py b/src/extremeweatherbench/sources/xarray_dataarray.py index e58d82d6..f3b4e734 100644 --- a/src/extremeweatherbench/sources/xarray_dataarray.py +++ b/src/extremeweatherbench/sources/xarray_dataarray.py @@ -5,7 +5,8 @@ import pandas as pd import xarray as xr -from extremeweatherbench import regions, utils +import extremeweatherbench.regions as regions +import extremeweatherbench.utils as utils def safely_pull_variables( diff --git a/src/extremeweatherbench/sources/xarray_dataset.py b/src/extremeweatherbench/sources/xarray_dataset.py index 56d52618..ae8e8b89 100644 --- a/src/extremeweatherbench/sources/xarray_dataset.py +++ b/src/extremeweatherbench/sources/xarray_dataset.py @@ -9,7 +9,7 @@ from extremeweatherbench import utils if TYPE_CHECKING: - from extremeweatherbench import regions + import extremeweatherbench.regions as regions def safely_pull_variables( From a0dcb134041a7fe2fa87d9ef78eef8ab705957c0 Mon Sep 17 00:00:00 2001 From: taylor Date: Fri, 23 Jan 2026 22:26:03 +0000 Subject: [PATCH 07/13] ruff/linting. add utils to init --- data_prep/ibtracs_bounds.py | 8 ++- src/extremeweatherbench/__init__.py | 105 ++++++++++++++++++++++------ tests/test_evaluate.py | 2 +- 3 files changed, 90 insertions(+), 25 deletions(-) diff --git a/data_prep/ibtracs_bounds.py b/data_prep/ibtracs_bounds.py index b8513adf..0dc962d9 100644 --- a/data_prep/ibtracs_bounds.py +++ b/data_prep/ibtracs_bounds.py @@ -4,6 +4,7 @@ import logging import re from importlib import resources +from typing import TYPE_CHECKING import cartopy.crs as ccrs import cartopy.feature as cfeature @@ -17,6 +18,9 @@ import extremeweatherbench as ewb import extremeweatherbench.data +if TYPE_CHECKING: + from extremeweatherbench.regions import Region + logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -67,7 +71,7 @@ def calculate_extent_bounds( bottom_lat: float, top_lat: float, extent_buffer: float = 250, -) -> ewb.regions.Region: +) -> Region: """Calculate extent bounds with buffer. Args: @@ -177,7 +181,7 @@ def load_and_process_ibtracs_data(): # Get all storms from 2020 - 2025 seasons all_storms_2020_2025_lf = IBTRACS_lf.filter( (pl.col("SEASON").cast(pl.Int32) >= 2020) - ).select(inputs.IBTrACS_metadata_variable_mapping.values()) + ).select(ewb.inputs.IBTrACS_metadata_variable_mapping.values()) schema = all_storms_2020_2025_lf.collect_schema() # Convert pressure and surface wind columns to float, replacing " " with null diff --git a/src/extremeweatherbench/__init__.py b/src/extremeweatherbench/__init__.py index 57ea8cf0..e28885d8 100644 --- a/src/extremeweatherbench/__init__.py +++ b/src/extremeweatherbench/__init__.py @@ -20,12 +20,24 @@ from types import SimpleNamespace +from extremeweatherbench.calc import ( + convert_from_cartesian_to_latlon, + geopotential_thickness, + great_circle_mask, + haversine_distance, + maybe_calculate_wind_speed, + mixing_ratio, + orography, + pressure_at_surface, + saturation_mixing_ratio, + saturation_vapor_pressure, + specific_humidity_from_relative_humidity, +) from extremeweatherbench.cases import ( CaseOperator, IndividualCase, - IndividualCaseCollection, build_case_operators, - load_ewb_events_yaml_into_case_collection, + load_ewb_events_yaml_into_case_list, load_individual_cases, load_individual_cases_from_yaml, read_incoming_yaml, @@ -33,11 +45,6 @@ from extremeweatherbench.defaults import ( DEFAULT_COORDINATE_VARIABLES, DEFAULT_VARIABLE_NAMES, - _preprocess_bb_ar_cira_forecast_dataset, - _preprocess_bb_cira_forecast_dataset, - _preprocess_bb_cira_tc_forecast_dataset, - _preprocess_bb_hres_tc_forecast_dataset, - _preprocess_bb_severe_cira_forecast_dataset, cira_atmospheric_river_forecast, cira_freeze_forecast, cira_heatwave_forecast, @@ -128,10 +135,50 @@ subset_cases_to_region, subset_results_to_region, ) +from extremeweatherbench.utils import ( + check_for_vars, + convert_day_yearofday_to_time, + convert_init_time_to_valid_time, + convert_longitude_to_180, + convert_longitude_to_360, + convert_valid_time_to_init_time, + derive_indices_from_init_time_and_lead_time, + determine_temporal_resolution, + extract_tc_names, + filter_kwargs_for_callable, + find_common_init_times, + idx_to_coords, + interp_climatology_to_target, + is_valid_landfall, + load_land_geometry, + maybe_cache_and_compute, + maybe_densify_dataarray, + maybe_get_closest_timestamp_to_center_of_valid_times, + maybe_get_operator, + min_if_all_timesteps_present, + min_if_all_timesteps_present_forecast, + read_event_yaml, + remove_ocean_gridpoints, + stack_dataarray_from_dims, +) + +calc = SimpleNamespace( + geopotential_thickness=geopotential_thickness, + specific_humidity_from_relative_humidity=specific_humidity_from_relative_humidity, + convert_from_cartesian_to_latlon=convert_from_cartesian_to_latlon, + great_circle_mask=great_circle_mask, + maybe_calculate_wind_speed=maybe_calculate_wind_speed, + mixing_ratio=mixing_ratio, + orography=orography, + pressure_at_surface=pressure_at_surface, + saturation_mixing_ratio=saturation_mixing_ratio, + saturation_vapor_pressure=saturation_vapor_pressure, + haversine_distance=haversine_distance, +) # Aliases evaluation = ExtremeWeatherBench -load_cases = load_ewb_events_yaml_into_case_collection +load_cases = load_ewb_events_yaml_into_case_list # Namespace submodules targets = SimpleNamespace( @@ -179,12 +226,11 @@ cases = SimpleNamespace( IndividualCase=IndividualCase, - IndividualCaseCollection=IndividualCaseCollection, CaseOperator=CaseOperator, build_case_operators=build_case_operators, load_individual_cases=load_individual_cases, load_individual_cases_from_yaml=load_individual_cases_from_yaml, - load_ewb_events_yaml_into_case_collection=load_ewb_events_yaml_into_case_collection, + load_ewb_events_yaml_into_case_list=load_ewb_events_yaml_into_case_list, read_incoming_yaml=read_incoming_yaml, ) @@ -214,11 +260,6 @@ DEFAULT_VARIABLE_NAMES=DEFAULT_VARIABLE_NAMES, get_climatology=get_climatology, get_brightband_evaluation_objects=get_brightband_evaluation_objects, - _preprocess_bb_cira_forecast_dataset=_preprocess_bb_cira_forecast_dataset, - _preprocess_bb_cira_tc_forecast_dataset=_preprocess_bb_cira_tc_forecast_dataset, - _preprocess_bb_hres_tc_forecast_dataset=_preprocess_bb_hres_tc_forecast_dataset, - _preprocess_bb_ar_cira_forecast_dataset=_preprocess_bb_ar_cira_forecast_dataset, - _preprocess_bb_severe_cira_forecast_dataset=_preprocess_bb_severe_cira_forecast_dataset, era5_heatwave_target=era5_heatwave_target, era5_freeze_target=era5_freeze_target, era5_atmospheric_river_target=era5_atmospheric_river_target, @@ -234,6 +275,32 @@ cira_severe_convection_forecast=cira_severe_convection_forecast, ) +utils = SimpleNamespace( + maybe_get_operator=maybe_get_operator, + find_common_init_times=find_common_init_times, + is_valid_landfall=is_valid_landfall, + load_land_geometry=load_land_geometry, + maybe_cache_and_compute=maybe_cache_and_compute, + maybe_densify_dataarray=maybe_densify_dataarray, + maybe_get_closest_timestamp_to_center_of_valid_times=maybe_get_closest_timestamp_to_center_of_valid_times, + min_if_all_timesteps_present=min_if_all_timesteps_present, + min_if_all_timesteps_present_forecast=min_if_all_timesteps_present_forecast, + determine_temporal_resolution=determine_temporal_resolution, + convert_init_time_to_valid_time=convert_init_time_to_valid_time, + convert_valid_time_to_init_time=convert_valid_time_to_init_time, + convert_day_yearofday_to_time=convert_day_yearofday_to_time, + interp_climatology_to_target=interp_climatology_to_target, + check_for_vars=check_for_vars, + idx_to_coords=idx_to_coords, + extract_tc_names=extract_tc_names, + stack_dataarray_from_dims=stack_dataarray_from_dims, + convert_longitude_to_360=convert_longitude_to_360, + convert_longitude_to_180=convert_longitude_to_180, + derive_indices_from_init_time_and_lead_time=derive_indices_from_init_time_and_lead_time, + filter_kwargs_for_callable=filter_kwargs_for_callable, + remove_ocean_gridpoints=remove_ocean_gridpoints, + read_event_yaml=read_event_yaml, +) __all__ = [ "evaluation", "ExtremeWeatherBench", @@ -245,12 +312,11 @@ "regions", "load_cases", "IndividualCase", - "IndividualCaseCollection", "CaseOperator", "build_case_operators", "load_individual_cases", "load_individual_cases_from_yaml", - "load_ewb_events_yaml_into_case_collection", + "load_ewb_events_yaml_into_case_list", "read_incoming_yaml", "InputBase", "ForecastBase", @@ -322,11 +388,6 @@ "DEFAULT_VARIABLE_NAMES", "get_climatology", "get_brightband_evaluation_objects", - "_preprocess_bb_cira_forecast_dataset", - "_preprocess_bb_cira_tc_forecast_dataset", - "_preprocess_bb_hres_tc_forecast_dataset", - "_preprocess_bb_ar_cira_forecast_dataset", - "_preprocess_bb_severe_cira_forecast_dataset", "era5_heatwave_target", "era5_freeze_target", "era5_atmospheric_river_target", diff --git a/tests/test_evaluate.py b/tests/test_evaluate.py index 5ff808ff..18569e6b 100644 --- a/tests/test_evaluate.py +++ b/tests/test_evaluate.py @@ -1394,7 +1394,7 @@ def test_run_pipeline_target( def test_run_pipeline_invalid_source(self, sample_case_operator): """Test run_pipeline function with invalid input source.""" with pytest.raises(AttributeError, match="'str' object has no attribute"): - evaluate.run_pipeline(sample_case_operator.case_metadata, "invalid") # type: ignore + evaluate.run_pipeline(sample_case_operator.case_metadata, "invalid") # type: ignore def test_maybe_cache_and_compute_with_cache_dir( self, sample_forecast_dataset, sample_target_dataset, sample_individual_case From deced3e8bad050c167e66c02af66cfadde7e5fd9 Mon Sep 17 00:00:00 2001 From: taylor Date: Sat, 24 Jan 2026 02:04:21 +0000 Subject: [PATCH 08/13] add test coverage for module loading patterns --- src/extremeweatherbench/__init__.py | 307 +++++++++++----------------- tests/test_evaluate_cli.py | 4 +- tests/test_init.py | 261 +++++++++++++++++++++++ 3 files changed, 385 insertions(+), 187 deletions(-) create mode 100644 tests/test_init.py diff --git a/src/extremeweatherbench/__init__.py b/src/extremeweatherbench/__init__.py index e28885d8..8f1ac54a 100644 --- a/src/extremeweatherbench/__init__.py +++ b/src/extremeweatherbench/__init__.py @@ -20,6 +20,10 @@ from types import SimpleNamespace +# Import actual modules for backwards compatibility +from extremeweatherbench import calc, cases, defaults, derived, metrics, regions, utils + +# Import specific items for top-level access from extremeweatherbench.calc import ( convert_from_cartesian_to_latlon, geopotential_thickness, @@ -162,25 +166,11 @@ stack_dataarray_from_dims, ) -calc = SimpleNamespace( - geopotential_thickness=geopotential_thickness, - specific_humidity_from_relative_humidity=specific_humidity_from_relative_humidity, - convert_from_cartesian_to_latlon=convert_from_cartesian_to_latlon, - great_circle_mask=great_circle_mask, - maybe_calculate_wind_speed=maybe_calculate_wind_speed, - mixing_ratio=mixing_ratio, - orography=orography, - pressure_at_surface=pressure_at_surface, - saturation_mixing_ratio=saturation_mixing_ratio, - saturation_vapor_pressure=saturation_vapor_pressure, - haversine_distance=haversine_distance, -) - # Aliases evaluation = ExtremeWeatherBench load_cases = load_ewb_events_yaml_into_case_list -# Namespace submodules +# Namespace submodules for convenient grouping (these don't shadow actual modules) targets = SimpleNamespace( TargetBase=TargetBase, ERA5=ERA5, @@ -197,208 +187,153 @@ XarrayForecast=XarrayForecast, ) -metrics = SimpleNamespace( - BaseMetric=BaseMetric, - CompositeMetric=CompositeMetric, - ThresholdMetric=ThresholdMetric, - LandfallMetric=LandfallMetric, - MeanSquaredError=MeanSquaredError, - MeanAbsoluteError=MeanAbsoluteError, - MeanError=MeanError, - RootMeanSquaredError=RootMeanSquaredError, - CriticalSuccessIndex=CriticalSuccessIndex, - FalseAlarmRatio=FalseAlarmRatio, - Accuracy=Accuracy, - TruePositives=TruePositives, - FalsePositives=FalsePositives, - TrueNegatives=TrueNegatives, - FalseNegatives=FalseNegatives, - EarlySignal=EarlySignal, - MaximumMeanAbsoluteError=MaximumMeanAbsoluteError, - MinimumMeanAbsoluteError=MinimumMeanAbsoluteError, - MaximumLowestMeanAbsoluteError=MaximumLowestMeanAbsoluteError, - DurationMeanError=DurationMeanError, - SpatialDisplacement=SpatialDisplacement, - LandfallDisplacement=LandfallDisplacement, - LandfallTimeMeanError=LandfallTimeMeanError, - LandfallIntensityMeanAbsoluteError=LandfallIntensityMeanAbsoluteError, -) - -cases = SimpleNamespace( - IndividualCase=IndividualCase, - CaseOperator=CaseOperator, - build_case_operators=build_case_operators, - load_individual_cases=load_individual_cases, - load_individual_cases_from_yaml=load_individual_cases_from_yaml, - load_ewb_events_yaml_into_case_list=load_ewb_events_yaml_into_case_list, - read_incoming_yaml=read_incoming_yaml, -) - -derived = SimpleNamespace( - DerivedVariable=DerivedVariable, - TropicalCycloneTrackVariables=TropicalCycloneTrackVariables, - CravenBrooksSignificantSevere=CravenBrooksSignificantSevere, - AtmosphericRiverVariables=AtmosphericRiverVariables, - maybe_derive_variables=maybe_derive_variables, - maybe_include_variables_from_derived_input=maybe_include_variables_from_derived_input, -) - -regions = SimpleNamespace( - Region=Region, - CenteredRegion=CenteredRegion, - BoundingBoxRegion=BoundingBoxRegion, - ShapefileRegion=ShapefileRegion, - RegionSubsetter=RegionSubsetter, - REGION_TYPES=REGION_TYPES, - map_to_create_region=map_to_create_region, - subset_cases_to_region=subset_cases_to_region, - subset_results_to_region=subset_results_to_region, -) - -defaults = SimpleNamespace( - DEFAULT_COORDINATE_VARIABLES=DEFAULT_COORDINATE_VARIABLES, - DEFAULT_VARIABLE_NAMES=DEFAULT_VARIABLE_NAMES, - get_climatology=get_climatology, - get_brightband_evaluation_objects=get_brightband_evaluation_objects, - era5_heatwave_target=era5_heatwave_target, - era5_freeze_target=era5_freeze_target, - era5_atmospheric_river_target=era5_atmospheric_river_target, - ghcn_heatwave_target=ghcn_heatwave_target, - ghcn_freeze_target=ghcn_freeze_target, - lsr_target=lsr_target, - pph_target=pph_target, - ibtracs_target=ibtracs_target, - cira_heatwave_forecast=cira_heatwave_forecast, - cira_freeze_forecast=cira_freeze_forecast, - cira_tropical_cyclone_forecast=cira_tropical_cyclone_forecast, - cira_atmospheric_river_forecast=cira_atmospheric_river_forecast, - cira_severe_convection_forecast=cira_severe_convection_forecast, -) - -utils = SimpleNamespace( - maybe_get_operator=maybe_get_operator, - find_common_init_times=find_common_init_times, - is_valid_landfall=is_valid_landfall, - load_land_geometry=load_land_geometry, - maybe_cache_and_compute=maybe_cache_and_compute, - maybe_densify_dataarray=maybe_densify_dataarray, - maybe_get_closest_timestamp_to_center_of_valid_times=maybe_get_closest_timestamp_to_center_of_valid_times, - min_if_all_timesteps_present=min_if_all_timesteps_present, - min_if_all_timesteps_present_forecast=min_if_all_timesteps_present_forecast, - determine_temporal_resolution=determine_temporal_resolution, - convert_init_time_to_valid_time=convert_init_time_to_valid_time, - convert_valid_time_to_init_time=convert_valid_time_to_init_time, - convert_day_yearofday_to_time=convert_day_yearofday_to_time, - interp_climatology_to_target=interp_climatology_to_target, - check_for_vars=check_for_vars, - idx_to_coords=idx_to_coords, - extract_tc_names=extract_tc_names, - stack_dataarray_from_dims=stack_dataarray_from_dims, - convert_longitude_to_360=convert_longitude_to_360, - convert_longitude_to_180=convert_longitude_to_180, - derive_indices_from_init_time_and_lead_time=derive_indices_from_init_time_and_lead_time, - filter_kwargs_for_callable=filter_kwargs_for_callable, - remove_ocean_gridpoints=remove_ocean_gridpoints, - read_event_yaml=read_event_yaml, -) __all__ = [ + # Core evaluation "evaluation", "ExtremeWeatherBench", - "targets", - "forecasts", - "metrics", + # Modules + "calc", "cases", + "defaults", "derived", + "metrics", "regions", + "utils", + # Namespace submodules + "targets", + "forecasts", + # Aliases "load_cases", - "IndividualCase", + # calc + "convert_from_cartesian_to_latlon", + "geopotential_thickness", + "great_circle_mask", + "haversine_distance", + "maybe_calculate_wind_speed", + "mixing_ratio", + "orography", + "pressure_at_surface", + "saturation_mixing_ratio", + "saturation_vapor_pressure", + "specific_humidity_from_relative_humidity", + # cases "CaseOperator", + "IndividualCase", "build_case_operators", + "load_ewb_events_yaml_into_case_list", "load_individual_cases", "load_individual_cases_from_yaml", - "load_ewb_events_yaml_into_case_list", "read_incoming_yaml", - "InputBase", - "ForecastBase", - "TargetBase", - "EvaluationObject", - "ZarrForecast", - "KerchunkForecast", - "XarrayForecast", + # defaults + "DEFAULT_COORDINATE_VARIABLES", + "DEFAULT_VARIABLE_NAMES", + "cira_atmospheric_river_forecast", + "cira_freeze_forecast", + "cira_heatwave_forecast", + "cira_severe_convection_forecast", + "cira_tropical_cyclone_forecast", + "era5_atmospheric_river_target", + "era5_freeze_target", + "era5_heatwave_target", + "get_brightband_evaluation_objects", + "get_climatology", + "ghcn_freeze_target", + "ghcn_heatwave_target", + "ibtracs_target", + "lsr_target", + "pph_target", + # derived + "AtmosphericRiverVariables", + "CravenBrooksSignificantSevere", + "DerivedVariable", + "TropicalCycloneTrackVariables", + "maybe_derive_variables", + "maybe_include_variables_from_derived_input", + # inputs + "ARCO_ERA5_FULL_URI", + "CIRA_metadata_variable_mapping", + "DEFAULT_GHCN_URI", "ERA5", - "GHCN", - "IBTrACS", - "LSR", - "PPH", "ERA5_metadata_variable_mapping", - "CIRA_metadata_variable_mapping", + "EvaluationObject", + "ForecastBase", + "GHCN", "HRES_metadata_variable_mapping", + "IBTrACS", "IBTrACS_metadata_variable_mapping", - "ARCO_ERA5_FULL_URI", - "DEFAULT_GHCN_URI", + "IBTRACS_URI", + "InputBase", + "KerchunkForecast", + "LSR", "LSR_URI", + "PPH", "PPH_URI", - "IBTRACS_URI", - "open_kerchunk_reference", - "zarr_target_subsetter", + "TargetBase", + "XarrayForecast", + "ZarrForecast", "align_forecast_to_target", - "maybe_subset_variables", "check_for_missing_data", + "maybe_subset_variables", + "open_kerchunk_reference", + "zarr_target_subsetter", + # metrics + "Accuracy", "BaseMetric", "CompositeMetric", - "ThresholdMetric", - "LandfallMetric", - "MeanSquaredError", - "MeanAbsoluteError", - "MeanError", - "RootMeanSquaredError", "CriticalSuccessIndex", + "DurationMeanError", + "EarlySignal", "FalseAlarmRatio", - "Accuracy", - "TruePositives", - "FalsePositives", - "TrueNegatives", "FalseNegatives", - "EarlySignal", + "FalsePositives", + "LandfallDisplacement", + "LandfallIntensityMeanAbsoluteError", + "LandfallMetric", + "LandfallTimeMeanError", + "MaximumLowestMeanAbsoluteError", "MaximumMeanAbsoluteError", + "MeanAbsoluteError", + "MeanError", + "MeanSquaredError", "MinimumMeanAbsoluteError", - "MaximumLowestMeanAbsoluteError", - "DurationMeanError", + "RootMeanSquaredError", "SpatialDisplacement", - "LandfallDisplacement", - "LandfallTimeMeanError", - "LandfallIntensityMeanAbsoluteError", - "DerivedVariable", - "TropicalCycloneTrackVariables", - "CravenBrooksSignificantSevere", - "AtmosphericRiverVariables", - "maybe_derive_variables", - "maybe_include_variables_from_derived_input", - "Region", - "CenteredRegion", + "ThresholdMetric", + "TrueNegatives", + "TruePositives", + # regions "BoundingBoxRegion", - "ShapefileRegion", - "RegionSubsetter", + "CenteredRegion", "REGION_TYPES", + "Region", + "RegionSubsetter", + "ShapefileRegion", "map_to_create_region", "subset_cases_to_region", "subset_results_to_region", - "defaults", - "DEFAULT_COORDINATE_VARIABLES", - "DEFAULT_VARIABLE_NAMES", - "get_climatology", - "get_brightband_evaluation_objects", - "era5_heatwave_target", - "era5_freeze_target", - "era5_atmospheric_river_target", - "ghcn_heatwave_target", - "ghcn_freeze_target", - "lsr_target", - "pph_target", - "ibtracs_target", - "cira_heatwave_forecast", - "cira_freeze_forecast", - "cira_tropical_cyclone_forecast", - "cira_atmospheric_river_forecast", - "cira_severe_convection_forecast", + # utils + "check_for_vars", + "convert_day_yearofday_to_time", + "convert_init_time_to_valid_time", + "convert_longitude_to_180", + "convert_longitude_to_360", + "convert_valid_time_to_init_time", + "derive_indices_from_init_time_and_lead_time", + "determine_temporal_resolution", + "extract_tc_names", + "filter_kwargs_for_callable", + "find_common_init_times", + "idx_to_coords", + "interp_climatology_to_target", + "is_valid_landfall", + "load_land_geometry", + "maybe_cache_and_compute", + "maybe_densify_dataarray", + "maybe_get_closest_timestamp_to_center_of_valid_times", + "maybe_get_operator", + "min_if_all_timesteps_present", + "min_if_all_timesteps_present_forecast", + "read_event_yaml", + "remove_ocean_gridpoints", + "stack_dataarray_from_dims", ] diff --git a/tests/test_evaluate_cli.py b/tests/test_evaluate_cli.py index 11c87493..2bda6879 100644 --- a/tests/test_evaluate_cli.py +++ b/tests/test_evaluate_cli.py @@ -469,7 +469,9 @@ def test_empty_results_handling(self, mock_ewb_class, mock_load_cases, runner): class TestHelperFunctions: """Test helper function functionality.""" - @mock.patch("extremeweatherbench.cases.load_ewb_events_yaml_into_case_list") + @mock.patch( + "extremeweatherbench.evaluate_cli.cases.load_ewb_events_yaml_into_case_list" + ) def test_load_default_cases(self, mock_load_yaml): """Test _load_default_cases function.""" mock_cases = [{"id": 1}] diff --git a/tests/test_init.py b/tests/test_init.py new file mode 100644 index 00000000..d66b84a8 --- /dev/null +++ b/tests/test_init.py @@ -0,0 +1,261 @@ +"""Tests for the extremeweatherbench package __init__.py API.""" + +import types + +import pytest + + +class TestModuleImports: + """Test that submodules are importable and are actual modules.""" + + def test_calc_is_module(self): + """Test that calc is an actual module, not a SimpleNamespace.""" + from extremeweatherbench import calc + + assert isinstance(calc, types.ModuleType) + + def test_utils_is_module(self): + """Test that utils is an actual module, not a SimpleNamespace.""" + from extremeweatherbench import utils + + assert isinstance(utils, types.ModuleType) + + def test_metrics_is_module(self): + """Test that metrics is an actual module, not a SimpleNamespace.""" + from extremeweatherbench import metrics + + assert isinstance(metrics, types.ModuleType) + + def test_regions_is_module(self): + """Test that regions is an actual module, not a SimpleNamespace.""" + from extremeweatherbench import regions + + assert isinstance(regions, types.ModuleType) + + def test_derived_is_module(self): + """Test that derived is an actual module, not a SimpleNamespace.""" + from extremeweatherbench import derived + + assert isinstance(derived, types.ModuleType) + + def test_defaults_is_module(self): + """Test that defaults is an actual module, not a SimpleNamespace.""" + from extremeweatherbench import defaults + + assert isinstance(defaults, types.ModuleType) + + def test_cases_is_module(self): + """Test that cases is an actual module, not a SimpleNamespace.""" + from extremeweatherbench import cases + + assert isinstance(cases, types.ModuleType) + + +class TestModuleAccessPatterns: + """Test both import patterns work identically.""" + + def test_ewb_dot_notation_equals_direct_import_calc(self): + """Test ewb.calc is the same object as direct import.""" + import extremeweatherbench as ewb + from extremeweatherbench import calc + + assert ewb.calc is calc + + def test_ewb_dot_notation_equals_direct_import_metrics(self): + """Test ewb.metrics is the same object as direct import.""" + import extremeweatherbench as ewb + from extremeweatherbench import metrics + + assert ewb.metrics is metrics + + def test_ewb_dot_notation_equals_direct_import_utils(self): + """Test ewb.utils is the same object as direct import.""" + import extremeweatherbench as ewb + from extremeweatherbench import utils + + assert ewb.utils is utils + + +class TestModuleLevelConstants: + """Test that module-level constants are accessible.""" + + def test_calc_g0_accessible(self): + """Test that calc.g0 constant is accessible.""" + from extremeweatherbench import calc + + assert hasattr(calc, "g0") + assert calc.g0 == 9.80665 + + def test_calc_epsilon_accessible(self): + """Test that calc.epsilon constant is accessible.""" + from extremeweatherbench import calc + + assert hasattr(calc, "epsilon") + assert isinstance(calc.epsilon, float) + + +class TestPrivateFunctionAccess: + """Test that private functions are accessible for testing purposes.""" + + def test_calc_private_functions_accessible(self): + """Test that private functions in calc are accessible.""" + from extremeweatherbench import calc + + assert hasattr(calc, "_is_true_landfall") + assert hasattr(calc, "_detect_landfalls_wrapper") + assert hasattr(calc, "_mask_init_time_boundaries") + assert hasattr(calc, "_interpolate_and_format_landfalls") + + def test_utils_private_functions_accessible(self): + """Test that private functions in utils are accessible.""" + from extremeweatherbench import utils + + assert hasattr(utils, "_create_nan_dataarray") + assert hasattr(utils, "_cache_maybe_densify_helper") + + def test_derived_private_functions_accessible(self): + """Test that private functions in derived are accessible.""" + from extremeweatherbench import derived + + assert hasattr(derived, "_maybe_convert_variable_to_string") + + def test_defaults_private_functions_accessible(self): + """Test that private functions in defaults are accessible.""" + from extremeweatherbench import defaults + + assert hasattr(defaults, "_preprocess_cira_forecast_dataset") + + def test_regions_private_functions_accessible(self): + """Test that private functions in regions are accessible.""" + from extremeweatherbench import regions + + assert hasattr(regions, "_adjust_bounds_to_dataset_convention") + + +class TestPublicFunctionAccess: + """Test that all public functions are accessible via module.""" + + def test_calc_public_functions(self): + """Test public functions in calc are accessible.""" + from extremeweatherbench import calc + + assert hasattr(calc, "find_landfalls") + assert hasattr(calc, "nantrapezoid") + assert hasattr(calc, "dewpoint_from_specific_humidity") + assert hasattr(calc, "find_land_intersection") + assert hasattr(calc, "haversine_distance") + + def test_utils_public_functions(self): + """Test public functions in utils are accessible.""" + from extremeweatherbench import utils + + assert hasattr(utils, "reduce_dataarray") + assert hasattr(utils, "stack_dataarray_from_dims") + assert hasattr(utils, "convert_longitude_to_360") + + +class TestTopLevelImports: + """Test that top-level imports work for commonly used items.""" + + def test_top_level_metric_imports(self): + """Test that metrics can be imported at top level.""" + from extremeweatherbench import ( + MeanAbsoluteError, + MeanError, + MeanSquaredError, + RootMeanSquaredError, + ) + + assert MeanAbsoluteError is not None + assert MeanError is not None + assert MeanSquaredError is not None + assert RootMeanSquaredError is not None + + def test_top_level_input_imports(self): + """Test that input classes can be imported at top level.""" + from extremeweatherbench import ERA5, GHCN, IBTrACS, ZarrForecast + + assert ERA5 is not None + assert GHCN is not None + assert IBTrACS is not None + assert ZarrForecast is not None + + def test_top_level_region_imports(self): + """Test that region classes can be imported at top level.""" + from extremeweatherbench import BoundingBoxRegion, CenteredRegion, Region + + assert Region is not None + assert BoundingBoxRegion is not None + assert CenteredRegion is not None + + def test_top_level_case_imports(self): + """Test that case classes can be imported at top level.""" + from extremeweatherbench import CaseOperator, IndividualCase + + assert IndividualCase is not None + assert CaseOperator is not None + + def test_evaluation_alias(self): + """Test that evaluation alias works.""" + from extremeweatherbench import ExtremeWeatherBench, evaluation + + assert evaluation is ExtremeWeatherBench + + def test_load_cases_alias(self): + """Test that load_cases alias works.""" + from extremeweatherbench import ( + load_cases, + load_ewb_events_yaml_into_case_list, + ) + + assert load_cases is load_ewb_events_yaml_into_case_list + + +class TestNamespaceSubmodules: + """Test the convenience namespace submodules.""" + + def test_targets_namespace(self): + """Test targets SimpleNamespace contains expected items.""" + from extremeweatherbench import targets + + assert isinstance(targets, types.SimpleNamespace) + assert hasattr(targets, "ERA5") + assert hasattr(targets, "GHCN") + assert hasattr(targets, "IBTrACS") + assert hasattr(targets, "TargetBase") + + def test_forecasts_namespace(self): + """Test forecasts SimpleNamespace contains expected items.""" + from extremeweatherbench import forecasts + + assert isinstance(forecasts, types.SimpleNamespace) + assert hasattr(forecasts, "ZarrForecast") + assert hasattr(forecasts, "KerchunkForecast") + assert hasattr(forecasts, "ForecastBase") + + +class TestMockPatching: + """Test that mock.patch.object works with module imports.""" + + def test_mock_patch_object_on_calc(self): + """Test that mock.patch.object works on calc module.""" + from unittest import mock + + from extremeweatherbench import calc + + with mock.patch.object(calc, "haversine_distance") as mock_func: + mock_func.return_value = 42.0 + result = calc.haversine_distance([0, 0], [1, 1]) + assert result == 42.0 + mock_func.assert_called_once() + + def test_mock_patch_string_on_calc(self): + """Test that mock.patch with string path works on calc module.""" + from unittest import mock + + with mock.patch("extremeweatherbench.calc.haversine_distance") as mock_func: + mock_func.return_value = 100.0 + from extremeweatherbench import calc + + result = calc.haversine_distance([0, 0], [1, 1]) + assert result == 100.0 From 0a4a653e84545d98724ae8210aa2528e122099fe Mon Sep 17 00:00:00 2001 From: taylor Date: Sat, 24 Jan 2026 02:15:22 +0000 Subject: [PATCH 09/13] ruff --- tests/test_init.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_init.py b/tests/test_init.py index d66b84a8..00517ae8 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -2,8 +2,6 @@ import types -import pytest - class TestModuleImports: """Test that submodules are importable and are actual modules.""" From 6a12dcdc934fc8ee52fc17570bd4398a4f60197e Mon Sep 17 00:00:00 2001 From: Taylor Mandelbaum Date: Fri, 23 Jan 2026 21:25:26 -0500 Subject: [PATCH 10/13] Cleanup docstrings in repo (#318) * update these docstrings * remove docstring changes markdown * update docstrings * update other docstrings * remove individualcasecollection reference, update based on develop changes --- src/extremeweatherbench/calc.py | 2 +- src/extremeweatherbench/cases.py | 62 ++- src/extremeweatherbench/derived.py | 96 ++-- src/extremeweatherbench/evaluate.py | 15 +- src/extremeweatherbench/evaluate_cli.py | 19 +- src/extremeweatherbench/inputs.py | 94 +++- src/extremeweatherbench/metrics.py | 581 +++++++++++++++--------- src/extremeweatherbench/regions.py | 96 ++-- src/extremeweatherbench/sources/base.py | 12 +- 9 files changed, 597 insertions(+), 380 deletions(-) diff --git a/src/extremeweatherbench/calc.py b/src/extremeweatherbench/calc.py index ef314349..28fcf94f 100644 --- a/src/extremeweatherbench/calc.py +++ b/src/extremeweatherbench/calc.py @@ -259,7 +259,7 @@ def geopotential_thickness( pressure_dim: The name of the pressure dimension. Default is "level". Returns: - The geopotential thickness in metersas an xarray DataArray. + The geopotential thickness in meters as an xarray DataArray. """ geopotential_heights = da.sel({pressure_dim: top_level}) geopotential_height_bottom = da.sel({pressure_dim: bottom_level}) diff --git a/src/extremeweatherbench/cases.py b/src/extremeweatherbench/cases.py index 42a6370f..236104c9 100644 --- a/src/extremeweatherbench/cases.py +++ b/src/extremeweatherbench/cases.py @@ -25,18 +25,15 @@ @dataclasses.dataclass class IndividualCase: - """Container for metadata defining a single or individual case. - - An IndividualCase defines the relevant metadata for a single case study for a - given extreme weather event; it is designed to be easily instantiable through a - simple YAML-based configuration file. + """Container for metadata defining a single case study. Attributes: - case_id_number: A unique numerical identifier for the event. - start_date: The start date of the case, for use in subsetting data for analysis. - end_date: The end date of the case, for use in subsetting data for analysis. - location: A Location dataclass representing the location of a case. - event_type: A string representing the type of extreme weather event. + case_id_number: Unique numerical identifier for the event. + title: Title of the case study. + start_date: Start date for subsetting data for analysis. + end_date: End date for subsetting data for analysis. + location: Region object representing the case location. + event_type: String representing the type of extreme weather event. """ case_id_number: int @@ -49,18 +46,13 @@ class IndividualCase: @dataclasses.dataclass class CaseOperator: - """A class which stores the graph to process an individual case. - - This class is used to store the graph to process an individual case. The purpose of - this class is to be a one-stop-shop for the evaluation of a single case. Multiple - CaseOperators can be run in parallel to evaluate multiple cases, or run through the - ExtremeWeatherBench.run() method to evaluate all cases in an evaluation in serial. + """Operator dataclass for an evaluation of a single evaluation object. Attributes: - case_metadata: IndividualCase metadata - metric_list: A list of metrics that are to be evaluated for the case operator - target_config: A TargetConfig object - forecast_config: A ForecastConfig object + case_metadata: IndividualCase metadata for this operator. + metric_list: List of metrics to evaluate for this case. + target: TargetBase object for ground truth data. + forecast: ForecastBase object for forecast data. """ case_metadata: IndividualCase @@ -76,8 +68,7 @@ def build_case_operators( """Build a CaseOperator from the case metadata and metric evaluation objects. Args: - cases: The case metadata to use for the case operators as a dictionary of cases - or a list of IndividualCases. + case_list: List of IndividualCase objects defining cases to process. evaluation_objects: The evaluation objects to apply to the case operators. Returns: @@ -109,7 +100,7 @@ def load_individual_cases( Will pass through existing IndividualCase objects and convert dictionaries to IndividualCase objects. Args: - cases: A dictionary of cases based on the IndividualCase dataclass. + cases: A list of cases as either dicts or IndividualCase objects. Returns: A list of IndividualCase objects. @@ -147,19 +138,18 @@ def load_individual_cases_from_yaml( Example of a yaml file: ```yaml - cases: - - case_id_number: 1 - title: Event 1 - start_date: 2021-01-01 00:00:00 - end_date: 2021-01-03 00:00:00 - location: - type: bounded_region - parameters: - latitude_min: 10.0 - latitude_max: 55.6 - longitude_min: 265.0 - longitude_max: 283.3 - event_type: tropical_cyclone + - case_id_number: 1 + title: Event 1 + start_date: 2021-01-01 00:00:00 + end_date: 2021-01-03 00:00:00 + location: + type: bounded_region + parameters: + latitude_min: 10.0 + latitude_max: 55.6 + longitude_min: 265.0 + longitude_max: 283.3 + event_type: tropical_cyclone ``` Args: diff --git a/src/extremeweatherbench/derived.py b/src/extremeweatherbench/derived.py index f94eefee..7517d232 100644 --- a/src/extremeweatherbench/derived.py +++ b/src/extremeweatherbench/derived.py @@ -16,21 +16,26 @@ class DerivedVariable(abc.ABC): - """An abstract base class defining the interface for ExtremeWeatherBench - derived variables. - - A DerivedVariable is any variable or transform that requires extra computation than - what is provided in analysis or forecast data. Some examples include the - practically perfect hindcast, MLCAPE, IVT, or atmospheric river masks. - - Attributes: - variables: A list of variables that are used to build the - derived variable. - output_variables: Optional list of variable names that specify - which outputs to use from the derived computation. - compute: A method that generates the derived variable from the variables. - derive_variable: An abstract method that defines the computation to - derive the derived_variable from variables. + """Abstract base class for ExtremeWeatherBench derived variables. + + A DerivedVariable is any variable or transform that requires extra + computation beyond what is provided in analysis or forecast data. Examples + include the practically perfect hindcast, MLCAPE, IVT, or atmospheric + river masks. + + Class attributes: + variables: List of variables used to build the derived variable + + Instance attributes: + name: The name of the derived variable + output_variables: Optional list of variable names specifying which + outputs to use from the derived computation + + Public methods: + compute: Build the derived variable from input variables + + Abstract methods: + derive_variable: Define the computation to derive the variable """ variables: List[str] @@ -81,33 +86,28 @@ def compute(self, data: xr.Dataset, *args, **kwargs) -> xr.DataArray: class TropicalCycloneTrackVariables(DerivedVariable): - """A derived variable abstract class for tropical cyclone (TC) variables. - - This class serves as a parent for TC-related derived variables and provides - shared track computation with caching to avoid reprocessing the same data - multiple times across different child classes. + """Derived variable class for tropical cyclone track-based variables. - The track data is computed once and cached, then child classes can extract - specific variables (like sea level pressure, wind speed) from the cached - track dataset. + Extends DerivedVariable to provide shared track computation with caching + for TC-related derived variables, avoiding reprocessing across child + classes. Track data is computed once and cached, then child classes can + extract specific variables (sea level pressure, wind speed, etc.). - Deriving the track locations using default TempestExtremes criteria: + Uses default TempestExtremes criteria for track identification: https://doi.org/10.5194/gmd-14-5023-2021 - For forecast data, when track data is provided, the valid candidates - approach is filtered to only include candidates within 5 great circle - degrees of track data points and within 48 hours of the valid_time. + For forecasts with track data, valid candidates are filtered to include + only those within 5 great circle degrees and 48 hours of track points. - Track data is automatically obtained from the target dataset when using - the evaluation pipeline (via `requires_target_dataset=True` flag). + Track data is automatically obtained from target dataset via + `requires_target_dataset=True` flag in evaluation pipeline. - Attributes: - output_variables: Optional list of variable names that specify - which outputs to use from the derived computation. - name: The name of the derived variable. Defaults to class-level - name attribute if present, otherwise the class name. - requires_target_dataset: If True, target dataset will be passed to - this derived variable via kwargs. + Class attributes: + requires_target_dataset: If True, target dataset passed via kwargs + + Instance attributes: + output_variables: Optional list specifying which outputs to use + name: Name of the derived variable """ # required variables for TC track identification @@ -287,8 +287,10 @@ def derive_variable(self, data: xr.Dataset, *args, **kwargs) -> xr.DataArray: class CravenBrooksSignificantSevere(DerivedVariable): - """A derived variable that computes the Craven-Brooks significant severe - convection index. + """Derived variable for Craven-Brooks significant severe convection index. + + Extends DerivedVariable to compute the Craven-Brooks index for assessing + significant severe convection potential. """ variables = [ @@ -391,18 +393,18 @@ def derive_variable( class AtmosphericRiverVariables(DerivedVariable): - """A derived variable that computes atmospheric river related variables. + """Derived variable for atmospheric river detection and characterization. - Calculates the IVT (Integrated Vapor Transport), atmospheric river mask, and land - intersection. IVT is calculated using the method described in Newell et al. 1992 and - elsewhere (e.g. Mo 2024). + Extends DerivedVariable to compute IVT (Integrated Vapor Transport), + atmospheric river mask, and land intersection. IVT calculation follows + Newell et al. 1992 and elsewhere (e.g. Mo 2024). - Output variables are: integrated_vapor_transport, atmospheric_river_mask, and - atmospheric_river_land_intersection. Users must declare at least one of the output - variables they want when calling the derived variable. + Output variables: integrated_vapor_transport, atmospheric_river_mask, + atmospheric_river_land_intersection. Users must declare at least one + output variable when calling the derived variable. - The Laplacian of IVT is calculated using a Gaussian blurring kernel with a - sigma of 3 grid points, meant to smooth out 0.25 degree grid scale features. + The Laplacian of IVT uses a Gaussian blurring kernel with sigma of 3 + grid points to smooth 0.25 degree grid scale features. """ variables = [ diff --git a/src/extremeweatherbench/evaluate.py b/src/extremeweatherbench/evaluate.py index 4412bdc7..504d7694 100644 --- a/src/extremeweatherbench/evaluate.py +++ b/src/extremeweatherbench/evaluate.py @@ -50,7 +50,7 @@ class ExtremeWeatherBench: results. Attributes: - case_metadata: A dictionary of cases or a list of IndividualCase objects to run. + case_metadata: A list of case dicts or IndividualCase objects to run. evaluation_objects: A list of evaluation objects to run. cache_dir: An optional directory to cache the mid-flight outputs of the workflow for serial runs. @@ -66,6 +66,16 @@ def __init__( cache_dir: Optional[Union[str, pathlib.Path]] = None, region_subsetter: Optional["regions.RegionSubsetter"] = None, ): + """Initialize the ExtremeWeatherBench workflow. + + Args: + case_metadata: List of case dicts or IndividualCase objects. + evaluation_objects: List of evaluation objects to run. + cache_dir: Optional directory for caching mid-flight outputs in + serial runs. + region_subsetter: Optional RegionSubsetter to filter cases by + spatial region. + """ # Load the case metadata from the input self.case_metadata = cases.load_individual_cases(case_metadata) self.evaluation_objects = evaluation_objects @@ -228,7 +238,8 @@ def _run_evaluation( Args: case_operators: List of case operators to run. cache_dir: Optional directory for caching (serial mode only). - **kwargs: Additional arguments, may include 'parallel_config' dict. + parallel_config: Optional dict of joblib parallel configuration. + **kwargs: Additional keyword arguments passed to case operators. Returns: List of result DataFrames. diff --git a/src/extremeweatherbench/evaluate_cli.py b/src/extremeweatherbench/evaluate_cli.py index a4045062..1a96dfc1 100644 --- a/src/extremeweatherbench/evaluate_cli.py +++ b/src/extremeweatherbench/evaluate_cli.py @@ -81,18 +81,17 @@ def cli_runner( save CaseOperator objects for later use or inspection. Args: - default: Use default Brightband evaluation objects with current directory as - output - config_file: Path to a config.py file containing evaluation objects - output_dir: Directory for analysis outputs (default: current directory) + default: Use default Brightband evaluation objects with current directory + as output. + config_file: Path to a config.py file containing evaluation objects. + output_dir: Directory for analysis outputs (default: current directory). cache_dir: Optional directory for caching intermediate data. When set, datasets or dataarrays are computed and cached as zarrs. - parallel_config: Parallel configuration using joblib (default: {'backend': - 'threading', 'n_jobs': 8}) - save_case_operators: Save CaseOperator objects to a pickle file at this path - n_jobs: Number of parallel jobs to run (default: 1 for serial execution) - parallel_config: Advanced parallel configuration using joblib. Takes precedence - over --n-jobs if provided. + n_jobs: Number of parallel jobs to run (default: 1 for serial execution). + parallel_config: Advanced parallel configuration using joblib. Takes + precedence over n_jobs if provided. + save_case_operators: Save CaseOperator objects to a pickle file at this + path. Examples: # Use default evaluation objects $ ewb --default diff --git a/src/extremeweatherbench/inputs.py b/src/extremeweatherbench/inputs.py index b275d2fa..e1353fc1 100644 --- a/src/extremeweatherbench/inputs.py +++ b/src/extremeweatherbench/inputs.py @@ -156,15 +156,29 @@ def _default_preprocess(input_data: IncomingDataInput) -> IncomingDataInput: @dataclasses.dataclass class InputBase(abc.ABC): - """An abstract base dataclass for target and forecast data. + """Abstract base dataclass for target and forecast data. + + This class provides the foundational interface for loading and processing + forecast and target datasets in ExtremeWeatherBench. Attributes: - source: The source of the data, which can be a local path or a remote URL/URI. + source: The source of the data, which can be a local path or a + remote URL/URI. name: The name of the input data source. variables: A list of variables to select from the data. variable_mapping: A dictionary of variable names to map to the data. storage_options: Storage/access options for the data. preprocess: A function to preprocess the data. + + Public methods: + open_and_maybe_preprocess_data_from_source: Open and preprocess data + maybe_convert_to_dataset: Convert input data to xarray Dataset + add_source_to_dataset_attrs: Add source name to dataset attributes + maybe_map_variable_names: Map variable names if mapping provided + + Abstract methods: + _open_data_from_source: Open the input data from source + subset_data_to_case: Subset data to case metadata """ source: str @@ -307,7 +321,17 @@ def maybe_map_variable_names(self, data: IncomingDataInput) -> IncomingDataInput @dataclasses.dataclass class ForecastBase(InputBase): - """A class defining the interface for ExtremeWeatherBench forecast data.""" + """Forecast data interface for ExtremeWeatherBench. + + Extends InputBase to provide functionality for forecast datasets with + init_time and lead_time dimensions. + + Attributes: + chunks: Chunking strategy for dask arrays. Defaults to "auto". + + Public methods: + subset_data_to_case: Subset forecast data to case (overrides parent) + """ chunks: Optional[Union[dict, str]] = "auto" @@ -404,7 +428,11 @@ class EvaluationObject: @dataclasses.dataclass class KerchunkForecast(ForecastBase): - """Forecast class for kerchunked forecast data.""" + """Forecast class for kerchunk-referenced forecast data. + + Extends ForecastBase for forecast data accessed via kerchunk references, + enabling efficient access to cloud-optimized datasets. + """ chunks: Optional[Union[dict, str]] = "auto" storage_options: dict = dataclasses.field(default_factory=dict) @@ -419,7 +447,10 @@ def _open_data_from_source(self) -> IncomingDataInput: @dataclasses.dataclass class ZarrForecast(ForecastBase): - """Forecast class for zarr forecast data.""" + """Forecast class for zarr-format forecast data. + + Extends ForecastBase for forecast data stored in zarr format. + """ chunks: Optional[Union[dict, str]] = "auto" @@ -434,11 +465,11 @@ def _open_data_from_source(self) -> IncomingDataInput: @dataclasses.dataclass class XarrayForecast(ForecastBase): - """Forecast class for datasets that were previously constructed and opened using xarray. + """Forecast class for pre-opened xarray datasets. - This class is intended for situations where the user has to manually prepare a dataset to - use in their evaluation. This can happen when the user is manually constructed such a - dataset from a collection of NetCDF or Zarr archives which need to be assembled into a + Extends ForecastBase for datasets previously constructed and opened using + xarray. Intended for situations where users manually prepare datasets from + collections of NetCDF or Zarr archives that need assembly into a single, master dataset. Attributes: @@ -483,12 +514,15 @@ def _open_data_from_source(self) -> xr.Dataset: @dataclasses.dataclass class TargetBase(InputBase): - """An abstract base class for target data. + """Target (truth) data interface for ExtremeWeatherBench. + + Extends InputBase to provide functionality for target datasets that serve + as ground truth for evaluation. Target data can be gridded datasets, point + observations, or any reference dataset. Targets need not match forecast + variables but must share a compatible coordinate system for evaluation. - A TargetBase is data that acts as the "truth" for a case. It can be a gridded - dataset, a point observation dataset, or any other reference dataset. Targets in EWB - are not required to be the same variable as the forecast dataset, but they must be - in the same coordinate system for evaluation. + Public methods: + maybe_align_forecast_to_target: Align forecast to target coordinates """ def maybe_align_forecast_to_target( @@ -516,8 +550,10 @@ def maybe_align_forecast_to_target( @dataclasses.dataclass class ERA5(TargetBase): - """Target class for ERA5 gridded data, ideally using the ARCO ERA5 dataset provided - by Google. Otherwise, either a different zarr source for ERA5. + """Target class for ERA5 gridded reanalysis data. + + Extends TargetBase for ERA5 data, optimized for the ARCO ERA5 dataset + provided by Google or other zarr-based ERA5 sources. """ name: str = "ERA5" @@ -572,10 +608,10 @@ def maybe_align_forecast_to_target( @dataclasses.dataclass class GHCN(TargetBase): - """Target class for GHCN tabular data. + """Target class for GHCN (Global Historical Climatology Network) data. - Data is processed using polars to maintain the lazy loading paradigm in - open_data_from_source and to separate the subsetting into subset_data_to_case. + Extends TargetBase for tabular GHCN station observation data. Uses polars + for lazy loading and efficient subsetting of large tabular datasets. """ name: str = "GHCN" @@ -646,10 +682,11 @@ def maybe_align_forecast_to_target( @dataclasses.dataclass class LSR(TargetBase): - """Target class for local storm report (LSR) tabular data. + """Target class for Local Storm Report (LSR) tabular data. - run_pipeline() returns a dataset with LSRs as mapped to numeric values (1=wind, 2=hail, 3=tor). IndividualCase date ranges for LSRs should be 12 UTC to - the next day at 12 UTC (exclusive) to match SPC's reporting window. + Extends TargetBase for SPC local storm reports. Returns dataset with LSRs + mapped to numeric values (1=wind, 2=hail, 3=tornado). IndividualCase date + ranges should be 12 UTC to next day 12 UTC to match SPC reporting window. """ name: str = "local_storm_reports" @@ -748,7 +785,10 @@ def maybe_align_forecast_to_target( # TODO: get PPH connector working properly @dataclasses.dataclass class PPH(TargetBase): - """Target class for practically perfect hindcast data.""" + """Target class for Practically Perfect Hindcast (PPH) data. + + Extends TargetBase for practically perfect hindcast datasets. + """ name: str = "practically_perfect_hindcast" source: str = PPH_URI @@ -880,7 +920,11 @@ def _ibtracs_preprocess(data: IncomingDataInput) -> IncomingDataInput: @dataclasses.dataclass class IBTrACS(TargetBase): - """Target class for IBTrACS data.""" + """Target class for IBTrACS tropical cyclone best track data. + + Extends TargetBase for International Best Track Archive for Climate + Stewardship (IBTrACS) tropical cyclone track and intensity data. + """ name: str = "IBTrACS" preprocess: Callable = _ibtracs_preprocess @@ -1065,6 +1109,7 @@ def open_icechunk_dataset_from_datatree( group: The group within the datatree to open. branch: The icechunk branch to open. Defaults to "main". chunks: The chunk pattern for the datatree. defaults to "auto". + Returns: The dataset for the specified group. """ @@ -1089,6 +1134,7 @@ def zarr_target_subsetter( data: The dataset to subset. case_metadata: The case metadata to subset the dataset to. time_variable: The time variable to use; defaults to "valid_time". + drop: Whether to drop masked values. Defaults to False. Returns: The subset dataset. diff --git a/src/extremeweatherbench/metrics.py b/src/extremeweatherbench/metrics.py index 21bed2b5..873009af 100644 --- a/src/extremeweatherbench/metrics.py +++ b/src/extremeweatherbench/metrics.py @@ -50,19 +50,22 @@ def _compute_metric_with_docstring(self, *args, **kwargs): class BaseMetric(abc.ABC, metaclass=ComputeDocstringMetaclass): - """A BaseMetric class is an abstract class that defines the foundational interface - for all metrics. - - Metrics are general operations applied between a forecast and analysis xarray - DataArray. EWB metrics prioritize the use of any arbitrary sets of forecasts and - analyses, so long as the spatiotemporal dimensions are the same. - - Args: - name: The name of the metric. - preserve_dims: The dimensions to preserve in the computation. Defaults to - "lead_time". - forecast_variable: The forecast variable to use in the computation. - target_variable: The target variable to use in the computation. + """Abstract base class defining the foundational interface for all metrics. + + Metrics are general operations applied between forecast and analysis xarray + DataArrays. EWB metrics prioritize the use of any arbitrary sets of + forecasts and analyses, so long as the spatiotemporal dimensions are the + same. + + Public methods: + compute_metric: Public interface to compute the metric + maybe_expand_composite: Expand composite metrics into individual metrics + is_composite: Check if this is a composite metric + __repr__: String representation of the metric + __eq__: Check equality with another metric + + Abstract methods: + _compute_metric: Logic to compute the metric (must be implemented) """ def __init__( @@ -72,6 +75,16 @@ def __init__( forecast_variable: Optional[str | derived.DerivedVariable] = None, target_variable: Optional[str | derived.DerivedVariable] = None, ): + """Initialize the base metric. + + Args: + name: The name of the metric. + preserve_dims: The dimensions to preserve in the computation. + Defaults to "lead_time". + forecast_variable: The forecast variable to use in the + computation. + target_variable: The target variable to use in the computation. + """ # Store the original variables (str or DerivedVariable instances) # Do NOT convert to string to preserve output_variables info self.name = name @@ -179,13 +192,27 @@ def maybe_prepare_composite_kwargs( class CompositeMetric(BaseMetric): - """Base class for composite metrics. + """Base class for composite metrics that can contain multiple sub-metrics. + + Extends BaseMetric to provide functionality for composite metrics that + aggregate multiple individual metrics for efficient evaluation. - This class provides common functionality for composite metrics. - Accepts the same arguments as BaseMetric. + Public methods: + maybe_expand_composite: Expand into individual metrics (overrides base) + is_composite: Check if has sub-metrics (overrides base) + + Abstract methods: + maybe_prepare_composite_kwargs: Prepare kwargs for composite evaluation + _compute_metric: Compute the metric (must be implemented by subclasses) """ def __init__(self, *args, **kwargs): + """Initialize the composite metric. + + Args: + *args: Positional arguments passed to BaseMetric.__init__ + **kwargs: Keyword arguments passed to BaseMetric.__init__ + """ super().__init__(*args, **kwargs) self._metric_instances: list["BaseMetric"] = [] @@ -242,36 +269,31 @@ def _compute_metric( class ThresholdMetric(CompositeMetric): - """Base class for threshold-based metrics. - - This class provides common functionality for metrics that require - forecast and target thresholds for binarization. - - Args: - name: The name of the metric. Defaults to "threshold_metrics". - preserve_dims: The dimensions to preserve in the computation. Defaults to - "lead_time". - forecast_variable: The forecast variable to use in the computation. - target_variable: The target variable to use in the computation. - forecast_threshold: The threshold for binarizing the forecast. Defaults to 0.5. - target_threshold: The threshold for binarizing the target. Defaults to 0.5. - metrics: A list of metrics to use as a composite. Defaults to None. - - Can be used in two ways: - 1. As a base class for specific threshold metrics (CriticalSuccessIndex, - FalseAlarmRatio, etc.) - 2. As a composite metric to compute multiple threshold metrics - efficiently by reusing the transformed contingency manager. - - Example of composite usage: + """Base class for threshold-based metrics with binary classification. + + Extends CompositeMetric to provide functionality for metrics that require + forecast and target thresholds for binarization. Can be used as a base + class for specific threshold metrics or as a composite metric. + + Public methods: + transformed_contingency_manager: Create contingency manager + maybe_prepare_composite_kwargs: Prepare kwargs (overrides parent) + __call__: Make instances callable with configured thresholds + + Abstract methods: + _compute_metric: Compute the metric (must be implemented by subclasses) + + Usage patterns: + 1. As a base class for specific metrics (CriticalSuccessIndex, etc.) + 2. As a composite metric to compute multiple threshold metrics + efficiently by reusing the transformed contingency manager + + Example: composite = ThresholdMetric( metrics=[CriticalSuccessIndex, FalseAlarmRatio, Accuracy], forecast_threshold=0.7, target_threshold=0.5 ) - results = composite.compute_metric(forecast, target) - # Returns: {"critical_success_index": ..., - # "false_alarm_ratio": ..., "accuracy": ...} """ def __init__( @@ -285,6 +307,23 @@ def __init__( metrics: Optional[list[Type["ThresholdMetric"]]] = None, **kwargs, ): + """Initialize the threshold metric. + + Args: + name: The name of the metric. Defaults to "threshold_metrics". + preserve_dims: The dimensions to preserve in the computation. + Defaults to "lead_time". + forecast_variable: The forecast variable to use in the + computation. + target_variable: The target variable to use in the computation. + forecast_threshold: The threshold for binarizing the forecast. + Defaults to 0.5. + target_threshold: The threshold for binarizing the target. + Defaults to 0.5. + metrics: A list of metrics to use as a composite. Defaults to + None. + **kwargs: Additional keyword arguments passed to parent. + """ super().__init__( name, preserve_dims=preserve_dims, @@ -430,16 +469,22 @@ def _compute_metric( class CriticalSuccessIndex(ThresholdMetric): - """Critical Success Index metric. + """Compute Critical Success Index (CSI) from binary classifications. - The Critical Success Index is computed between the forecast and target using the - preserve_dims dimensions. - - Args: - name: The name of the metric. Defaults to "CriticalSuccessIndex". + Extends ThresholdMetric to compute CSI between forecast and target using + the preserve_dims dimensions. CSI measures the fraction of correctly + predicted events. """ def __init__(self, name: str = "CriticalSuccessIndex", *args, **kwargs): + """Initialize the Critical Success Index metric. + + Args: + name: The name of the metric. Defaults to + "CriticalSuccessIndex". + *args: Additional positional arguments passed to ThresholdMetric. + **kwargs: Additional keyword arguments passed to ThresholdMetric. + """ super().__init__(name, *args, **kwargs) def _compute_metric( @@ -462,16 +507,21 @@ def _compute_metric( class FalseAlarmRatio(ThresholdMetric): - """False Alarm Ratio metric. - - The False Alarm Ratio is computed between the forecast and target using the - preserve_dims dimensions. Note that this is not the same as the False Alarm Rate. + """Compute False Alarm Ratio (FAR) from binary classifications. - Args: - name: The name of the metric. Defaults to "FalseAlarmRatio". + Extends ThresholdMetric to compute FAR between forecast and target using + the preserve_dims dimensions. FAR measures the fraction of predicted + events that did not occur. Note: FAR is not the same as False Alarm Rate. """ def __init__(self, name: str = "FalseAlarmRatio", *args, **kwargs): + """Initialize the False Alarm Ratio metric. + + Args: + name: The name of the metric. Defaults to "FalseAlarmRatio". + *args: Additional positional arguments passed to ThresholdMetric. + **kwargs: Additional keyword arguments passed to ThresholdMetric. + """ super().__init__(name, *args, **kwargs) def _compute_metric( @@ -494,16 +544,21 @@ def _compute_metric( class TruePositives(ThresholdMetric): - """True Positive ratio. - - The True Positive is the number of times the forecast is a true positive (top right - cell in the contingency table) divided by the total number of observations. + """Compute True Positive ratio from binary classifications. - Args: - name: The name of the metric. Defaults to "TruePositives". + Extends ThresholdMetric to compute the ratio of true positives (correctly + predicted events) to the total number of observations. Corresponds to the + top right cell in the contingency table. """ def __init__(self, name: str = "TruePositives", *args, **kwargs): + """Initialize the True Positives metric. + + Args: + name: The name of the metric. Defaults to "TruePositives". + *args: Additional positional arguments passed to ThresholdMetric. + **kwargs: Additional keyword arguments passed to ThresholdMetric. + """ super().__init__(name, *args, **kwargs) def _compute_metric( @@ -527,16 +582,20 @@ def _compute_metric( class FalsePositives(ThresholdMetric): - """False Positive ratio. + """Compute False Positive ratio from binary classifications. - The False Positive is the number of times the forecast is a false positive divided - by the total number of observations. - - Args: - name: The name of the metric. Defaults to "FalsePositives". + Extends ThresholdMetric to compute the ratio of false positives + (incorrectly predicted events) to the total number of observations. """ def __init__(self, name: str = "FalsePositives", *args, **kwargs): + """Initialize the False Positives metric. + + Args: + name: The name of the metric. Defaults to "FalsePositives". + *args: Additional positional arguments passed to ThresholdMetric. + **kwargs: Additional keyword arguments passed to ThresholdMetric. + """ super().__init__(name, *args, **kwargs) def _compute_metric( @@ -560,16 +619,20 @@ def _compute_metric( class TrueNegatives(ThresholdMetric): - """True Negative ratio. - - The True Negative is the number of times the forecast is a true negative divided by - the total number of observations. + """Compute True Negative ratio from binary classifications. - Args: - name: The name of the metric. Defaults to "TrueNegatives". + Extends ThresholdMetric to compute the ratio of true negatives (correctly + predicted non-events) to the total number of observations. """ def __init__(self, name: str = "TrueNegatives", *args, **kwargs): + """Initialize the True Negatives metric. + + Args: + name: The name of the metric. Defaults to "TrueNegatives". + *args: Additional positional arguments passed to ThresholdMetric. + **kwargs: Additional keyword arguments passed to ThresholdMetric. + """ super().__init__(name, *args, **kwargs) def _compute_metric( @@ -593,16 +656,21 @@ def _compute_metric( class FalseNegatives(ThresholdMetric): - """False Negative ratio. - - The False Negative is the number of times the forecast is a false negative (top left - cell in the contingency table) divided by the total number of observations. + """Compute False Negative ratio from binary classifications. - Args: - name: The name of the metric. Defaults to "FalseNegatives". + Extends ThresholdMetric to compute the ratio of false negatives (missed + events) to the total number of observations. Corresponds to the top left + cell in the contingency table. """ def __init__(self, name: str = "FalseNegatives", *args, **kwargs): + """Initialize the False Negatives metric. + + Args: + name: The name of the metric. Defaults to "FalseNegatives". + *args: Additional positional arguments passed to ThresholdMetric. + **kwargs: Additional keyword arguments passed to ThresholdMetric. + """ super().__init__(name, *args, **kwargs) def _compute_metric( @@ -626,17 +694,21 @@ def _compute_metric( class Accuracy(ThresholdMetric): - """Accuracy metric. + """Compute classification accuracy from binary classifications. - The Accuracy is the number of times the forecast is correct (top right or bottom - right cell in the contingency table) divided by the total number of observations, or - (true positives + true negatives) / (total number of samples). - - Args: - name: The name of the metric. Defaults to "Accuracy". + Extends ThresholdMetric to compute the ratio of correct predictions (true + positives + true negatives) to the total number of observations. Measures + overall correctness of the forecast. """ def __init__(self, name: str = "Accuracy", *args, **kwargs): + """Initialize the Accuracy metric. + + Args: + name: The name of the metric. Defaults to "Accuracy". + *args: Additional positional arguments passed to ThresholdMetric. + **kwargs: Additional keyword arguments passed to ThresholdMetric. + """ super().__init__(name, *args, **kwargs) def _compute_metric( @@ -659,19 +731,10 @@ def _compute_metric( class MeanSquaredError(BaseMetric): - """Mean Squared Error metric. - - Args: - name: The name of the metric. Defaults to "MeanSquaredError". - interval_where_one: From scores, endpoints of the interval where the threshold - weights are 1. Must be increasing. Infinite endpoints are permissible. By - supplying a tuple of arrays, endpoints can vary with dimension. - interval_where_positive: From scores, endpoints of the interval where the - threshold weights are positive. Must be increasing. Infinite endpoints are - only permissible when the corresponding interval_where_one endpoint is - infinite. By supplying a tuple of arrays, endpoints can vary with dimension. - weights: From scores, an array of weights to apply to the score (e.g., weighting - a grid by latitude). If None, no weights are applied. + """Compute Mean Squared Error between forecast and target. + + Extends BaseMetric to calculate MSE with optional interval-based + weighting and custom weights for spatial/temporal averaging. """ def __init__( @@ -687,6 +750,20 @@ def __init__( *args, **kwargs, ): + """Initialize the Mean Squared Error metric. + + Args: + name: The name of the metric. Defaults to "MeanSquaredError". + interval_where_one: Endpoints of the interval where threshold + weights are 1. Must be increasing. Infinite endpoints + permissible. + interval_where_positive: Endpoints of the interval where threshold + weights are positive. Must be increasing. + weights: Array of weights to apply to the score (e.g., latitude + weighting). If None, no weights are applied. + *args: Additional positional arguments passed to BaseMetric. + **kwargs: Additional keyword arguments passed to BaseMetric. + """ super().__init__(name, *args, **kwargs) self.interval_where_one = interval_where_one self.interval_where_positive = interval_where_positive @@ -711,19 +788,10 @@ def _compute_metric( class MeanAbsoluteError(BaseMetric): - """Mean Absolute Error metric. - - Args: - name: The name of the metric. Defaults to "MeanAbsoluteError". - interval_where_one: From scores, endpoints of the interval where the threshold - weights are 1. Must be increasing. Infinite endpoints are permissible. By - supplying a tuple of arrays, endpoints can vary with dimension. - interval_where_positive: From scores, endpoints of the interval where the - threshold weights are positive. Must be increasing. Infinite endpoints are - only permissible when the corresponding interval_where_one endpoint is - infinite. By supplying a tuple of arrays, endpoints can vary with dimension. - weights: From scores, an array of weights to apply to the score (e.g., weighting - a grid by latitude). If None, no weights are applied. + """Compute Mean Absolute Error between forecast and target. + + Extends BaseMetric to calculate MAE with optional interval-based + weighting and custom weights for spatial/temporal averaging. """ def __init__( @@ -739,6 +807,20 @@ def __init__( *args, **kwargs, ): + """Initialize the Mean Absolute Error metric. + + Args: + name: The name of the metric. Defaults to "MeanAbsoluteError". + interval_where_one: Endpoints of the interval where threshold + weights are 1. Must be increasing. Infinite endpoints + permissible. + interval_where_positive: Endpoints of the interval where threshold + weights are positive. Must be increasing. + weights: Array of weights to apply to the score (e.g., latitude + weighting). If None, no weights are applied. + *args: Additional positional arguments passed to BaseMetric. + **kwargs: Additional keyword arguments passed to BaseMetric. + """ self.interval_where_one = interval_where_one self.interval_where_positive = interval_where_positive self.weights = weights @@ -772,16 +854,20 @@ def _compute_metric( class MeanError(BaseMetric): - """Mean Error (bias) metric. + """Compute Mean Error (bias) between forecast and target. - The mean error (or mean bias error) is computed between the forecast and target - using the preserve_dims dimensions. - - Args: - name: The name of the metric. Defaults to "MeanError". + Extends BaseMetric to calculate mean error (bias) using the preserve_dims + dimensions. Positive values indicate forecast exceeds target. """ def __init__(self, name: str = "MeanError", *args, **kwargs): + """Initialize the Mean Error metric. + + Args: + name: The name of the metric. Defaults to "MeanError". + *args: Additional positional arguments passed to BaseMetric. + **kwargs: Additional keyword arguments passed to BaseMetric. + """ super().__init__(name, *args, **kwargs) def _compute_metric( @@ -805,16 +891,20 @@ def _compute_metric( class RootMeanSquaredError(BaseMetric): - """Root Mean Square Error metric. - - The Root Mean Square Error is computed between the forecast and target using the - preserve_dims dimensions. + """Compute Root Mean Squared Error between forecast and target. - Args: - name: The name of the metric. Defaults to "RootMeanSquaredError". + Extends BaseMetric to calculate RMSE using the preserve_dims dimensions. + RMSE is the square root of the mean squared error. """ def __init__(self, name: str = "RootMeanSquaredError", *args, **kwargs): + """Initialize the Root Mean Squared Error metric. + + Args: + name: The name of the metric. Defaults to "RootMeanSquaredError". + *args: Additional positional arguments passed to BaseMetric. + **kwargs: Additional keyword arguments passed to BaseMetric. + """ super().__init__(name, *args, **kwargs) def _compute_metric( @@ -838,20 +928,11 @@ def _compute_metric( class EarlySignal(BaseMetric): - """Early Signal detection metric. - - This metric finds the first occurrence where a signal is detected based on - threshold criteria and returns the corresponding init_time, lead_time, and - valid_time information. The metric is designed to be flexible for different - signal detection criteria that can be specified in applied metrics downstream. - - Args: - name: The name of the metric. - comparison_operator: The comparison operator to use for signal detection. - threshold: The threshold value for signal detection. - spatial_aggregation: The spatial aggregation method to use for signal detection. - Options are "any" (any gridpoint meets criteria), "all" (all gridpoints - meet criteria), or "half" (at least half of gridpoints meet criteria). + """Detect first occurrence of signal exceeding threshold criteria. + + Extends BaseMetric to find the earliest time when a signal is detected + based on threshold criteria, returning init_time, lead_time, and + valid_time information. Flexible for different signal detection criteria. """ def __init__( @@ -864,6 +945,17 @@ def __init__( spatial_aggregation: Literal["any", "all", "half"] = "any", **kwargs, ): + """Initialize the Early Signal detection metric. + + Args: + name: The name of the metric. Defaults to "EarlySignal". + comparison_operator: The comparison operator for signal detection. + threshold: The threshold value for signal detection. + spatial_aggregation: Spatial aggregation method. Options: "any" + (any gridpoint meets criteria), "all" (all gridpoints meet + criteria), or "half" (at least half meet criteria). + **kwargs: Additional keyword arguments passed to BaseMetric. + """ # Extract threshold params before passing to super self.comparison_operator = utils.maybe_get_operator(comparison_operator) self.threshold = threshold @@ -929,19 +1021,11 @@ def _compute_metric( class MaximumMeanAbsoluteError(MeanAbsoluteError): - """Computes the mean absolute error between the forecast and target maximum values. - - The forecast is filtered to a time window around the target's maximum using - tolerance_range_hours (in the event of variation between the timing between the - target and forecast maximum values). The mean absolute error is computed between the - filtered forecast and target maximum value. - - Args: - tolerance_range_hours: The time window (hours) around the target's maximum - value to search for forecast minimum. Defaults to 24 hours. - reduce_spatial_dims: The spatial dimensions to reduce. Defaults to - ["latitude", "longitude"]. - name: The name of the metric. Defaults to "MaximumMeanAbsoluteError". + """Compute MAE between forecast and target maximum values. + + Extends MeanAbsoluteError to filter forecast to a time window around the + target's maximum using tolerance_range_hours. Useful for evaluating peak + value timing and magnitude. """ def __init__( @@ -952,6 +1036,20 @@ def __init__( *args, **kwargs, ): + """Initialize the Maximum Mean Absolute Error metric. + + Args: + tolerance_range_hours: Time window (hours) around target's + maximum to search for forecast maximum. Defaults to 24. + reduce_spatial_dims: Spatial dimensions to reduce. Defaults to + ["latitude", "longitude"]. + name: The name of the metric. Defaults to + "MaximumMeanAbsoluteError". + *args: Additional positional arguments passed to + MeanAbsoluteError. + **kwargs: Additional keyword arguments passed to + MeanAbsoluteError. + """ self.tolerance_range_hours = tolerance_range_hours self.reduce_spatial_dims = reduce_spatial_dims super().__init__(name, *args, **kwargs) @@ -1007,19 +1105,11 @@ def _compute_metric( class MinimumMeanAbsoluteError(MeanAbsoluteError): - """Computes the mean absolute error between the forecast and target minimum values. - - The forecast is filtered to a time window around the target's minimum using - tolerance_range_hours (in the event of variation between the timing between the - target and forecast minimum values). The mean absolute error is computed between the - filtered forecast and target minimum value. - - Args: - tolerance_range_hours: The time window (hours) around the target's minimum - value to search for forecast minimum. Defaults to 24 hours. - reduce_spatial_dims: The spatial dimensions to reduce. Defaults to - ["latitude", "longitude"]. - name: The name of the metric. Defaults to "MinimumMeanAbsoluteError". + """Compute MAE between forecast and target minimum values. + + Extends MeanAbsoluteError to filter forecast to a time window around the + target's minimum using tolerance_range_hours. Useful for evaluating + minimum value timing and magnitude. """ def __init__( @@ -1030,6 +1120,20 @@ def __init__( *args, **kwargs, ): + """Initialize the Minimum Mean Absolute Error metric. + + Args: + tolerance_range_hours: Time window (hours) around target's + minimum to search for forecast minimum. Defaults to 24. + reduce_spatial_dims: Spatial dimensions to reduce. Defaults to + ["latitude", "longitude"]. + name: The name of the metric. Defaults to + "MinimumMeanAbsoluteError". + *args: Additional positional arguments passed to + MeanAbsoluteError. + **kwargs: Additional keyword arguments passed to + MeanAbsoluteError. + """ self.tolerance_range_hours = tolerance_range_hours self.reduce_spatial_dims = reduce_spatial_dims super().__init__(name, *args, **kwargs) @@ -1082,16 +1186,11 @@ def _compute_metric( class MaximumLowestMeanAbsoluteError(MeanAbsoluteError): - """Mean Absolute Error of the maximum of aggregated minimum values. - - Meant for heatwave evaluation by aggregating the minimum values over a day and then - computing the MeanAbsoluteError between the warmest nighttime (daily minimum) - temperature in the target and forecast. + """Compute MAE of maximum aggregated minimum values for heatwaves. - Args: - tolerance_range_hours: The time window (hours) around the target's max-min - value to search for forecast max-min. Defaults to 24 hours. - name: The name of the metric. Defaults to "MaximumLowestMeanAbsoluteError". + Extends MeanAbsoluteError for heatwave evaluation by aggregating daily + minimum values and computing MAE between the warmest nighttime (daily + minimum) temperature in target and forecast. """ def __init__( @@ -1101,6 +1200,18 @@ def __init__( *args, **kwargs, ): + """Initialize the Maximum Lowest Mean Absolute Error metric. + + Args: + tolerance_range_hours: Time window (hours) around target's + max-min value to search for forecast max-min. Defaults to 24. + name: The name of the metric. Defaults to + "MaximumLowestMeanAbsoluteError". + *args: Additional positional arguments passed to + MeanAbsoluteError. + **kwargs: Additional keyword arguments passed to + MeanAbsoluteError. + """ self.tolerance_range_hours = tolerance_range_hours super().__init__(name, *args, **kwargs) @@ -1184,22 +1295,10 @@ def _compute_metric( class DurationMeanError(MeanError): - """Compute the duration of a case's event. - - This metric computes the mean error between the forecast and target durations. - - Args: - threshold_criteria: The criteria for event detection. Can be either a DataArray - of a climatology with dimensions (dayofyear, hour, latitude, longitude) or a - float value representing a fixed threshold. - reduce_spatial_dims: The spatial dimensions to reduce prior to applying threshold - criteria. Defaults to ["latitude", "longitude"]. - op_func: Comparison operator or string (e.g., operator.ge for >=). - name: Name of the metric. - preserve_dims: Dimensions to preserve during aggregation. Defaults to - "init_time". - product_time_resolution_hours: Whether to product the duration by the time - resolution of the forecast (in hours). Defaults to False. + """Compute mean error of event duration between forecast and target. + + Extends MeanError to compute the mean error between forecast and target + event durations based on threshold criteria and spatial aggregation. """ def __init__( @@ -1211,6 +1310,23 @@ def __init__( preserve_dims: str = "init_time", product_time_resolution_hours: bool = False, ): + """Initialize the Duration Mean Error metric. + + Args: + threshold_criteria: Criteria for event detection. Either a + DataArray of climatology with dimensions (dayofyear, hour, + latitude, longitude) or a float fixed threshold. + reduce_spatial_dims: Spatial dimensions to reduce prior to + applying threshold criteria. Defaults to ["latitude", + "longitude"]. + op_func: Comparison operator or string (e.g., operator.ge for + >=). + name: Name of the metric. Defaults to "DurationMeanError". + preserve_dims: Dimensions to preserve during aggregation. + Defaults to "init_time". + product_time_resolution_hours: Whether to multiply duration by + time resolution of forecast (in hours). Defaults to False. + """ super().__init__(name=name, preserve_dims=preserve_dims) self.reduce_spatial_dims = reduce_spatial_dims self.threshold_criteria = threshold_criteria @@ -1307,15 +1423,17 @@ def _compute_metric( class LandfallMetric(CompositeMetric): - """Base class for landfall metrics. + """Base class for tropical cyclone landfall metrics. + + Extends CompositeMetric to compute landfalls using calc.find_landfalls, + which utilizes land geometry and line segments based on track data to + determine intersections. - Landfall metrics compute landfalls using the calc.find_landfalls function, which - utilizes a land geometry and line segments based on track data to determine - intersections. + Can be used as a base class for custom landfall metrics, as a mixin with + other metrics, or as a composite metric for multiple landfall metrics. - Can be used as a base class for custom landfall metrics, as a mixin with other - metrics, or as a composite metric for multiple landfall metrics (which utilize - identical landfalling locations). + Public methods: + maybe_prepare_composite_kwargs: Prepare kwargs for landfall composites """ def __init__( @@ -1521,13 +1639,11 @@ def _compute_metric( class SpatialDisplacement(BaseMetric): - """Spatial displacement error metric for atmospheric rivers and similar events. - - Computes the great circle distance between the center of mass of forecast - and target spatial patterns. + """Compute spatial displacement between forecast and target patterns. - Args: - name: The name of the metric. Defaults to "spatial_displacement". + Extends BaseMetric to compute great circle distance between centers of + mass of forecast and target spatial patterns. Useful for atmospheric + rivers and similar spatial features. """ def __init__( @@ -1535,6 +1651,13 @@ def __init__( name: str = "spatial_displacement", **kwargs: Any, ): + """Initialize the Spatial Displacement metric. + + Args: + name: The name of the metric. Defaults to + "spatial_displacement". + **kwargs: Additional keyword arguments passed to BaseMetric. + """ super().__init__(name, **kwargs) def _compute_metric( @@ -1621,13 +1744,10 @@ def center_of_mass_ufunc(data): class LandfallDisplacement(LandfallMetric): - """Calculate the distance between forecast and target landfall positions. + """Compute distance between forecast and target landfall positions. - This metric computes the distance between the forecast and target - landfall positions, defaulting to kilometers. - - Args: - name: The name of the metric. Defaults to "landfall_displacement". + Extends LandfallMetric to calculate the spatial distance between forecast + and target landfall positions, defaulting to kilometers. """ def __init__( @@ -1636,6 +1756,14 @@ def __init__( *args, **kwargs, ): + """Initialize the Landfall Displacement metric. + + Args: + name: The name of the metric. Defaults to + "landfall_displacement". + *args: Additional positional arguments passed to LandfallMetric. + **kwargs: Additional keyword arguments passed to LandfallMetric. + """ super().__init__(name, *args, **kwargs) self.units = kwargs.get("units", "km") @@ -1717,15 +1845,11 @@ def _compute_metric( class LandfallTimeMeanError(LandfallMetric): - """Landfall time mean error. - - This metric computes the mean error between the forecast and target landfall times. - A positive value indicates the forecast landfall time is later than the target - landfall time, a negative value indicates the forecast landfall time is earlier than - the target landfall time. + """Compute mean error between forecast and target landfall times. - Args: - name: The name of the metric. Defaults to "landfall_time_me". + Extends LandfallMetric to calculate timing difference. Positive values + indicate forecast landfall is later than target; negative values indicate + forecast landfall is earlier than target. """ def __init__( @@ -1734,6 +1858,13 @@ def __init__( *args, **kwargs, ): + """Initialize the Landfall Time Mean Error metric. + + Args: + name: The name of the metric. Defaults to "landfall_time_me". + *args: Additional positional arguments passed to LandfallMetric. + **kwargs: Additional keyword arguments passed to LandfallMetric. + """ super().__init__(name, *args, **kwargs) def calculate_time_difference( @@ -1791,18 +1922,14 @@ def _compute_metric( class LandfallIntensityMeanAbsoluteError(LandfallMetric, MeanAbsoluteError): - """Compute the MeanAbsoluteError between forecast and target. + """Compute MAE of forecast and target intensity at landfall. - This metric computes the mean absolute error between forecast and target - intensity at landfall. + Extends both LandfallMetric and MeanAbsoluteError to calculate mean + absolute error between forecast and target intensity at landfall time. The intensity variable is determined by forecast_variable and - target_variable. To evaluate multiple intensity variables (e.g., - surface_wind_speed and air_pressure_at_mean_sea_level), create - separate metric instances for each variable. - - Args: - name: The name of the metric. Defaults to "landfall_intensity_mae". + target_variable. For multiple intensity variables, create separate metric + instances for each variable. """ def __init__( @@ -1811,6 +1938,14 @@ def __init__( *args, **kwargs, ): + """Initialize the Landfall Intensity Mean Absolute Error metric. + + Args: + name: The name of the metric. Defaults to + "landfall_intensity_mae". + *args: Additional positional arguments passed to parent classes. + **kwargs: Additional keyword arguments passed to parent classes. + """ super().__init__(name, *args, **kwargs) def _compute_metric( diff --git a/src/extremeweatherbench/regions.py b/src/extremeweatherbench/regions.py index e150e42f..4e6c7730 100644 --- a/src/extremeweatherbench/regions.py +++ b/src/extremeweatherbench/regions.py @@ -22,7 +22,21 @@ class Region(abc.ABC): - """Base class for different region representations.""" + """Base class for different region representations. + + This abstract class defines the interface for geographic regions used in + ExtremeWeatherBench. Regions can be centered, bounding boxes, or defined + by shapefiles. + + Public methods: + create_region: Abstract factory method to create a region + as_geopandas: Convert region to GeoDataFrame representation + get_adjusted_bounds: Get region bounds adjusted to dataset convention + mask: Mask a dataset to this region + intersects: Check if this region intersects another region + contains: Check if this region contains another region + area_overlap_fraction: Calculate area overlap with another region + """ @classmethod @abc.abstractmethod @@ -159,17 +173,11 @@ def area_overlap_fraction(self, other: "Region") -> float: class CenteredRegion(Region): - """A region defined by a center point and a bounding box. - - bounding_box_degrees is the width (length) of one or all sides, not half size; - e.g., bounding_box_degrees=10.0 means a 10 degree by 10 degree box around - the center point. + """Region defined by center point and bounding box. - Attributes: - latitude: Center latitude - longitude: Center longitude - bounding_box_degrees: Size of bounding box in degrees or tuple of - (lat_degrees, lon_degrees) + Extends Region to define a region using a center point and bounding box + dimensions. The bounding_box_degrees is the full width/height (not half + size); e.g., 10.0 means a 10x10 degree box around the center. """ def __repr__(self): @@ -182,6 +190,15 @@ def __repr__(self): def __init__( self, latitude: float, longitude: float, bounding_box_degrees: float | tuple ): + """Initialize the CenteredRegion. + + Args: + latitude: Center latitude in degrees. + longitude: Center longitude in degrees. + bounding_box_degrees: Size of bounding box in degrees. Either a + single float (square box) or tuple of (lat_degrees, + lon_degrees). + """ self.latitude = latitude self.longitude = longitude self.bounding_box_degrees = bounding_box_degrees @@ -229,13 +246,9 @@ def as_geopandas(self) -> gpd.GeoDataFrame: class BoundingBoxRegion(Region): - """A region defined by explicit latitude and longitude bounds. + """Region defined by explicit latitude and longitude bounds. - Attributes: - latitude_min: Minimum latitude bound - latitude_max: Maximum latitude bound - longitude_min: Minimum longitude bound - longitude_max: Maximum longitude bound + Extends Region to define a region using explicit bounding box coordinates. """ def __repr__(self): @@ -253,6 +266,14 @@ def __init__( longitude_min: float, longitude_max: float, ): + """Initialize the BoundingBoxRegion. + + Args: + latitude_min: Minimum latitude bound in degrees. + latitude_max: Maximum latitude bound in degrees. + longitude_min: Minimum longitude bound in degrees. + longitude_max: Maximum longitude bound in degrees. + """ self.latitude_min = latitude_min self.latitude_max = latitude_max self.longitude_min = longitude_min @@ -286,19 +307,21 @@ def as_geopandas(self) -> gpd.GeoDataFrame: class ShapefileRegion(Region): - """A region defined by a shapefile. - - A geopandas object shapefile is read in and stored as an attribute - on instantiation. + """Region defined by a shapefile. - Attributes: - shapefile_path: Local or remote path to the .shp shapefile + Extends Region to define a region using a shapefile. The shapefile is read + using geopandas on instantiation. """ def __repr__(self): return f"{self.__class__.__name__}(shapefile_path={self.shapefile_path})" def __init__(self, shapefile_path: str | pathlib.Path): + """Initialize the ShapefileRegion. + + Args: + shapefile_path: Local or remote path to the .shp shapefile. + """ self.shapefile_path = pathlib.Path(shapefile_path) @classmethod @@ -465,16 +488,17 @@ def _create_geopandas_from_bounds( class RegionSubsetter: - """A utility class for subsetting ExtremeWeatherBench objects by region. + """Utility class for subsetting ExtremeWeatherBench objects by region. - Attributes: - region: The region to subset to. Can be a Region object or a - dictionary of bounds with keys "latitude_min", "latitude_max", - "longitude_min", and "longitude_max". - method: The method to use for subsetting. Options: - - "intersects": Include cases where ANY part of a case intersects region - - "percent": Include cases where percent of case area overlaps with region. - - "all": Only include cases where entirety of a case is within region + Provides methods for filtering case collections based on spatial overlap + with a specified region using various inclusion criteria. + + Public methods: + subset: Subset a case collection based on region overlap + + Instance attributes: + region: The region to subset to + method: The subsetting method used percent_threshold: Threshold for percent overlap (0.0 to 1.0) """ @@ -491,10 +515,10 @@ def __init__( """Initialize the RegionSubsetter. Args: - region: The region to subset to. Can be a Region object or a - dictionary of bounds with keys "latitude_min", "latitude_max", - "longitude_min", and "longitude_max". - method: The method to use for subsetting. Options: + region: The region to subset to. Can be a Region object or + dictionary with keys "latitude_min", "latitude_max", + "longitude_min", "longitude_max". + method: The subsetting method. Options: - "intersects": Include cases where ANY part of a case intersects region - "percent": Include cases where percent of case area overlaps with region diff --git a/src/extremeweatherbench/sources/base.py b/src/extremeweatherbench/sources/base.py index 3278038b..e7dbda6e 100644 --- a/src/extremeweatherbench/sources/base.py +++ b/src/extremeweatherbench/sources/base.py @@ -6,7 +6,17 @@ @runtime_checkable class Source(Protocol): - """A protocol for input sources.""" + """Protocol defining the interface for input data sources. + + This protocol specifies the methods that input source implementations must + provide for variable extraction, temporal validation, and spatial data + checking. + + Required methods: + safely_pull_variables: Extract specified variables from data + check_for_valid_times: Check if data has valid times in date range + check_for_spatial_data: Check if data has spatial coverage for region + """ def safely_pull_variables( self, From aec88ae896bd854aa45e2d59fcf26560269166fb Mon Sep 17 00:00:00 2001 From: Taylor Mandelbaum Date: Fri, 23 Jan 2026 21:35:16 -0500 Subject: [PATCH 11/13] add explanation for dim reqs (#320) --- docs/usage.md | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/docs/usage.md b/docs/usage.md index 874edbd2..71e0648e 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -8,11 +8,6 @@ To run the Brightband-based evaluation on an existing AIWP model (FCN v2), which includes the default 337 cases for heat waves, freezes, severe convective days, tropical cyclones, and atmospheric rivers: -```bash -ewb --default -``` - -or: ```python import extremeweatherbench as ewb @@ -29,6 +24,12 @@ outputs = runner.run() outputs.to_csv('your_outputs.csv') ``` +or: + +```bash +ewb --default +``` + ## API Overview ExtremeWeatherBench provides a hierarchical API for accessing its components: @@ -52,13 +53,16 @@ ewb.ERA5(...) ewb.ZarrForecast(...) ewb.load_cases() ``` - ## Running an Evaluation for a Single Event Type ExtremeWeatherBench has default event types and cases for heat waves, freezes, severe convection, tropical cyclones, and atmospheric rivers. To run an evaluation, there are three components required: a forecast, a target, and an evaluation object. +ExtremeWeatherBench requires forecasts to have `init_time`, `lead_time`, `latitude`, and `longitude` dimensions at minimum. If not already in that naming convention, initializing a `ForecastBase` object with a `variable_mapping` to map to those names is required. Other dimensions such as pressure level (`level`) can be included. + +Targets require at least a `valid_time` with at least one spatial dimension. Examples include `location`, `station`, or (`latitude`, `longitude`). Forecasts are aligned to targets during the steps immediately prior to evaluating a metric. + ```python import extremeweatherbench as ewb ``` From d24fc2e0a0fa49f8b73651ecf00fe97e13d58479 Mon Sep 17 00:00:00 2001 From: Taylor Mandelbaum Date: Fri, 23 Jan 2026 23:07:22 -0500 Subject: [PATCH 12/13] Update `defaults` and `inputs` to include new CIRA icechunk store (#319) * more explicit naming, add func and model names var * add test coverage, ruff, linting * update readme for new cira approach * move cira func and model ref to inputs * update docs * module wasnt called for moved func * update tests for moving func and var * ruff * fix mock typos --- README.md | 49 +---- docs/recipes/cira_forecast.md | 39 ++-- docs/usage.md | 25 ++- src/extremeweatherbench/defaults.py | 140 ++++++++------- src/extremeweatherbench/inputs.py | 59 ++++++ tests/test_defaults.py | 95 ++++++++-- tests/test_inputs.py | 267 +++++++++++++++++++++++++++- 7 files changed, 518 insertions(+), 156 deletions(-) diff --git a/README.md b/README.md index f0249f50..d90aff16 100644 --- a/README.md +++ b/README.md @@ -67,48 +67,11 @@ $ ewb --default ```python from extremeweatherbench import cases, inputs, metrics, evaluate, utils -# Select model -model = 'FOUR_v200_GFS' - -# Set up path to directory of file - zarr or kerchunk/virtualizarr json/parquet -forecast_dir = f'gs://extremeweatherbench/{model}.parq' - -# Preprocessing function exclusive to handling the CIRA parquets -def preprocess_bb_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: - """Preprocess CIRA kerchunk (parquet) data in the ExtremeWeatherBench bucket. - A preprocess function that renames the time coordinate to lead_time, - creates a valid_time coordinate, and sets the lead time range and resolution not - present in the original dataset. - Args: - ds: The forecast dataset to rename. - Returns: - The renamed forecast dataset. - """ - ds = ds.rename({"time": "lead_time"}) - - # The evaluation configuration is used to set the lead time range and resolution. - ds["lead_time"] = np.array( - [i for i in range(0, 241, 6)], dtype="timedelta64[h]" - ).astype("timedelta64[ns]") - - return ds - -# Define a forecast object; in this case, a KerchunkForecast -fcnv2_forecast = inputs.KerchunkForecast( - name="fcnv2_forecast", # identifier for this forecast in results - source=forecast_dir, # source path - variables=["surface_air_temperature"], # variables to use in the evaluation - variable_mapping=inputs.CIRA_metadata_variable_mapping, # mapping to use for variables in forecast dataset to EWB variable names - storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, # storage options for access - preprocess=preprocess_bb_cira_forecast_dataset # required preprocessing function for CIRA references -) +# Load in a forecast; here, we load in GFS initialized FCNv2 from the CIRA MLWP archive with a default variable built-in for convenience +fcnv2_heatwave_forecast = defaults.cira_fcnv2_heatwave_forecast -# Load in ERA5; source defaults to the ARCO ERA5 dataset from Google and variable mapping is provided by default as well -era5_heatwave_target = inputs.ERA5( - variables=["surface_air_temperature"], # variable to use in the evaluation - storage_options={"remote_options": {"anon": True}}, # storage options for access - chunks=None, # define chunks for the ERA5 data -) +# Load in ERA5 with another default convenience variable +era5_heatwave_target = defaults.era5_heatwave_target # EvaluationObjects are used to evaluate a single forecast source against a single target source with a defined event type. Event types are declared with each case. One or more metrics can be evaluated with each EvaluationObject. heatwave_evaluation_list = [ @@ -120,7 +83,7 @@ heatwave_evaluation_list = [ metrics.MaximumLowestMeanAbsoluteError(), ], target=era5_heatwave_target, - forecast=fcnv2_forecast, + forecast=fcnv2_heatwave_forecast, ), ] # Load in the EWB default list of event cases @@ -134,7 +97,7 @@ ewb_instance = evaluate.ExtremeWeatherBench( # Execute a parallel run and return the evaluation results as a pandas DataFrame heatwave_outputs = ewb_instance.run( - parallel_config={'backend':'loky','n_jobs':16} # Uses 16 jobs with the loky backend + parallel_config={'n_jobs':16} # Uses 16 jobs with the loky backend as default ) # Save the results diff --git a/docs/recipes/cira_forecast.md b/docs/recipes/cira_forecast.md index 43a7a6ac..9cf57b90 100644 --- a/docs/recipes/cira_forecast.md +++ b/docs/recipes/cira_forecast.md @@ -2,22 +2,10 @@ We have a dedicated virtual reference icechunk store for CIRA data **up to May 26th, 2025** available at `gs://extremeweatherbench/cira-icechunk`. Compared to using parquet virtual references, we have seen a speed improvements of around 2x with ~25% more memory usage. -## Loading the store - -```python - -from extremeweatherbench import cases, inputs, metrics, evaluate, defaults -import datetime -import icechunk - -storage = icechunk.gcs_storage( - bucket="extremeweatherbench", prefix="cira-icechunk", anonymous=True -) -``` - ## Accessing a CIRA Model from the store ```python +from extremeweatherbench import inputs group_list = inputs.list_groups_in_icechunk_datatree(storage) ``` @@ -39,22 +27,33 @@ group_list = inputs.list_groups_in_icechunk_datatree(storage) ```python -# Find FCNv2's name in the group list -fcnv2_group = [n for n in group_list if 'FOUR_v200_GFS' in n][0] - # Helper function to access the virtual dataset -fcnv2 = inputs.open_icechunk_dataset_from_datatree( +fcnv2 = inputs.get_cira_icechunk(model_name='FOUR_v200_IFS') +``` + +`fcnv2` is a `ForecastBase` object ready to be used within EWB's evaluation framework. + +> **Detailed Explanation**: `inputs.get_cira_icechunk` is syntactic sugar for this: +```python +import icechunk + +storage = icechunk.gcs_storage( + bucket="extremeweatherbench", prefix="cira-icechunk", anonymous=True +) + +fcnv2_icechunk_ds = inputs.open_icechunk_dataset_from_datatree( storage=storage, - group=fcnv2_group, + group="FOUR_v200_IFS", authorize_virtual_chunk_access=inputs.CIRA_CREDENTIALS ) -fcnv2_icechunk_forecast_object = inputs.XarrayForecast( + +fcnv2 = inputs.XarrayForecast( ds=fcnv2, variable_mapping=inputs.CIRA_metadata_variable_mapping ) ``` -`fcnv2_icechunk_forecast_object` is a `ForecastBase` object ready to be used within EWB's evaluation framework. +Which is a three step process of accessing the icechunk storage, loading the dataset from the datatree/zarr group format, and finally applying that `Dataset` in a `ForecastBase` object. ## Set up metrics and target for evaluation diff --git a/docs/usage.md b/docs/usage.md index 71e0648e..ddd372fa 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -66,8 +66,7 @@ Targets require at least a `valid_time` with at least one spatial dimension. Exa ```python import extremeweatherbench as ewb ``` - -There are two built-in `ForecastBase` classes to set up a forecast: `ZarrForecast` and `KerchunkForecast`. Here is an example of a `ZarrForecast`, using Weatherbench2's HRES zarr store: +There are three built-in `ForecastBase` classes to set up a forecast: `ZarrForecast`, `XarrayForecast`, and `KerchunkForecast`. Here is an example of a `ZarrForecast`, using Weatherbench2's HRES zarr store: ```python hres_forecast = ewb.forecasts.ZarrForecast( @@ -86,9 +85,9 @@ There are required arguments, namely: - `variables`* - `variable_mapping` -* `variables` can be defined within one or more metrics instead of in a `ForecastBase` object. +* `variables` can alternatively be defined within one or more metrics, instead of in a `ForecastBase` object. -A forecast needs a `source`, which is a link to the zarr store in this case. A `name` is required to identify the outputs. It also needs `variables` defined, which are based on CF Conventions. A list of variable namings exists in `ewb.DEFAULT_VARIABLE_NAMES`. Each forecast will likely have different names for their variables, so a `variable_mapping` dictionary is also essential to process the variables, as well as the coordinates and dimensions. EWB uses `lead_time`, `init_time`, and `valid_time` as time coordinates. The HRES data is mapped from `prediction_timedelta` to `lead_time`, as an example. `storage_options` define access patterns for the data if needed. These are passed to the opening function, e.g. `xarray.open_zarr`. +> **Detailed Explanation**: A forecast needs a `source`, which is a link to the zarr store in this case. A `name` is required to identify the outputs. It also needs `variables` defined, which are based on CF Conventions. A list of variable namings exists in `defaults.py` as `DEFAULT_VARIABLE_NAMES`. Each forecast will likely have different names for their variables, so a `variable_mapping` dictionary is also essential to process the variables, as well as the coordinates and dimensions. EWB uses `lead_time`, `init_time`, and `valid_time` as time coordinates. The HRES data is mapped from `prediction_timedelta` to `lead_time`, as an example. `storage_options` define access patterns for the data if needed. These are passed to the opening function, e.g. `xarray.open_zarr`. Next, a target dataset must be defined as well to evaluate against. For this evaluation, we'll use ERA5: @@ -101,7 +100,19 @@ era5_heatwave_target = ewb.targets.ERA5( ) ``` -Similarly to forecasts, we need to define the `source`, which here is the ARCO ERA5 provided by Google. `variables` are again required to be set for the `ewb.targets.ERA5` class; `variable_mapping` defaults to `ewb.ERA5_metadata_variable_mapping` for many existing variables and likely is not required to be set unless your use case is for less common variables. Both forecasts and targets, if relevant, have an optional `chunks` parameter which defaults to what should be the most efficient value - usually `None` or `'auto'`, but can be changed as seen above. +Note that EWB provides defaults for arguments, so most users will be able to instead write this (if defining variables with the intent of it applying to all metrics): + +```python +era5_heatwave_target = inputs.ERA5(variables=['surface_air_temperature']) +``` + +Or (if defining variables as arguments to the metrics): + +```python +era5_heatwave_target = inputs.ERA5() +``` + +> **Detailed Explanation**: Similarly to forecasts, we need to define the `source`, which here is the ARCO ERA5 provided by Google. `variables` are used to subset `ewb.inputs.ERA5` in an evaluation; `variable_mapping` defaults to `ewb.inputs.ERA5_metadata_variable_mapping` for many existing variables and likely is not required to be set unless your use case is for less common variables. Both forecasts and targets, if relevant, have an optional `chunks` parameter which defaults to what should be the most efficient value - usually `None` or `'auto'`, but can be changed as seen above. *If using the ARCO ERA5 and setting `chunks=None`, it is critical to order your subsetting by variables -> time -> `.sel` or `.isel` latitude & longitude -> rechunk. [See this Github comment](https://github.com/pydata/xarray/issues/8902#issuecomment-2036435045). We then set up an `EvaluationObject` list: @@ -139,7 +150,9 @@ outputs.to_csv('your_file_name.csv') Where the EWB default events YAML file is loaded in using `ewb.load_cases()`, then applied to an instance of `ewb.evaluation` along with the `EvaluationObject` list. Finally, we run the evaluation with the `.run()` method, where defaults are typically sufficient to run with a small to moderate-sized virtual machine. -The outputs are returned as a pandas DataFrame and can be manipulated in the script, a notebook, or post-hoc after saving it. +Running locally is feasible but is typically bottlenecked heavily by IO and network bandwidth. Even on a gigabit connection, the rate of data access is significantly slower compared to within a cloud provider VM. + +The outputs are returned as a pandas DataFrame and can be manipulated in the script, a notebook, etc. ## Backward Compatibility diff --git a/src/extremeweatherbench/defaults.py b/src/extremeweatherbench/defaults.py index 4e85c17b..41d20e16 100644 --- a/src/extremeweatherbench/defaults.py +++ b/src/extremeweatherbench/defaults.py @@ -58,28 +58,36 @@ ] -def _preprocess_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: +def _preprocess_cira_forecast_dataset( + ds: xr.Dataset, kerchunk: bool = True +) -> xr.Dataset: """A preprocess function for CIRA data that renames the time coordinate to lead_time, creates a valid_time coordinate, and sets the lead time range and resolution not present in the original dataset. Args: ds: The forecast dataset to preprocess. - + kerchunk: Whether the dataset is a kerchunk reference. Defaults to True. Returns: The preprocessed forecast dataset. """ - ds = ds.rename({"time": "lead_time"}) - # The evaluation configuration is used to set the lead time range and resolution. - ds["lead_time"] = np.array( - [i for i in range(0, 241, 6)], dtype="timedelta64[h]" - ).astype("timedelta64[ns]") + + # If the dataset is a kerchunk, we need to rename the time coordinate to lead_time + # and set the lead time range and resolution. Otherwise, pass through the dataset. + if kerchunk: + ds = ds.rename({"time": "lead_time"}) + # The evaluation configuration is used to set the lead time range and resolution. + ds["lead_time"] = np.array( + [i for i in range(0, 241, 6)], dtype="timedelta64[h]" + ).astype("timedelta64[ns]") return ds # Preprocessing function for CIRA data that includes geopotential thickness calculation # required for tropical cyclone tracks -def _preprocess_cira_tc_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: +def _preprocess_cira_tc_forecast_dataset( + ds: xr.Dataset, kerchunk: bool = True +) -> xr.Dataset: """A preprocess function for CIRA data that includes geopotential thickness calculation required for tropical cyclone tracks. @@ -89,16 +97,18 @@ def _preprocess_cira_tc_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: Args: ds: The forecast dataset to rename. - + kerchunk: Whether the dataset is a kerchunk reference. Defaults to True. Returns: The renamed forecast dataset. """ - ds = ds.rename({"time": "lead_time"}) - - # The evaluation configuration is used to set the lead time range and resolution. - ds["lead_time"] = np.array( - [i for i in range(0, 241, 6)], dtype="timedelta64[h]" - ).astype("timedelta64[ns]") + # If the dataset is a kerchunk, we need to rename the time coordinate to lead_time + # and set the lead time range and resolution. Otherwise, pass through the dataset. + if kerchunk: + ds = ds.rename({"time": "lead_time"}) + # The evaluation configuration is used to set the lead time range and resolution. + ds["lead_time"] = np.array( + [i for i in range(0, 241, 6)], dtype="timedelta64[h]" + ).astype("timedelta64[ns]") # Calculate the geopotential thickness required for tropical cyclone tracks ds["geopotential_thickness"] = ( @@ -133,23 +143,27 @@ def _preprocess_hres_tc_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: # Preprocess function for CIRA data using Brightband kerchunk parquets -def _preprocess_ar_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: +def _preprocess_cira_ar_forecast_dataset( + ds: xr.Dataset, kerchunk: bool = True +) -> xr.Dataset: """An example preprocess function that renames the time coordinate to lead_time, creates a valid_time coordinate, and sets the lead time range and resolution not present in the original dataset. Args: ds: The forecast dataset to rename. - + kerchunk: Whether the dataset is a kerchunk reference. Defaults to True. Returns: The renamed forecast dataset. """ - ds = ds.rename({"time": "lead_time"}) - - # The evaluation configuration is used to set the lead time range and resolution. - ds["lead_time"] = np.array( - [i for i in range(0, 241, 6)], dtype="timedelta64[h]" - ).astype("timedelta64[ns]") + # If the dataset is a kerchunk, we need to rename the time coordinate to lead_time + # and set the lead time range and resolution. Otherwise, pass through the dataset. + if kerchunk: + ds = ds.rename({"time": "lead_time"}) + # The evaluation configuration is used to set the lead time range and resolution. + ds["lead_time"] = np.array( + [i for i in range(0, 241, 6)], dtype="timedelta64[h]" + ).astype("timedelta64[ns]") if "q" not in ds.variables: # Calculate specific humidity from relative humidity and air temperature ds["specific_humidity"] = calc.specific_humidity_from_relative_humidity( @@ -161,23 +175,27 @@ def _preprocess_ar_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: # Preprocess function for CIRA data using Brightband kerchunk parquets -def _preprocess_severe_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: +def _preprocess_severe_cira_forecast_dataset( + ds: xr.Dataset, kerchunk: bool = True +) -> xr.Dataset: """An example preprocess function that renames the time coordinate to lead_time, creates a valid_time coordinate, and sets the lead time range and resolution not present in the original dataset. Args: ds: The forecast dataset to rename. - + kerchunk: Whether the dataset is a kerchunk reference. Defaults to True. Returns: The renamed forecast dataset. """ - ds = ds.rename({"time": "lead_time"}) - - # The evaluation configuration is used to set the lead time range and resolution. - ds["lead_time"] = np.array( - [i for i in range(0, 241, 6)], dtype="timedelta64[h]" - ).astype("timedelta64[ns]") + # If the dataset is a kerchunk, we need to rename the time coordinate to lead_time + # and set the lead time range and resolution. Otherwise, pass through the dataset. + if kerchunk: + ds = ds.rename({"time": "lead_time"}) + # The evaluation configuration is used to set the lead time range and resolution. + ds["lead_time"] = np.array( + [i for i in range(0, 241, 6)], dtype="timedelta64[h]" + ).astype("timedelta64[ns]") if "q" not in ds.variables: # Calculate specific humidity from relative humidity and air temperature ds["specific_humidity"] = calc.specific_humidity_from_relative_humidity( @@ -243,51 +261,39 @@ def get_climatology(quantile: float = 0.85) -> xr.DataArray: ibtracs_target = inputs.IBTrACS() # Forecasts -cira_heatwave_forecast = inputs.KerchunkForecast( - name="FourCastNetv2", - source="gs://extremeweatherbench/FOUR_v200_GFS.parq", +cira_fcnv2_heatwave_forecast = inputs.get_cira_icechunk( + model_name="FOUR_v200_GFS", variables=["surface_air_temperature"], - variable_mapping=inputs.CIRA_metadata_variable_mapping, - storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, - preprocess=_preprocess_cira_forecast_dataset, + name="FourCastNetv2", ) -cira_freeze_forecast = inputs.KerchunkForecast( - name="FourCastNetv2", - source="gs://extremeweatherbench/FOUR_v200_GFS.parq", +cira_fcnv2_freeze_forecast = inputs.get_cira_icechunk( + model_name="FOUR_v200_GFS", variables=["surface_air_temperature"], - variable_mapping=inputs.CIRA_metadata_variable_mapping, - storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, - preprocess=_preprocess_cira_forecast_dataset, + name="FourCastNetv2", ) -cira_tropical_cyclone_forecast = inputs.KerchunkForecast( - name="FourCastNetv2", - source="gs://extremeweatherbench/FOUR_v200_GFS.parq", +cira_fcnv2_tropical_cyclone_forecast = inputs.get_cira_icechunk( + model_name="FOUR_v200_GFS", variables=[derived.TropicalCycloneTrackVariables()], - variable_mapping=inputs.CIRA_metadata_variable_mapping, - storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, + name="FourCastNetv2", preprocess=_preprocess_cira_tc_forecast_dataset, ) -cira_atmospheric_river_forecast = inputs.KerchunkForecast( - name="FourCastNetv2", - source="gs://extremeweatherbench/FOUR_v200_GFS.parq", +cira_fcnv2_atmospheric_river_forecast = inputs.get_cira_icechunk( + model_name="FOUR_v200_GFS", variables=[ derived.AtmosphericRiverVariables( output_variables=["atmospheric_river_land_intersection"] ) ], - variable_mapping=inputs.CIRA_metadata_variable_mapping, - storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, - preprocess=_preprocess_ar_cira_forecast_dataset, + name="FourCastNetv2", + preprocess=_preprocess_cira_ar_forecast_dataset, ) -cira_severe_convection_forecast = inputs.KerchunkForecast( - name="FourCastNetv2", - source="gs://extremeweatherbench/FOUR_v200_GFS.parq", +cira_fcnv2_severe_convection_forecast = inputs.get_cira_icechunk( + model_name="FOUR_v200_GFS", variables=[derived.CravenBrooksSignificantSevere()], - variable_mapping=inputs.CIRA_metadata_variable_mapping, - storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, + name="FourCastNetv2", preprocess=_preprocess_severe_cira_forecast_dataset, ) @@ -363,37 +369,37 @@ def get_brightband_evaluation_objects() -> list[inputs.EvaluationObject]: event_type="heat_wave", metric_list=heatwave_metric_list, target=era5_heatwave_target, - forecast=cira_heatwave_forecast, + forecast=cira_fcnv2_heatwave_forecast, ), inputs.EvaluationObject( event_type="heat_wave", metric_list=heatwave_metric_list, target=ghcn_heatwave_target, - forecast=cira_heatwave_forecast, + forecast=cira_fcnv2_heatwave_forecast, ), inputs.EvaluationObject( event_type="freeze", metric_list=freeze_metric_list, target=era5_freeze_target, - forecast=cira_freeze_forecast, + forecast=cira_fcnv2_freeze_forecast, ), inputs.EvaluationObject( event_type="freeze", metric_list=freeze_metric_list, target=ghcn_freeze_target, - forecast=cira_freeze_forecast, + forecast=cira_fcnv2_freeze_forecast, ), inputs.EvaluationObject( event_type="severe_convection", metric_list=pph_metric_list, target=pph_target, - forecast=cira_severe_convection_forecast, + forecast=cira_fcnv2_severe_convection_forecast, ), inputs.EvaluationObject( event_type="severe_convection", metric_list=lsr_metric_list, target=lsr_target, - forecast=cira_severe_convection_forecast, + forecast=cira_fcnv2_severe_convection_forecast, ), inputs.EvaluationObject( event_type="atmospheric_river", @@ -403,12 +409,12 @@ def get_brightband_evaluation_objects() -> list[inputs.EvaluationObject]: metrics.EarlySignal(), ], target=era5_atmospheric_river_target, - forecast=cira_atmospheric_river_forecast, + forecast=cira_fcnv2_atmospheric_river_forecast, ), inputs.EvaluationObject( event_type="tropical_cyclone", metric_list=composite_landfall_metrics, target=ibtracs_target, - forecast=cira_tropical_cyclone_forecast, + forecast=cira_fcnv2_tropical_cyclone_forecast, ), ] diff --git a/src/extremeweatherbench/inputs.py b/src/extremeweatherbench/inputs.py index e1353fc1..37aba5a5 100644 --- a/src/extremeweatherbench/inputs.py +++ b/src/extremeweatherbench/inputs.py @@ -148,6 +148,17 @@ {"s3://noaa-oar-mlwp-data/": icechunk.s3_credentials(anonymous=True)} ) +CIRA_MODEL_NAMES = [ + "AURO_v100_GFS", + "FOUR_v200_IFS", + "PANG_v100_IFS", + "FOUR_v200_GFS", + "GRAP_v100_GFS", + "AURO_v100_IFS", + "PANG_v100_GFS", + "GRAP_v100_IFS", +] + def _default_preprocess(input_data: IncomingDataInput) -> IncomingDataInput: """Default forecast preprocess function that does nothing.""" @@ -1271,3 +1282,51 @@ def check_for_missing_data( return False else: return True + + +def get_cira_icechunk( + model_name: str, + variables: list[Union[str, derived.DerivedVariable]] = [], + preprocess: Callable = _default_preprocess, + name: Optional[str] = None, +) -> XarrayForecast: + """Get a CIRA icechunk forecast object for a given model name. + + Args: + model_name: The name of the model from CIRA to get the forecast object for. For + example, "FOUR_v200_GFS". For a list of available models, see + `extremeweatherbench.defaults.CIRA_MODEL_NAMES`. + variables: The variables to select from the model. Defaults to all variables. + preprocess: The preprocessing function to apply to the model. Defaults to the + default passthrough preprocess function. + name: The name of the forecast object. Defaults to model_name by default unless + `name` is provided. + Returns: + An XarrayForecast object for the given model. + """ + # Check if the model name is valid + if model_name not in CIRA_MODEL_NAMES: + raise ValueError( + f"Model name {model_name} not found in CIRA_MODEL_NAMES. Model names must be one of: {CIRA_MODEL_NAMES}" + ) + + # Get the CIRA icechunkstorage + cira_storage = icechunk.gcs_storage( + bucket="extremeweatherbench", prefix="cira-icechunk", anonymous=True + ) + + # The models are distinct groups within the icechunk store; open the group + # corresponding to the model name + cira_model_ds = open_icechunk_dataset_from_datatree( + cira_storage, model_name, authorize_virtual_chunk_access=CIRA_CREDENTIALS + ) + + # Create the XarrayForecast object for the given model + cira_model_forecast = XarrayForecast( + ds=cira_model_ds, + variables=variables, + variable_mapping=CIRA_metadata_variable_mapping, + name=name if name else model_name, + preprocess=preprocess, + ) + return cira_model_forecast diff --git a/tests/test_defaults.py b/tests/test_defaults.py index cbbea8bd..ff452c91 100644 --- a/tests/test_defaults.py +++ b/tests/test_defaults.py @@ -120,10 +120,22 @@ def test_target_objects_exist(self): def test_forecast_objects_exist(self): """Test that forecast objects are properly defined.""" - assert hasattr(defaults, "cira_heatwave_forecast") - assert hasattr(defaults, "cira_freeze_forecast") - assert isinstance(defaults.cira_heatwave_forecast, inputs.KerchunkForecast) - assert isinstance(defaults.cira_freeze_forecast, inputs.KerchunkForecast) + assert hasattr(defaults, "cira_fcnv2_heatwave_forecast") + assert hasattr(defaults, "cira_fcnv2_freeze_forecast") + assert hasattr(defaults, "cira_fcnv2_tropical_cyclone_forecast") + assert hasattr(defaults, "cira_fcnv2_atmospheric_river_forecast") + assert hasattr(defaults, "cira_fcnv2_severe_convection_forecast") + assert isinstance(defaults.cira_fcnv2_heatwave_forecast, inputs.XarrayForecast) + assert isinstance(defaults.cira_fcnv2_freeze_forecast, inputs.XarrayForecast) + assert isinstance( + defaults.cira_fcnv2_tropical_cyclone_forecast, inputs.XarrayForecast + ) + assert isinstance( + defaults.cira_fcnv2_atmospheric_river_forecast, inputs.XarrayForecast + ) + assert isinstance( + defaults.cira_fcnv2_severe_convection_forecast, inputs.XarrayForecast + ) def test_era5_heatwave_target_configuration(self): """Test ERA5 heatwave target configuration.""" @@ -149,21 +161,6 @@ def test_era5_freeze_target_configuration(self): for key, value in expected_mapping.items(): assert target.variable_mapping[key] == value - def test_cira_forecasts_have_preprocess_function(self): - """Test that CIRA forecasts have the preprocess function set.""" - assert defaults.cira_heatwave_forecast.preprocess is not None - assert defaults.cira_freeze_forecast.preprocess is not None - - # Test that the preprocess function is the expected one - assert ( - defaults.cira_heatwave_forecast.preprocess - == defaults._preprocess_cira_forecast_dataset - ) - assert ( - defaults.cira_freeze_forecast.preprocess - == defaults._preprocess_cira_forecast_dataset - ) - def test_get_brightband_evaluation_objects_no_exceptions(self): """Test that get_brightband_evaluation_objects runs without exceptions.""" try: @@ -173,3 +170,63 @@ def test_get_brightband_evaluation_objects_no_exceptions(self): assert len(result) > 0 except Exception as e: pytest.fail(f"get_brightband_evaluation_objects raised an exception: {e}") + + +class TestCiraFcnv2PreprocessFunctions: + """Tests that each cira_fcnv2 forecast has the correct preprocessing function.""" + + def test_heatwave_forecast_has_default_preprocess(self): + """Test that cira_fcnv2_heatwave_forecast uses default preprocess.""" + forecast = defaults.cira_fcnv2_heatwave_forecast + assert forecast.preprocess == inputs._default_preprocess + + def test_freeze_forecast_has_default_preprocess(self): + """Test that cira_fcnv2_freeze_forecast uses default preprocess.""" + forecast = defaults.cira_fcnv2_freeze_forecast + assert forecast.preprocess == inputs._default_preprocess + + def test_tropical_cyclone_forecast_has_tc_preprocess(self): + """Test that cira_fcnv2_tropical_cyclone_forecast uses TC preprocess.""" + forecast = defaults.cira_fcnv2_tropical_cyclone_forecast + assert forecast.preprocess == defaults._preprocess_cira_tc_forecast_dataset + + def test_atmospheric_river_forecast_has_ar_preprocess(self): + """Test that cira_fcnv2_atmospheric_river_forecast uses AR preprocess.""" + forecast = defaults.cira_fcnv2_atmospheric_river_forecast + assert forecast.preprocess == defaults._preprocess_cira_ar_forecast_dataset + + def test_severe_convection_forecast_has_severe_preprocess(self): + """Test that cira_fcnv2_severe_convection_forecast uses severe preprocess.""" + forecast = defaults.cira_fcnv2_severe_convection_forecast + assert forecast.preprocess == defaults._preprocess_severe_cira_forecast_dataset + + def test_all_forecasts_have_preprocess_attribute(self): + """Test that all cira_fcnv2 forecasts have a preprocess attribute set.""" + forecasts = [ + defaults.cira_fcnv2_heatwave_forecast, + defaults.cira_fcnv2_freeze_forecast, + defaults.cira_fcnv2_tropical_cyclone_forecast, + defaults.cira_fcnv2_atmospheric_river_forecast, + defaults.cira_fcnv2_severe_convection_forecast, + ] + for forecast in forecasts: + assert hasattr(forecast, "preprocess") + assert forecast.preprocess is not None + assert callable(forecast.preprocess) + + def test_preprocess_functions_are_distinct_where_expected(self): + """Test that different event types use different preprocess functions.""" + # TC, AR, and severe should have distinct preprocess functions + tc_preprocess = defaults.cira_fcnv2_tropical_cyclone_forecast.preprocess + ar_preprocess = defaults.cira_fcnv2_atmospheric_river_forecast.preprocess + severe_preprocess = defaults.cira_fcnv2_severe_convection_forecast.preprocess + + assert tc_preprocess != ar_preprocess + assert tc_preprocess != severe_preprocess + # Note: AR and severe could be the same or different depending on impl + + def test_heatwave_and_freeze_use_same_preprocess(self): + """Test that heatwave and freeze forecasts use the same preprocess.""" + heatwave_preprocess = defaults.cira_fcnv2_heatwave_forecast.preprocess + freeze_preprocess = defaults.cira_fcnv2_freeze_forecast.preprocess + assert heatwave_preprocess == freeze_preprocess diff --git a/tests/test_inputs.py b/tests/test_inputs.py index 4531d150..0725afb0 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -2233,7 +2233,9 @@ def test_xarray_forecast_none_handling_for_optional_params( """Test that None values are properly converted to empty defaults.""" # Explicitly pass None to test the None handling in __init__ forecast = inputs.XarrayForecast( - ds=sample_forecast_with_valid_time, variables=None, variable_mapping=None + ds=sample_forecast_with_valid_time, + variables=None, + variable_mapping=None, # type: ignore ) # Should be converted to empty containers @@ -2323,3 +2325,266 @@ def test_default_preprocess(): df = pd.DataFrame({"a": [1, 2, 3]}) result_df = inputs._default_preprocess(df) assert result_df is df + + +class TestGetCIRAIcechunk: + """Tests for get_cira_icechunk function.""" + + def test_invalid_model_name_raises_value_error(self): + """Test that an invalid model name raises ValueError.""" + with pytest.raises(ValueError) as exc_info: + inputs.get_cira_icechunk(model_name="INVALID_MODEL") + + assert "INVALID_MODEL" in str(exc_info.value) + assert "CIRA_MODEL_NAMES" in str(exc_info.value) + + def test_empty_model_name_raises_value_error(self): + """Test that an empty model name raises ValueError.""" + with pytest.raises(ValueError): + inputs.get_cira_icechunk(model_name="") + + def test_none_model_name_raises_error(self): + """Test that None as model name raises appropriate error.""" + with pytest.raises((ValueError, TypeError)): + inputs.get_cira_icechunk(model_name=None) # type: ignore + + def test_case_sensitive_model_name(self): + """Test that model name matching is case-sensitive.""" + # Lowercase version of a valid model name should fail + with pytest.raises(ValueError): + inputs.get_cira_icechunk(model_name="four_v200_gfs") + + # Mixed case should fail + with pytest.raises(ValueError): + inputs.get_cira_icechunk(model_name="Four_V200_GFS") + + def test_partial_model_name_raises_value_error(self): + """Test that partial model names are rejected.""" + with pytest.raises(ValueError): + inputs.get_cira_icechunk(model_name="FOUR") + + with pytest.raises(ValueError): + inputs.get_cira_icechunk(model_name="GFS") + + def test_model_name_with_extra_chars_raises_value_error(self): + """Test that model names with extra characters are rejected.""" + with pytest.raises(ValueError): + inputs.get_cira_icechunk(model_name="FOUR_v200_GFS_extra") + + with pytest.raises(ValueError): + inputs.get_cira_icechunk(model_name=" FOUR_v200_GFS") + + def test_error_message_lists_valid_model_names(self): + """Test that the error message includes the list of valid model names.""" + with pytest.raises(ValueError) as exc_info: + inputs.get_cira_icechunk(model_name="BAD_MODEL") + + error_msg = str(exc_info.value) + # Check that at least some valid model names are shown in the error + assert "FOUR_v200_GFS" in error_msg or "CIRA_MODEL_NAMES" in error_msg + + @mock.patch("extremeweatherbench.inputs.icechunk.gcs_storage") + @mock.patch("extremeweatherbench.inputs.open_icechunk_dataset_from_datatree") + @mock.patch("extremeweatherbench.inputs.XarrayForecast") + def test_valid_model_name_four_v200_gfs( + self, mock_forecast, mock_open, mock_storage + ): + """Test that FOUR_v200_GFS is a valid model name.""" + mock_storage.return_value = mock.MagicMock() + mock_open.return_value = mock.MagicMock() + mock_forecast.return_value = mock.MagicMock() + + result = inputs.get_cira_icechunk(model_name="FOUR_v200_GFS") + + assert result is not None + mock_storage.assert_called_once() + mock_open.assert_called_once() + + @mock.patch("extremeweatherbench.inputs.icechunk.gcs_storage") + @mock.patch("extremeweatherbench.inputs.open_icechunk_dataset_from_datatree") + @mock.patch("extremeweatherbench.inputs.XarrayForecast") + def test_valid_model_name_auro_v100_gfs( + self, mock_forecast, mock_open, mock_storage + ): + """Test that AURO_v100_GFS is a valid model name.""" + mock_storage.return_value = mock.MagicMock() + mock_open.return_value = mock.MagicMock() + mock_forecast.return_value = mock.MagicMock() + + result = inputs.get_cira_icechunk(model_name="AURO_v100_GFS") + + assert result is not None + + @mock.patch("extremeweatherbench.inputs.icechunk.gcs_storage") + @mock.patch("extremeweatherbench.inputs.open_icechunk_dataset_from_datatree") + @mock.patch("extremeweatherbench.inputs.XarrayForecast") + def test_all_cira_model_names_are_valid( + self, mock_forecast, mock_open, mock_storage + ): + """Test that all model names in CIRA_MODEL_NAMES are accepted.""" + mock_storage.return_value = mock.MagicMock() + mock_open.return_value = mock.MagicMock() + mock_forecast.return_value = mock.MagicMock() + + for model_name in inputs.CIRA_MODEL_NAMES: + result = inputs.get_cira_icechunk(model_name=model_name) + assert result is not None, f"Model {model_name} should be valid" + + @mock.patch("extremeweatherbench.inputs.icechunk.gcs_storage") + @mock.patch("extremeweatherbench.inputs.open_icechunk_dataset_from_datatree") + @mock.patch("extremeweatherbench.inputs.XarrayForecast") + def test_custom_name_parameter(self, mock_forecast, mock_open, mock_storage): + """Test that a custom name parameter is passed to XarrayForecast.""" + mock_storage.return_value = mock.MagicMock() + mock_open.return_value = mock.MagicMock() + mock_forecast.return_value = mock.MagicMock() + + inputs.get_cira_icechunk(model_name="FOUR_v200_GFS", name="CustomName") + + # Check that XarrayForecast was called with the custom name + call_kwargs = mock_forecast.call_args[1] + assert call_kwargs["name"] == "CustomName" + + @mock.patch("extremeweatherbench.inputs.icechunk.gcs_storage") + @mock.patch("extremeweatherbench.inputs.open_icechunk_dataset_from_datatree") + @mock.patch("extremeweatherbench.inputs.XarrayForecast") + def test_default_name_uses_model_name(self, mock_forecast, mock_open, mock_storage): + """Test that name inputs to model_name when not provided.""" + mock_storage.return_value = mock.MagicMock() + mock_open.return_value = mock.MagicMock() + mock_forecast.return_value = mock.MagicMock() + + inputs.get_cira_icechunk(model_name="FOUR_v200_GFS") + + call_kwargs = mock_forecast.call_args[1] + assert call_kwargs["name"] == "FOUR_v200_GFS" + + @mock.patch("extremeweatherbench.inputs.icechunk.gcs_storage") + @mock.patch("extremeweatherbench.inputs.open_icechunk_dataset_from_datatree") + @mock.patch("extremeweatherbench.inputs.XarrayForecast") + def test_empty_variables_list(self, mock_forecast, mock_open, mock_storage): + """Test that empty variables list is valid.""" + mock_storage.return_value = mock.MagicMock() + mock_open.return_value = mock.MagicMock() + mock_forecast.return_value = mock.MagicMock() + + result = inputs.get_cira_icechunk(model_name="FOUR_v200_GFS", variables=[]) + + assert result is not None + call_kwargs = mock_forecast.call_args[1] + assert call_kwargs["variables"] == [] + + @mock.patch("extremeweatherbench.inputs.icechunk.gcs_storage") + @mock.patch("extremeweatherbench.inputs.open_icechunk_dataset_from_datatree") + @mock.patch("extremeweatherbench.inputs.XarrayForecast") + def test_custom_variables_list(self, mock_forecast, mock_open, mock_storage): + """Test that a custom variables list is passed through.""" + mock_storage.return_value = mock.MagicMock() + mock_open.return_value = mock.MagicMock() + mock_forecast.return_value = mock.MagicMock() + + variables = ["surface_air_temperature", "air_pressure"] + inputs.get_cira_icechunk(model_name="FOUR_v200_GFS", variables=variables) + + call_kwargs = mock_forecast.call_args[1] + assert call_kwargs["variables"] == variables + + @mock.patch("extremeweatherbench.inputs.icechunk.gcs_storage") + @mock.patch("extremeweatherbench.inputs.open_icechunk_dataset_from_datatree") + @mock.patch("extremeweatherbench.inputs.XarrayForecast") + def test_custom_preprocess_function(self, mock_forecast, mock_open, mock_storage): + """Test that a custom preprocess function is passed through.""" + mock_storage.return_value = mock.MagicMock() + mock_open.return_value = mock.MagicMock() + mock_forecast.return_value = mock.MagicMock() + + def custom_preprocess(ds: xr.Dataset) -> xr.Dataset: + return ds + + inputs.get_cira_icechunk( + model_name="FOUR_v200_GFS", preprocess=custom_preprocess + ) + + call_kwargs = mock_forecast.call_args[1] + assert call_kwargs["preprocess"] == custom_preprocess + + @mock.patch("extremeweatherbench.inputs.icechunk.gcs_storage") + @mock.patch("extremeweatherbench.inputs.open_icechunk_dataset_from_datatree") + @mock.patch("extremeweatherbench.inputs.XarrayForecast") + def test_returns_xarray_forecast_object( + self, mock_forecast, mock_open, mock_storage + ): + """Test that the function returns an XarrayForecast object.""" + mock_storage.return_value = mock.MagicMock() + mock_open.return_value = mock.MagicMock() + expected_forecast = mock.MagicMock() + mock_forecast.return_value = expected_forecast + + result = inputs.get_cira_icechunk(model_name="FOUR_v200_GFS") + + assert result is expected_forecast + + @mock.patch("extremeweatherbench.inputs.icechunk.gcs_storage") + @mock.patch("extremeweatherbench.inputs.open_icechunk_dataset_from_datatree") + @mock.patch("extremeweatherbench.inputs.XarrayForecast") + def test_gcs_storage_configuration(self, mock_forecast, mock_open, mock_storage): + """Test that GCS storage is configured with correct parameters.""" + mock_storage.return_value = mock.MagicMock() + mock_open.return_value = mock.MagicMock() + mock_forecast.return_value = mock.MagicMock() + + inputs.get_cira_icechunk(model_name="FOUR_v200_GFS") + + mock_storage.assert_called_once_with( + bucket="extremeweatherbench", prefix="cira-icechunk", anonymous=True + ) + + @mock.patch("extremeweatherbench.inputs.icechunk.gcs_storage") + @mock.patch("extremeweatherbench.inputs.open_icechunk_dataset_from_datatree") + @mock.patch("extremeweatherbench.inputs.XarrayForecast") + def test_uses_cira_variable_mapping(self, mock_forecast, mock_open, mock_storage): + """Test that CIRA metadata variable mapping is used.""" + mock_storage.return_value = mock.MagicMock() + mock_open.return_value = mock.MagicMock() + mock_forecast.return_value = mock.MagicMock() + + inputs.get_cira_icechunk(model_name="FOUR_v200_GFS") + + call_kwargs = mock_forecast.call_args[1] + assert call_kwargs["variable_mapping"] == inputs.CIRA_metadata_variable_mapping + + +class TestCiraModelNames: + """Tests for CIRA_MODEL_NAMES constant.""" + + def test_cira_model_names_is_list(self): + """Test that CIRA_MODEL_NAMES is a list.""" + assert isinstance(inputs.CIRA_MODEL_NAMES, list) + + def test_cira_model_names_not_empty(self): + """Test that CIRA_MODEL_NAMES is not empty.""" + assert len(inputs.CIRA_MODEL_NAMES) > 0 + + def test_cira_model_names_contains_expected_models(self): + """Test that CIRA_MODEL_NAMES contains expected model names.""" + expected_models = [ + "FOUR_v200_GFS", + "FOUR_v200_IFS", + "AURO_v100_GFS", + "AURO_v100_IFS", + "PANG_v100_GFS", + "PANG_v100_IFS", + "GRAP_v100_GFS", + "GRAP_v100_IFS", + ] + for model in expected_models: + assert model in inputs.CIRA_MODEL_NAMES + + def test_cira_model_names_all_strings(self): + """Test that all entries in CIRA_MODEL_NAMES are strings.""" + for model in inputs.CIRA_MODEL_NAMES: + assert isinstance(model, str) + + def test_cira_model_names_no_duplicates(self): + """Test that CIRA_MODEL_NAMES has no duplicate entries.""" + assert len(inputs.CIRA_MODEL_NAMES) == len(set(inputs.CIRA_MODEL_NAMES)) From a0ce4c56408980933d0e7950687e5a990e2044b3 Mon Sep 17 00:00:00 2001 From: aaTman Date: Mon, 26 Jan 2026 19:15:23 +0000 Subject: [PATCH 13/13] update defaults var refs --- src/extremeweatherbench/__init__.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/extremeweatherbench/__init__.py b/src/extremeweatherbench/__init__.py index 8f1ac54a..af479c2f 100644 --- a/src/extremeweatherbench/__init__.py +++ b/src/extremeweatherbench/__init__.py @@ -49,11 +49,11 @@ from extremeweatherbench.defaults import ( DEFAULT_COORDINATE_VARIABLES, DEFAULT_VARIABLE_NAMES, - cira_atmospheric_river_forecast, - cira_freeze_forecast, - cira_heatwave_forecast, - cira_severe_convection_forecast, - cira_tropical_cyclone_forecast, + cira_fcnv2_atmospheric_river_forecast, + cira_fcnv2_freeze_forecast, + cira_fcnv2_heatwave_forecast, + cira_fcnv2_severe_convection_forecast, + cira_fcnv2_tropical_cyclone_forecast, era5_atmospheric_river_target, era5_freeze_target, era5_heatwave_target, @@ -227,11 +227,11 @@ # defaults "DEFAULT_COORDINATE_VARIABLES", "DEFAULT_VARIABLE_NAMES", - "cira_atmospheric_river_forecast", - "cira_freeze_forecast", - "cira_heatwave_forecast", - "cira_severe_convection_forecast", - "cira_tropical_cyclone_forecast", + "cira_fcnv2_atmospheric_river_forecast", + "cira_fcnv2_freeze_forecast", + "cira_fcnv2_heatwave_forecast", + "cira_fcnv2_severe_convection_forecast", + "cira_fcnv2_tropical_cyclone_forecast", "era5_atmospheric_river_target", "era5_freeze_target", "era5_heatwave_target",