From 8715835c70f50358f18c494aaf2e5d2d928c3d06 Mon Sep 17 00:00:00 2001 From: Taylor Mandelbaum Date: Sat, 17 Jan 2026 09:50:45 -0500 Subject: [PATCH 01/15] update build-system and project --- pyproject.toml | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4994a179..5b72c306 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,25 @@ name = "extremeweatherbench" version = "0.2.0" description = "Benchmarking weather and weather AI models using extreme events" +keywords = [ + "weather", + "extreme events", + "benchmarking", + "forecasting", + "climate", +] +license = { file = "LICENSE" } readme = "README.md" +classifiers = [ + "Intended Audience :: Science/Research", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Topic :: Scientific/Engineering :: Atmospheric Science", +] requires-python = ">=3.11,<3.14" dependencies = [ "dacite>=1.8.1", @@ -73,8 +91,8 @@ docs = [ ] [build-system] -requires = ["setuptools", "wheel"] -build-backend = "setuptools.build_meta" +requires = ["hatchling >= 1.26"] +build-backend = "hatchling.build" [tool.setuptools] packages = ["extremeweatherbench"] From 793c482ab4fa8ee7ea706c2863f70b3396e811d7 Mon Sep 17 00:00:00 2001 From: Taylor Mandelbaum Date: Sat, 17 Jan 2026 10:12:37 -0500 Subject: [PATCH 02/15] update workflows, publish, and pyproject --- .github/workflows/publish.yaml | 56 +++++++++++++++++++ .../{ci.yaml => run-pre-commit.yaml} | 32 ++++------- .github/workflows/run-tests.yaml | 39 +++++++++++++ pyproject.toml | 29 ++++++++++ 4 files changed, 136 insertions(+), 20 deletions(-) create mode 100644 .github/workflows/publish.yaml rename .github/workflows/{ci.yaml => run-pre-commit.yaml} (63%) create mode 100644 .github/workflows/run-tests.yaml diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml new file mode 100644 index 00000000..a9324bc7 --- /dev/null +++ b/.github/workflows/publish.yaml @@ -0,0 +1,56 @@ +name: Upload Python Package + +on: + release: + types: [published] + +permissions: + contents: read + +jobs: + release-build: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.x" + + - name: Build release distributions + run: | + python -m pip install build + python -m build + + - name: Upload distributions + uses: actions/upload-artifact@v4 + with: + name: release-dists + path: dist/ + + pypi-publish: + runs-on: ubuntu-latest + needs: + - release-build + permissions: + # IMPORTANT: this permission is mandatory for trusted publishing + id-token: write + + # Dedicated environments with protections for publishing are strongly recommended. + # For more information, see: https://docs.github.com/en/actions/deployment/targeting-different-environments/using-environments-for-deployment#deployment-protection-rules + environment: + name: release + url: https://pypi.org/p/extremeweatherbench + + steps: + - name: Retrieve release distributions + uses: actions/download-artifact@v4 + with: + name: release-dists + path: dist/ + + - name: Publish release distributions to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + packages-dir: dist/ diff --git a/.github/workflows/ci.yaml b/.github/workflows/run-pre-commit.yaml similarity index 63% rename from .github/workflows/ci.yaml rename to .github/workflows/run-pre-commit.yaml index e042a833..a4fe50eb 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/run-pre-commit.yaml @@ -1,23 +1,26 @@ -name: ci +name: Run pre-commit -on: - pull_request: - push: - branches: [main] +on: [push, pull_request] + +permissions: + contents: read jobs: pre-commit: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v3 + - uses: actions/checkout@v4 + - name: Set up Python env with uv + uses: actions/setup-python@v5 + with: + python-version: "3.13" - uses: pre-commit/action@v3.0.1 test: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.12"] + python-version: ["3.10", "3.11", "3.12", "3.13"] steps: - uses: actions/checkout@v4 @@ -33,21 +36,10 @@ jobs: python-version-file: "pyproject.toml" - name: Install the project - run: uv sync --all-extras --dev + run: uv sync --all-extras --all-groups - name: Run tests run: uv run pytest - name: Generate Coverage Report run: uv run coverage report -m - - golden-tests: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: extractions/setup-just@v3 - with: - just-version: 1.43.1 - - - name: Run golden tests with just - run: just golden-tests diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml new file mode 100644 index 00000000..98895402 --- /dev/null +++ b/.github/workflows/run-tests.yaml @@ -0,0 +1,39 @@ +name: Run tests + +on: + push: + branches: [main] + pull_request: + branches: [main, develop] + +permissions: + contents: read + +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10", "3.11", "3.12", "3.13"] + steps: + - uses: actions/checkout@v4 + + - name: Set up Python env with uv + uses: astral-sh/setup-uv@v4 + with: + version: "0.5.6" + enable-cache: true + + - name: "Set up Python" + uses: actions/setup-python@v5 + with: + python-version-file: "pyproject.toml" + + - name: Install the project + run: uv sync --all-extras --all-groups + + - name: Run tests + run: uv run pytest + + - name: Generate Coverage Report + run: uv run coverage report -m diff --git a/pyproject.toml b/pyproject.toml index 5b72c306..b94dd6c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,6 +90,8 @@ docs = [ "pymdown-extensions>=10.19.1", ] +complete = ["extremeweatherbench[data-prep,multiprocessing]"] + [build-system] requires = ["hatchling >= 1.26"] build-backend = "hatchling.build" @@ -102,6 +104,11 @@ include-package-data = true [tool.setuptools.package-data] extremeweatherbench = ["data/**/*", "data/**/.*"] +[project.urls] +Documentation = "https://extremeweatherbench.readthedocs.io/" +Repository = "https://github.com/brightbandtech/extremeweatherbench" + + [project.scripts] ewb = "extremeweatherbench.evaluate_cli:cli_runner" @@ -143,3 +150,25 @@ docstring-code-line-length = "dynamic" [tool.ruff.lint.isort] case-sensitive = true + +[tool.semantic_release] +version_toml = ["pyproject.toml:project.version"] +branch = "main" +dist_path = "dist/" +upload_to_pypi = false +remote = { type = "github" } +commit_author = "semantic-release " +commit_parser = "conventional" +commit_parser_options = { parse_squash_commits = "false", parse_merge_commits = "true" } +minor_tag = "[minor]" +patch_tag = "[patch]" +major_tag = "[major]" +build_command = """ + uv lock --offline + git add uv.lock + uv build +""" +# Only create GitHub releases for the current version, not historical ones +github_release_mode = "latest" +# Ensure assets are only uploaded for the current release, not past ones +upload_assets_for_all_releases = false From a61ddf075f7fb67da5e64cea9171c7f96ad76dc7 Mon Sep 17 00:00:00 2001 From: Taylor Mandelbaum Date: Sat, 17 Jan 2026 10:19:16 -0500 Subject: [PATCH 03/15] add justfile and twine --- Justfile | 56 +++++++++++++++++++++++++++++++++++++++++++++++--- pyproject.toml | 2 ++ 2 files changed, 55 insertions(+), 3 deletions(-) diff --git a/Justfile b/Justfile index f0c45f50..cddbd2d1 100644 --- a/Justfile +++ b/Justfile @@ -1,7 +1,57 @@ +# NOTE: We automatically load a .env file containing the "GH_TOKEN" environment variable +# for use with semantic-release. If this isn't present, then those commands will likely fail. +set dotenv-load + # List all available recipes default: @just --list -# Placeholder for golden tests -golden-tests: - @just --list \ No newline at end of file +# Run the complete test suite +test: + @echo "Running tests" + uv run pytest + +# Serve a local build of the project documentation at http://localhost:8000 +serve-docs: + @echo "Serving docs at http://localhost:8000" + uv run --extra docs mkdocs serve + +# Build the project documentation +build-docs: + @echo "Building docs" + uv run --extra docs mkdocs build + +# Run the pre-commit hooks on all files in the repo +pre-commit: + @echo "Running pre-commit hooks" + uv run pre-commit run --all-files + +# Run the coverage report +coverage: + @echo "Running coverage report" + uv run coverage run -m pytest + uv run coverage report + +# Determine the next version number +next-version: + @echo "Determining next version" + uv run semantic-release version --print + +# Create a minor release +minor-release: + @echo "Creating minor release" + uv run semantic-release -vvv --noop version --minor --no-changelog + +# Create a patch release +patch-release: + @echo "Creating patch release" + uv run semantic-release -vvv --noop version --patch --no-changelog + +# Upload a release to PyPI +pypi-upload tag: + @echo "Uploading release {{tag}} to PyPI" + git checkout {{tag}} + rm -rf dist + uv run python -m build + uv run twine upload dist/* + git checkout - \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index b94dd6c0..1960a539 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,6 +81,8 @@ dev = [ "types-pytz>=2025.2.0.20250809", "types-pyyaml>=6.0.12.20241230", "types-tqdm>=4.67.0.20250809", + "python-semantic-release>=10.3.0", + "twine>=5.1.1", ] docs = [ "mkdocs>=1.6.1", From 6bb1153c386f8528ff4a2f0b7b5dc7025e01e11e Mon Sep 17 00:00:00 2001 From: Taylor Mandelbaum Date: Sat, 17 Jan 2026 10:37:56 -0500 Subject: [PATCH 04/15] update publish yaml --- .github/workflows/publish.yaml | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index a9324bc7..04dfe2e9 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -9,6 +9,7 @@ permissions: jobs: release-build: + name: Build release distribution runs-on: ubuntu-latest steps: @@ -30,7 +31,9 @@ jobs: path: dist/ pypi-publish: + name: Publish release distribution to PyPI runs-on: ubuntu-latest + if: startsWith(github.ref, 'refs/tags/') needs: - release-build permissions: @@ -40,7 +43,7 @@ jobs: # Dedicated environments with protections for publishing are strongly recommended. # For more information, see: https://docs.github.com/en/actions/deployment/targeting-different-environments/using-environments-for-deployment#deployment-protection-rules environment: - name: release + name: ewb-pypi-release url: https://pypi.org/p/extremeweatherbench steps: @@ -54,3 +57,28 @@ jobs: uses: pypa/gh-action-pypi-publish@release/v1 with: packages-dir: dist/ + + publish-to-testpypi: + name: Publish release distribution to TestPyPI + runs-on: ubuntu-latest + needs: + - release-build + + permissions: + id-token: write + + environment: + name: ewb-testpypi-release + url: https://test.pypi.org/p/extremeweatherbench + + steps: + - name: Retrieve release distributions + uses: actions/download-artifact@v4 + with: + name: release-dists + path: dist/ + + - name: Publish release distributions to TestPyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + repository-url: https://test.pypi.org/legacy/ From 1e2c0fcdef5daab24a5592345d11f8a95410f370 Mon Sep 17 00:00:00 2001 From: Taylor Mandelbaum Date: Sat, 17 Jan 2026 10:46:57 -0500 Subject: [PATCH 05/15] change to python 3.10 as minimum requirement --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 1960a539..e6c41064 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ classifiers = [ "Programming Language :: Python :: 3.13", "Topic :: Scientific/Engineering :: Atmospheric Science", ] -requires-python = ">=3.11,<3.14" +requires-python = ">=3.10,<3.14" dependencies = [ "dacite>=1.8.1", "gcsfs>=2024.12.0", From 6856f4957dda534ffe8b38bde7245b0a896a908a Mon Sep 17 00:00:00 2001 From: Taylor Mandelbaum Date: Sat, 17 Jan 2026 10:49:07 -0500 Subject: [PATCH 06/15] kerchunk needs 3.11, swapping pyproject and tests to remove 3.10 --- .github/workflows/run-pre-commit.yaml | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/run-pre-commit.yaml b/.github/workflows/run-pre-commit.yaml index a4fe50eb..e66851a0 100644 --- a/.github/workflows/run-pre-commit.yaml +++ b/.github/workflows/run-pre-commit.yaml @@ -20,7 +20,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.10", "3.11", "3.12", "3.13"] + python-version: ["3.11", "3.12", "3.13"] steps: - uses: actions/checkout@v4 diff --git a/pyproject.toml b/pyproject.toml index e6c41064..1960a539 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ classifiers = [ "Programming Language :: Python :: 3.13", "Topic :: Scientific/Engineering :: Atmospheric Science", ] -requires-python = ">=3.10,<3.14" +requires-python = ">=3.11,<3.14" dependencies = [ "dacite>=1.8.1", "gcsfs>=2024.12.0", From a5a80dd8ae1f12b7d5d22670b376dde0ed81209b Mon Sep 17 00:00:00 2001 From: Taylor Mandelbaum Date: Mon, 19 Jan 2026 10:57:25 -0500 Subject: [PATCH 07/15] change workflows to use version matrix --- .github/workflows/run-pre-commit.yaml | 2 +- .github/workflows/run-tests.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/run-pre-commit.yaml b/.github/workflows/run-pre-commit.yaml index e66851a0..b7dfaa4a 100644 --- a/.github/workflows/run-pre-commit.yaml +++ b/.github/workflows/run-pre-commit.yaml @@ -33,7 +33,7 @@ jobs: - name: "Set up Python" uses: actions/setup-python@v5 with: - python-version-file: "pyproject.toml" + python-version: ${{ matrix.python-version }} - name: Install the project run: uv sync --all-extras --all-groups diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index 98895402..45fbd669 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -27,7 +27,7 @@ jobs: - name: "Set up Python" uses: actions/setup-python@v5 with: - python-version-file: "pyproject.toml" + python-version: ${{ matrix.python-version }} - name: Install the project run: uv sync --all-extras --all-groups From 80de3b20b92f9da169b199262f80683463a6f0ea Mon Sep 17 00:00:00 2001 From: Taylor Mandelbaum Date: Mon, 19 Jan 2026 11:03:52 -0500 Subject: [PATCH 08/15] align workflows --- .github/workflows/run-pre-commit.yaml | 27 +++++++++------------------ .github/workflows/run-tests.yaml | 6 +++--- 2 files changed, 12 insertions(+), 21 deletions(-) diff --git a/.github/workflows/run-pre-commit.yaml b/.github/workflows/run-pre-commit.yaml index b7dfaa4a..42e6056f 100644 --- a/.github/workflows/run-pre-commit.yaml +++ b/.github/workflows/run-pre-commit.yaml @@ -1,26 +1,20 @@ name: Run pre-commit -on: [push, pull_request] +on: + pull_request: + branches: [main, develop] + push: + branches: [main] permissions: contents: read jobs: - pre-commit: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - name: Set up Python env with uv - uses: actions/setup-python@v5 - with: - python-version: "3.13" - - uses: pre-commit/action@v3.0.1 - - test: + build: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.11", "3.12", "3.13"] + python-version: ["3.10", "3.11", "3.12", "3.13"] steps: - uses: actions/checkout@v4 @@ -38,8 +32,5 @@ jobs: - name: Install the project run: uv sync --all-extras --all-groups - - name: Run tests - run: uv run pytest - - - name: Generate Coverage Report - run: uv run coverage report -m + - name: Run pre-commit hooks + run: uv run pre-commit run --all-files diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index 45fbd669..99baeb36 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -1,16 +1,16 @@ name: Run tests on: - push: - branches: [main] pull_request: branches: [main, develop] + push: + branches: [main] permissions: contents: read jobs: - test: + build: runs-on: ubuntu-latest strategy: matrix: From 9be0c8ead25ee2d22ce7443a542a3725a0923db2 Mon Sep 17 00:00:00 2001 From: Taylor Mandelbaum Date: Fri, 23 Jan 2026 21:25:26 -0500 Subject: [PATCH 09/15] Cleanup docstrings in repo (#318) * update these docstrings * remove docstring changes markdown * update docstrings * update other docstrings * remove individualcasecollection reference, update based on develop changes --- src/extremeweatherbench/calc.py | 2 +- src/extremeweatherbench/cases.py | 62 ++- src/extremeweatherbench/derived.py | 96 ++-- src/extremeweatherbench/evaluate.py | 15 +- src/extremeweatherbench/evaluate_cli.py | 19 +- src/extremeweatherbench/inputs.py | 94 +++- src/extremeweatherbench/metrics.py | 581 +++++++++++++++--------- src/extremeweatherbench/regions.py | 96 ++-- src/extremeweatherbench/sources/base.py | 12 +- 9 files changed, 597 insertions(+), 380 deletions(-) diff --git a/src/extremeweatherbench/calc.py b/src/extremeweatherbench/calc.py index ef314349..28fcf94f 100644 --- a/src/extremeweatherbench/calc.py +++ b/src/extremeweatherbench/calc.py @@ -259,7 +259,7 @@ def geopotential_thickness( pressure_dim: The name of the pressure dimension. Default is "level". Returns: - The geopotential thickness in metersas an xarray DataArray. + The geopotential thickness in meters as an xarray DataArray. """ geopotential_heights = da.sel({pressure_dim: top_level}) geopotential_height_bottom = da.sel({pressure_dim: bottom_level}) diff --git a/src/extremeweatherbench/cases.py b/src/extremeweatherbench/cases.py index a7800dc5..3f2e858c 100644 --- a/src/extremeweatherbench/cases.py +++ b/src/extremeweatherbench/cases.py @@ -24,18 +24,15 @@ @dataclasses.dataclass class IndividualCase: - """Container for metadata defining a single or individual case. - - An IndividualCase defines the relevant metadata for a single case study for a - given extreme weather event; it is designed to be easily instantiable through a - simple YAML-based configuration file. + """Container for metadata defining a single case study. Attributes: - case_id_number: A unique numerical identifier for the event. - start_date: The start date of the case, for use in subsetting data for analysis. - end_date: The end date of the case, for use in subsetting data for analysis. - location: A Location dataclass representing the location of a case. - event_type: A string representing the type of extreme weather event. + case_id_number: Unique numerical identifier for the event. + title: Title of the case study. + start_date: Start date for subsetting data for analysis. + end_date: End date for subsetting data for analysis. + location: Region object representing the case location. + event_type: String representing the type of extreme weather event. """ case_id_number: int @@ -48,18 +45,13 @@ class IndividualCase: @dataclasses.dataclass class CaseOperator: - """A class which stores the graph to process an individual case. - - This class is used to store the graph to process an individual case. The purpose of - this class is to be a one-stop-shop for the evaluation of a single case. Multiple - CaseOperators can be run in parallel to evaluate multiple cases, or run through the - ExtremeWeatherBench.run() method to evaluate all cases in an evaluation in serial. + """Operator dataclass for an evaluation of a single evaluation object. Attributes: - case_metadata: IndividualCase metadata - metric_list: A list of metrics that are to be evaluated for the case operator - target_config: A TargetConfig object - forecast_config: A ForecastConfig object + case_metadata: IndividualCase metadata for this operator. + metric_list: List of metrics to evaluate for this case. + target: TargetBase object for ground truth data. + forecast: ForecastBase object for forecast data. """ case_metadata: IndividualCase @@ -75,8 +67,7 @@ def build_case_operators( """Build a CaseOperator from the case metadata and metric evaluation objects. Args: - cases: The case metadata to use for the case operators as a dictionary of cases - or a list of IndividualCases. + case_list: List of IndividualCase objects defining cases to process. evaluation_objects: The evaluation objects to apply to the case operators. Returns: @@ -108,7 +99,7 @@ def load_individual_cases( Will pass through existing IndividualCase objects and convert dictionaries to IndividualCase objects. Args: - cases: A dictionary of cases based on the IndividualCase dataclass. + cases: A list of cases as either dicts or IndividualCase objects. Returns: A list of IndividualCase objects. @@ -146,19 +137,18 @@ def load_individual_cases_from_yaml( Example of a yaml file: ```yaml - cases: - - case_id_number: 1 - title: Event 1 - start_date: 2021-01-01 00:00:00 - end_date: 2021-01-03 00:00:00 - location: - type: bounded_region - parameters: - latitude_min: 10.0 - latitude_max: 55.6 - longitude_min: 265.0 - longitude_max: 283.3 - event_type: tropical_cyclone + - case_id_number: 1 + title: Event 1 + start_date: 2021-01-01 00:00:00 + end_date: 2021-01-03 00:00:00 + location: + type: bounded_region + parameters: + latitude_min: 10.0 + latitude_max: 55.6 + longitude_min: 265.0 + longitude_max: 283.3 + event_type: tropical_cyclone ``` Args: diff --git a/src/extremeweatherbench/derived.py b/src/extremeweatherbench/derived.py index 94c97a8d..0e609fbe 100644 --- a/src/extremeweatherbench/derived.py +++ b/src/extremeweatherbench/derived.py @@ -16,21 +16,26 @@ class DerivedVariable(abc.ABC): - """An abstract base class defining the interface for ExtremeWeatherBench - derived variables. - - A DerivedVariable is any variable or transform that requires extra computation than - what is provided in analysis or forecast data. Some examples include the - practically perfect hindcast, MLCAPE, IVT, or atmospheric river masks. - - Attributes: - variables: A list of variables that are used to build the - derived variable. - output_variables: Optional list of variable names that specify - which outputs to use from the derived computation. - compute: A method that generates the derived variable from the variables. - derive_variable: An abstract method that defines the computation to - derive the derived_variable from variables. + """Abstract base class for ExtremeWeatherBench derived variables. + + A DerivedVariable is any variable or transform that requires extra + computation beyond what is provided in analysis or forecast data. Examples + include the practically perfect hindcast, MLCAPE, IVT, or atmospheric + river masks. + + Class attributes: + variables: List of variables used to build the derived variable + + Instance attributes: + name: The name of the derived variable + output_variables: Optional list of variable names specifying which + outputs to use from the derived computation + + Public methods: + compute: Build the derived variable from input variables + + Abstract methods: + derive_variable: Define the computation to derive the variable """ variables: List[str] @@ -81,33 +86,28 @@ def compute(self, data: xr.Dataset, *args, **kwargs) -> xr.DataArray: class TropicalCycloneTrackVariables(DerivedVariable): - """A derived variable abstract class for tropical cyclone (TC) variables. - - This class serves as a parent for TC-related derived variables and provides - shared track computation with caching to avoid reprocessing the same data - multiple times across different child classes. + """Derived variable class for tropical cyclone track-based variables. - The track data is computed once and cached, then child classes can extract - specific variables (like sea level pressure, wind speed) from the cached - track dataset. + Extends DerivedVariable to provide shared track computation with caching + for TC-related derived variables, avoiding reprocessing across child + classes. Track data is computed once and cached, then child classes can + extract specific variables (sea level pressure, wind speed, etc.). - Deriving the track locations using default TempestExtremes criteria: + Uses default TempestExtremes criteria for track identification: https://doi.org/10.5194/gmd-14-5023-2021 - For forecast data, when track data is provided, the valid candidates - approach is filtered to only include candidates within 5 great circle - degrees of track data points and within 48 hours of the valid_time. + For forecasts with track data, valid candidates are filtered to include + only those within 5 great circle degrees and 48 hours of track points. - Track data is automatically obtained from the target dataset when using - the evaluation pipeline (via `requires_target_dataset=True` flag). + Track data is automatically obtained from target dataset via + `requires_target_dataset=True` flag in evaluation pipeline. - Attributes: - output_variables: Optional list of variable names that specify - which outputs to use from the derived computation. - name: The name of the derived variable. Defaults to class-level - name attribute if present, otherwise the class name. - requires_target_dataset: If True, target dataset will be passed to - this derived variable via kwargs. + Class attributes: + requires_target_dataset: If True, target dataset passed via kwargs + + Instance attributes: + output_variables: Optional list specifying which outputs to use + name: Name of the derived variable """ # required variables for TC track identification @@ -287,8 +287,10 @@ def derive_variable(self, data: xr.Dataset, *args, **kwargs) -> xr.DataArray: class CravenBrooksSignificantSevere(DerivedVariable): - """A derived variable that computes the Craven-Brooks significant severe - convection index. + """Derived variable for Craven-Brooks significant severe convection index. + + Extends DerivedVariable to compute the Craven-Brooks index for assessing + significant severe convection potential. """ variables = [ @@ -391,18 +393,18 @@ def derive_variable( class AtmosphericRiverVariables(DerivedVariable): - """A derived variable that computes atmospheric river related variables. + """Derived variable for atmospheric river detection and characterization. - Calculates the IVT (Integrated Vapor Transport), atmospheric river mask, and land - intersection. IVT is calculated using the method described in Newell et al. 1992 and - elsewhere (e.g. Mo 2024). + Extends DerivedVariable to compute IVT (Integrated Vapor Transport), + atmospheric river mask, and land intersection. IVT calculation follows + Newell et al. 1992 and elsewhere (e.g. Mo 2024). - Output variables are: integrated_vapor_transport, atmospheric_river_mask, and - atmospheric_river_land_intersection. Users must declare at least one of the output - variables they want when calling the derived variable. + Output variables: integrated_vapor_transport, atmospheric_river_mask, + atmospheric_river_land_intersection. Users must declare at least one + output variable when calling the derived variable. - The Laplacian of IVT is calculated using a Gaussian blurring kernel with a - sigma of 3 grid points, meant to smooth out 0.25 degree grid scale features. + The Laplacian of IVT uses a Gaussian blurring kernel with sigma of 3 + grid points to smooth 0.25 degree grid scale features. """ variables = [ diff --git a/src/extremeweatherbench/evaluate.py b/src/extremeweatherbench/evaluate.py index ad7953eb..64c0ba96 100644 --- a/src/extremeweatherbench/evaluate.py +++ b/src/extremeweatherbench/evaluate.py @@ -45,7 +45,7 @@ class ExtremeWeatherBench: results. Attributes: - case_metadata: A dictionary of cases or a list of IndividualCase objects to run. + case_metadata: A list of case dicts or IndividualCase objects to run. evaluation_objects: A list of evaluation objects to run. cache_dir: An optional directory to cache the mid-flight outputs of the workflow for serial runs. @@ -61,6 +61,16 @@ def __init__( cache_dir: Optional[Union[str, pathlib.Path]] = None, region_subsetter: Optional["regions.RegionSubsetter"] = None, ): + """Initialize the ExtremeWeatherBench workflow. + + Args: + case_metadata: List of case dicts or IndividualCase objects. + evaluation_objects: List of evaluation objects to run. + cache_dir: Optional directory for caching mid-flight outputs in + serial runs. + region_subsetter: Optional RegionSubsetter to filter cases by + spatial region. + """ # Load the case metadata from the input self.case_metadata = cases.load_individual_cases(case_metadata) self.evaluation_objects = evaluation_objects @@ -138,7 +148,8 @@ def _run_case_operators( Args: case_operators: List of case operators to run. cache_dir: Optional directory for caching (serial mode only). - **kwargs: Additional arguments, may include 'parallel_config' dict. + parallel_config: Optional dict of joblib parallel configuration. + **kwargs: Additional keyword arguments passed to case operators. Returns: List of result DataFrames. diff --git a/src/extremeweatherbench/evaluate_cli.py b/src/extremeweatherbench/evaluate_cli.py index 3216a1ef..9b978269 100644 --- a/src/extremeweatherbench/evaluate_cli.py +++ b/src/extremeweatherbench/evaluate_cli.py @@ -79,18 +79,17 @@ def cli_runner( save CaseOperator objects for later use or inspection. Args: - default: Use default Brightband evaluation objects with current directory as - output - config_file: Path to a config.py file containing evaluation objects - output_dir: Directory for analysis outputs (default: current directory) + default: Use default Brightband evaluation objects with current directory + as output. + config_file: Path to a config.py file containing evaluation objects. + output_dir: Directory for analysis outputs (default: current directory). cache_dir: Optional directory for caching intermediate data. When set, datasets or dataarrays are computed and cached as zarrs. - parallel_config: Parallel configuration using joblib (default: {'backend': - 'threading', 'n_jobs': 8}) - save_case_operators: Save CaseOperator objects to a pickle file at this path - n_jobs: Number of parallel jobs to run (default: 1 for serial execution) - parallel_config: Advanced parallel configuration using joblib. Takes precedence - over --n-jobs if provided. + n_jobs: Number of parallel jobs to run (default: 1 for serial execution). + parallel_config: Advanced parallel configuration using joblib. Takes + precedence over n_jobs if provided. + save_case_operators: Save CaseOperator objects to a pickle file at this + path. Examples: # Use default evaluation objects $ ewb --default diff --git a/src/extremeweatherbench/inputs.py b/src/extremeweatherbench/inputs.py index 86f99ff0..cd0f52cd 100644 --- a/src/extremeweatherbench/inputs.py +++ b/src/extremeweatherbench/inputs.py @@ -153,15 +153,29 @@ def _default_preprocess(input_data: IncomingDataInput) -> IncomingDataInput: @dataclasses.dataclass class InputBase(abc.ABC): - """An abstract base dataclass for target and forecast data. + """Abstract base dataclass for target and forecast data. + + This class provides the foundational interface for loading and processing + forecast and target datasets in ExtremeWeatherBench. Attributes: - source: The source of the data, which can be a local path or a remote URL/URI. + source: The source of the data, which can be a local path or a + remote URL/URI. name: The name of the input data source. variables: A list of variables to select from the data. variable_mapping: A dictionary of variable names to map to the data. storage_options: Storage/access options for the data. preprocess: A function to preprocess the data. + + Public methods: + open_and_maybe_preprocess_data_from_source: Open and preprocess data + maybe_convert_to_dataset: Convert input data to xarray Dataset + add_source_to_dataset_attrs: Add source name to dataset attributes + maybe_map_variable_names: Map variable names if mapping provided + + Abstract methods: + _open_data_from_source: Open the input data from source + subset_data_to_case: Subset data to case metadata """ source: str @@ -304,7 +318,17 @@ def maybe_map_variable_names(self, data: IncomingDataInput) -> IncomingDataInput @dataclasses.dataclass class ForecastBase(InputBase): - """A class defining the interface for ExtremeWeatherBench forecast data.""" + """Forecast data interface for ExtremeWeatherBench. + + Extends InputBase to provide functionality for forecast datasets with + init_time and lead_time dimensions. + + Attributes: + chunks: Chunking strategy for dask arrays. Defaults to "auto". + + Public methods: + subset_data_to_case: Subset forecast data to case (overrides parent) + """ chunks: Optional[Union[dict, str]] = "auto" @@ -401,7 +425,11 @@ class EvaluationObject: @dataclasses.dataclass class KerchunkForecast(ForecastBase): - """Forecast class for kerchunked forecast data.""" + """Forecast class for kerchunk-referenced forecast data. + + Extends ForecastBase for forecast data accessed via kerchunk references, + enabling efficient access to cloud-optimized datasets. + """ chunks: Optional[Union[dict, str]] = "auto" storage_options: dict = dataclasses.field(default_factory=dict) @@ -416,7 +444,10 @@ def _open_data_from_source(self) -> IncomingDataInput: @dataclasses.dataclass class ZarrForecast(ForecastBase): - """Forecast class for zarr forecast data.""" + """Forecast class for zarr-format forecast data. + + Extends ForecastBase for forecast data stored in zarr format. + """ chunks: Optional[Union[dict, str]] = "auto" @@ -431,11 +462,11 @@ def _open_data_from_source(self) -> IncomingDataInput: @dataclasses.dataclass class XarrayForecast(ForecastBase): - """Forecast class for datasets that were previously constructed and opened using xarray. + """Forecast class for pre-opened xarray datasets. - This class is intended for situations where the user has to manually prepare a dataset to - use in their evaluation. This can happen when the user is manually constructed such a - dataset from a collection of NetCDF or Zarr archives which need to be assembled into a + Extends ForecastBase for datasets previously constructed and opened using + xarray. Intended for situations where users manually prepare datasets from + collections of NetCDF or Zarr archives that need assembly into a single, master dataset. Attributes: @@ -480,12 +511,15 @@ def _open_data_from_source(self) -> xr.Dataset: @dataclasses.dataclass class TargetBase(InputBase): - """An abstract base class for target data. + """Target (truth) data interface for ExtremeWeatherBench. + + Extends InputBase to provide functionality for target datasets that serve + as ground truth for evaluation. Target data can be gridded datasets, point + observations, or any reference dataset. Targets need not match forecast + variables but must share a compatible coordinate system for evaluation. - A TargetBase is data that acts as the "truth" for a case. It can be a gridded - dataset, a point observation dataset, or any other reference dataset. Targets in EWB - are not required to be the same variable as the forecast dataset, but they must be - in the same coordinate system for evaluation. + Public methods: + maybe_align_forecast_to_target: Align forecast to target coordinates """ def maybe_align_forecast_to_target( @@ -513,8 +547,10 @@ def maybe_align_forecast_to_target( @dataclasses.dataclass class ERA5(TargetBase): - """Target class for ERA5 gridded data, ideally using the ARCO ERA5 dataset provided - by Google. Otherwise, either a different zarr source for ERA5. + """Target class for ERA5 gridded reanalysis data. + + Extends TargetBase for ERA5 data, optimized for the ARCO ERA5 dataset + provided by Google or other zarr-based ERA5 sources. """ name: str = "ERA5" @@ -569,10 +605,10 @@ def maybe_align_forecast_to_target( @dataclasses.dataclass class GHCN(TargetBase): - """Target class for GHCN tabular data. + """Target class for GHCN (Global Historical Climatology Network) data. - Data is processed using polars to maintain the lazy loading paradigm in - open_data_from_source and to separate the subsetting into subset_data_to_case. + Extends TargetBase for tabular GHCN station observation data. Uses polars + for lazy loading and efficient subsetting of large tabular datasets. """ name: str = "GHCN" @@ -643,10 +679,11 @@ def maybe_align_forecast_to_target( @dataclasses.dataclass class LSR(TargetBase): - """Target class for local storm report (LSR) tabular data. + """Target class for Local Storm Report (LSR) tabular data. - run_pipeline() returns a dataset with LSRs as mapped to numeric values (1=wind, 2=hail, 3=tor). IndividualCase date ranges for LSRs should be 12 UTC to - the next day at 12 UTC (exclusive) to match SPC's reporting window. + Extends TargetBase for SPC local storm reports. Returns dataset with LSRs + mapped to numeric values (1=wind, 2=hail, 3=tornado). IndividualCase date + ranges should be 12 UTC to next day 12 UTC to match SPC reporting window. """ name: str = "local_storm_reports" @@ -745,7 +782,10 @@ def maybe_align_forecast_to_target( # TODO: get PPH connector working properly @dataclasses.dataclass class PPH(TargetBase): - """Target class for practically perfect hindcast data.""" + """Target class for Practically Perfect Hindcast (PPH) data. + + Extends TargetBase for practically perfect hindcast datasets. + """ name: str = "practically_perfect_hindcast" source: str = PPH_URI @@ -877,7 +917,11 @@ def _ibtracs_preprocess(data: IncomingDataInput) -> IncomingDataInput: @dataclasses.dataclass class IBTrACS(TargetBase): - """Target class for IBTrACS data.""" + """Target class for IBTrACS tropical cyclone best track data. + + Extends TargetBase for International Best Track Archive for Climate + Stewardship (IBTrACS) tropical cyclone track and intensity data. + """ name: str = "IBTrACS" preprocess: Callable = _ibtracs_preprocess @@ -1062,6 +1106,7 @@ def open_icechunk_dataset_from_datatree( group: The group within the datatree to open. branch: The icechunk branch to open. Defaults to "main". chunks: The chunk pattern for the datatree. defaults to "auto". + Returns: The dataset for the specified group. """ @@ -1086,6 +1131,7 @@ def zarr_target_subsetter( data: The dataset to subset. case_metadata: The case metadata to subset the dataset to. time_variable: The time variable to use; defaults to "valid_time". + drop: Whether to drop masked values. Defaults to False. Returns: The subset dataset. diff --git a/src/extremeweatherbench/metrics.py b/src/extremeweatherbench/metrics.py index 21bed2b5..873009af 100644 --- a/src/extremeweatherbench/metrics.py +++ b/src/extremeweatherbench/metrics.py @@ -50,19 +50,22 @@ def _compute_metric_with_docstring(self, *args, **kwargs): class BaseMetric(abc.ABC, metaclass=ComputeDocstringMetaclass): - """A BaseMetric class is an abstract class that defines the foundational interface - for all metrics. - - Metrics are general operations applied between a forecast and analysis xarray - DataArray. EWB metrics prioritize the use of any arbitrary sets of forecasts and - analyses, so long as the spatiotemporal dimensions are the same. - - Args: - name: The name of the metric. - preserve_dims: The dimensions to preserve in the computation. Defaults to - "lead_time". - forecast_variable: The forecast variable to use in the computation. - target_variable: The target variable to use in the computation. + """Abstract base class defining the foundational interface for all metrics. + + Metrics are general operations applied between forecast and analysis xarray + DataArrays. EWB metrics prioritize the use of any arbitrary sets of + forecasts and analyses, so long as the spatiotemporal dimensions are the + same. + + Public methods: + compute_metric: Public interface to compute the metric + maybe_expand_composite: Expand composite metrics into individual metrics + is_composite: Check if this is a composite metric + __repr__: String representation of the metric + __eq__: Check equality with another metric + + Abstract methods: + _compute_metric: Logic to compute the metric (must be implemented) """ def __init__( @@ -72,6 +75,16 @@ def __init__( forecast_variable: Optional[str | derived.DerivedVariable] = None, target_variable: Optional[str | derived.DerivedVariable] = None, ): + """Initialize the base metric. + + Args: + name: The name of the metric. + preserve_dims: The dimensions to preserve in the computation. + Defaults to "lead_time". + forecast_variable: The forecast variable to use in the + computation. + target_variable: The target variable to use in the computation. + """ # Store the original variables (str or DerivedVariable instances) # Do NOT convert to string to preserve output_variables info self.name = name @@ -179,13 +192,27 @@ def maybe_prepare_composite_kwargs( class CompositeMetric(BaseMetric): - """Base class for composite metrics. + """Base class for composite metrics that can contain multiple sub-metrics. + + Extends BaseMetric to provide functionality for composite metrics that + aggregate multiple individual metrics for efficient evaluation. - This class provides common functionality for composite metrics. - Accepts the same arguments as BaseMetric. + Public methods: + maybe_expand_composite: Expand into individual metrics (overrides base) + is_composite: Check if has sub-metrics (overrides base) + + Abstract methods: + maybe_prepare_composite_kwargs: Prepare kwargs for composite evaluation + _compute_metric: Compute the metric (must be implemented by subclasses) """ def __init__(self, *args, **kwargs): + """Initialize the composite metric. + + Args: + *args: Positional arguments passed to BaseMetric.__init__ + **kwargs: Keyword arguments passed to BaseMetric.__init__ + """ super().__init__(*args, **kwargs) self._metric_instances: list["BaseMetric"] = [] @@ -242,36 +269,31 @@ def _compute_metric( class ThresholdMetric(CompositeMetric): - """Base class for threshold-based metrics. - - This class provides common functionality for metrics that require - forecast and target thresholds for binarization. - - Args: - name: The name of the metric. Defaults to "threshold_metrics". - preserve_dims: The dimensions to preserve in the computation. Defaults to - "lead_time". - forecast_variable: The forecast variable to use in the computation. - target_variable: The target variable to use in the computation. - forecast_threshold: The threshold for binarizing the forecast. Defaults to 0.5. - target_threshold: The threshold for binarizing the target. Defaults to 0.5. - metrics: A list of metrics to use as a composite. Defaults to None. - - Can be used in two ways: - 1. As a base class for specific threshold metrics (CriticalSuccessIndex, - FalseAlarmRatio, etc.) - 2. As a composite metric to compute multiple threshold metrics - efficiently by reusing the transformed contingency manager. - - Example of composite usage: + """Base class for threshold-based metrics with binary classification. + + Extends CompositeMetric to provide functionality for metrics that require + forecast and target thresholds for binarization. Can be used as a base + class for specific threshold metrics or as a composite metric. + + Public methods: + transformed_contingency_manager: Create contingency manager + maybe_prepare_composite_kwargs: Prepare kwargs (overrides parent) + __call__: Make instances callable with configured thresholds + + Abstract methods: + _compute_metric: Compute the metric (must be implemented by subclasses) + + Usage patterns: + 1. As a base class for specific metrics (CriticalSuccessIndex, etc.) + 2. As a composite metric to compute multiple threshold metrics + efficiently by reusing the transformed contingency manager + + Example: composite = ThresholdMetric( metrics=[CriticalSuccessIndex, FalseAlarmRatio, Accuracy], forecast_threshold=0.7, target_threshold=0.5 ) - results = composite.compute_metric(forecast, target) - # Returns: {"critical_success_index": ..., - # "false_alarm_ratio": ..., "accuracy": ...} """ def __init__( @@ -285,6 +307,23 @@ def __init__( metrics: Optional[list[Type["ThresholdMetric"]]] = None, **kwargs, ): + """Initialize the threshold metric. + + Args: + name: The name of the metric. Defaults to "threshold_metrics". + preserve_dims: The dimensions to preserve in the computation. + Defaults to "lead_time". + forecast_variable: The forecast variable to use in the + computation. + target_variable: The target variable to use in the computation. + forecast_threshold: The threshold for binarizing the forecast. + Defaults to 0.5. + target_threshold: The threshold for binarizing the target. + Defaults to 0.5. + metrics: A list of metrics to use as a composite. Defaults to + None. + **kwargs: Additional keyword arguments passed to parent. + """ super().__init__( name, preserve_dims=preserve_dims, @@ -430,16 +469,22 @@ def _compute_metric( class CriticalSuccessIndex(ThresholdMetric): - """Critical Success Index metric. + """Compute Critical Success Index (CSI) from binary classifications. - The Critical Success Index is computed between the forecast and target using the - preserve_dims dimensions. - - Args: - name: The name of the metric. Defaults to "CriticalSuccessIndex". + Extends ThresholdMetric to compute CSI between forecast and target using + the preserve_dims dimensions. CSI measures the fraction of correctly + predicted events. """ def __init__(self, name: str = "CriticalSuccessIndex", *args, **kwargs): + """Initialize the Critical Success Index metric. + + Args: + name: The name of the metric. Defaults to + "CriticalSuccessIndex". + *args: Additional positional arguments passed to ThresholdMetric. + **kwargs: Additional keyword arguments passed to ThresholdMetric. + """ super().__init__(name, *args, **kwargs) def _compute_metric( @@ -462,16 +507,21 @@ def _compute_metric( class FalseAlarmRatio(ThresholdMetric): - """False Alarm Ratio metric. - - The False Alarm Ratio is computed between the forecast and target using the - preserve_dims dimensions. Note that this is not the same as the False Alarm Rate. + """Compute False Alarm Ratio (FAR) from binary classifications. - Args: - name: The name of the metric. Defaults to "FalseAlarmRatio". + Extends ThresholdMetric to compute FAR between forecast and target using + the preserve_dims dimensions. FAR measures the fraction of predicted + events that did not occur. Note: FAR is not the same as False Alarm Rate. """ def __init__(self, name: str = "FalseAlarmRatio", *args, **kwargs): + """Initialize the False Alarm Ratio metric. + + Args: + name: The name of the metric. Defaults to "FalseAlarmRatio". + *args: Additional positional arguments passed to ThresholdMetric. + **kwargs: Additional keyword arguments passed to ThresholdMetric. + """ super().__init__(name, *args, **kwargs) def _compute_metric( @@ -494,16 +544,21 @@ def _compute_metric( class TruePositives(ThresholdMetric): - """True Positive ratio. - - The True Positive is the number of times the forecast is a true positive (top right - cell in the contingency table) divided by the total number of observations. + """Compute True Positive ratio from binary classifications. - Args: - name: The name of the metric. Defaults to "TruePositives". + Extends ThresholdMetric to compute the ratio of true positives (correctly + predicted events) to the total number of observations. Corresponds to the + top right cell in the contingency table. """ def __init__(self, name: str = "TruePositives", *args, **kwargs): + """Initialize the True Positives metric. + + Args: + name: The name of the metric. Defaults to "TruePositives". + *args: Additional positional arguments passed to ThresholdMetric. + **kwargs: Additional keyword arguments passed to ThresholdMetric. + """ super().__init__(name, *args, **kwargs) def _compute_metric( @@ -527,16 +582,20 @@ def _compute_metric( class FalsePositives(ThresholdMetric): - """False Positive ratio. + """Compute False Positive ratio from binary classifications. - The False Positive is the number of times the forecast is a false positive divided - by the total number of observations. - - Args: - name: The name of the metric. Defaults to "FalsePositives". + Extends ThresholdMetric to compute the ratio of false positives + (incorrectly predicted events) to the total number of observations. """ def __init__(self, name: str = "FalsePositives", *args, **kwargs): + """Initialize the False Positives metric. + + Args: + name: The name of the metric. Defaults to "FalsePositives". + *args: Additional positional arguments passed to ThresholdMetric. + **kwargs: Additional keyword arguments passed to ThresholdMetric. + """ super().__init__(name, *args, **kwargs) def _compute_metric( @@ -560,16 +619,20 @@ def _compute_metric( class TrueNegatives(ThresholdMetric): - """True Negative ratio. - - The True Negative is the number of times the forecast is a true negative divided by - the total number of observations. + """Compute True Negative ratio from binary classifications. - Args: - name: The name of the metric. Defaults to "TrueNegatives". + Extends ThresholdMetric to compute the ratio of true negatives (correctly + predicted non-events) to the total number of observations. """ def __init__(self, name: str = "TrueNegatives", *args, **kwargs): + """Initialize the True Negatives metric. + + Args: + name: The name of the metric. Defaults to "TrueNegatives". + *args: Additional positional arguments passed to ThresholdMetric. + **kwargs: Additional keyword arguments passed to ThresholdMetric. + """ super().__init__(name, *args, **kwargs) def _compute_metric( @@ -593,16 +656,21 @@ def _compute_metric( class FalseNegatives(ThresholdMetric): - """False Negative ratio. - - The False Negative is the number of times the forecast is a false negative (top left - cell in the contingency table) divided by the total number of observations. + """Compute False Negative ratio from binary classifications. - Args: - name: The name of the metric. Defaults to "FalseNegatives". + Extends ThresholdMetric to compute the ratio of false negatives (missed + events) to the total number of observations. Corresponds to the top left + cell in the contingency table. """ def __init__(self, name: str = "FalseNegatives", *args, **kwargs): + """Initialize the False Negatives metric. + + Args: + name: The name of the metric. Defaults to "FalseNegatives". + *args: Additional positional arguments passed to ThresholdMetric. + **kwargs: Additional keyword arguments passed to ThresholdMetric. + """ super().__init__(name, *args, **kwargs) def _compute_metric( @@ -626,17 +694,21 @@ def _compute_metric( class Accuracy(ThresholdMetric): - """Accuracy metric. + """Compute classification accuracy from binary classifications. - The Accuracy is the number of times the forecast is correct (top right or bottom - right cell in the contingency table) divided by the total number of observations, or - (true positives + true negatives) / (total number of samples). - - Args: - name: The name of the metric. Defaults to "Accuracy". + Extends ThresholdMetric to compute the ratio of correct predictions (true + positives + true negatives) to the total number of observations. Measures + overall correctness of the forecast. """ def __init__(self, name: str = "Accuracy", *args, **kwargs): + """Initialize the Accuracy metric. + + Args: + name: The name of the metric. Defaults to "Accuracy". + *args: Additional positional arguments passed to ThresholdMetric. + **kwargs: Additional keyword arguments passed to ThresholdMetric. + """ super().__init__(name, *args, **kwargs) def _compute_metric( @@ -659,19 +731,10 @@ def _compute_metric( class MeanSquaredError(BaseMetric): - """Mean Squared Error metric. - - Args: - name: The name of the metric. Defaults to "MeanSquaredError". - interval_where_one: From scores, endpoints of the interval where the threshold - weights are 1. Must be increasing. Infinite endpoints are permissible. By - supplying a tuple of arrays, endpoints can vary with dimension. - interval_where_positive: From scores, endpoints of the interval where the - threshold weights are positive. Must be increasing. Infinite endpoints are - only permissible when the corresponding interval_where_one endpoint is - infinite. By supplying a tuple of arrays, endpoints can vary with dimension. - weights: From scores, an array of weights to apply to the score (e.g., weighting - a grid by latitude). If None, no weights are applied. + """Compute Mean Squared Error between forecast and target. + + Extends BaseMetric to calculate MSE with optional interval-based + weighting and custom weights for spatial/temporal averaging. """ def __init__( @@ -687,6 +750,20 @@ def __init__( *args, **kwargs, ): + """Initialize the Mean Squared Error metric. + + Args: + name: The name of the metric. Defaults to "MeanSquaredError". + interval_where_one: Endpoints of the interval where threshold + weights are 1. Must be increasing. Infinite endpoints + permissible. + interval_where_positive: Endpoints of the interval where threshold + weights are positive. Must be increasing. + weights: Array of weights to apply to the score (e.g., latitude + weighting). If None, no weights are applied. + *args: Additional positional arguments passed to BaseMetric. + **kwargs: Additional keyword arguments passed to BaseMetric. + """ super().__init__(name, *args, **kwargs) self.interval_where_one = interval_where_one self.interval_where_positive = interval_where_positive @@ -711,19 +788,10 @@ def _compute_metric( class MeanAbsoluteError(BaseMetric): - """Mean Absolute Error metric. - - Args: - name: The name of the metric. Defaults to "MeanAbsoluteError". - interval_where_one: From scores, endpoints of the interval where the threshold - weights are 1. Must be increasing. Infinite endpoints are permissible. By - supplying a tuple of arrays, endpoints can vary with dimension. - interval_where_positive: From scores, endpoints of the interval where the - threshold weights are positive. Must be increasing. Infinite endpoints are - only permissible when the corresponding interval_where_one endpoint is - infinite. By supplying a tuple of arrays, endpoints can vary with dimension. - weights: From scores, an array of weights to apply to the score (e.g., weighting - a grid by latitude). If None, no weights are applied. + """Compute Mean Absolute Error between forecast and target. + + Extends BaseMetric to calculate MAE with optional interval-based + weighting and custom weights for spatial/temporal averaging. """ def __init__( @@ -739,6 +807,20 @@ def __init__( *args, **kwargs, ): + """Initialize the Mean Absolute Error metric. + + Args: + name: The name of the metric. Defaults to "MeanAbsoluteError". + interval_where_one: Endpoints of the interval where threshold + weights are 1. Must be increasing. Infinite endpoints + permissible. + interval_where_positive: Endpoints of the interval where threshold + weights are positive. Must be increasing. + weights: Array of weights to apply to the score (e.g., latitude + weighting). If None, no weights are applied. + *args: Additional positional arguments passed to BaseMetric. + **kwargs: Additional keyword arguments passed to BaseMetric. + """ self.interval_where_one = interval_where_one self.interval_where_positive = interval_where_positive self.weights = weights @@ -772,16 +854,20 @@ def _compute_metric( class MeanError(BaseMetric): - """Mean Error (bias) metric. + """Compute Mean Error (bias) between forecast and target. - The mean error (or mean bias error) is computed between the forecast and target - using the preserve_dims dimensions. - - Args: - name: The name of the metric. Defaults to "MeanError". + Extends BaseMetric to calculate mean error (bias) using the preserve_dims + dimensions. Positive values indicate forecast exceeds target. """ def __init__(self, name: str = "MeanError", *args, **kwargs): + """Initialize the Mean Error metric. + + Args: + name: The name of the metric. Defaults to "MeanError". + *args: Additional positional arguments passed to BaseMetric. + **kwargs: Additional keyword arguments passed to BaseMetric. + """ super().__init__(name, *args, **kwargs) def _compute_metric( @@ -805,16 +891,20 @@ def _compute_metric( class RootMeanSquaredError(BaseMetric): - """Root Mean Square Error metric. - - The Root Mean Square Error is computed between the forecast and target using the - preserve_dims dimensions. + """Compute Root Mean Squared Error between forecast and target. - Args: - name: The name of the metric. Defaults to "RootMeanSquaredError". + Extends BaseMetric to calculate RMSE using the preserve_dims dimensions. + RMSE is the square root of the mean squared error. """ def __init__(self, name: str = "RootMeanSquaredError", *args, **kwargs): + """Initialize the Root Mean Squared Error metric. + + Args: + name: The name of the metric. Defaults to "RootMeanSquaredError". + *args: Additional positional arguments passed to BaseMetric. + **kwargs: Additional keyword arguments passed to BaseMetric. + """ super().__init__(name, *args, **kwargs) def _compute_metric( @@ -838,20 +928,11 @@ def _compute_metric( class EarlySignal(BaseMetric): - """Early Signal detection metric. - - This metric finds the first occurrence where a signal is detected based on - threshold criteria and returns the corresponding init_time, lead_time, and - valid_time information. The metric is designed to be flexible for different - signal detection criteria that can be specified in applied metrics downstream. - - Args: - name: The name of the metric. - comparison_operator: The comparison operator to use for signal detection. - threshold: The threshold value for signal detection. - spatial_aggregation: The spatial aggregation method to use for signal detection. - Options are "any" (any gridpoint meets criteria), "all" (all gridpoints - meet criteria), or "half" (at least half of gridpoints meet criteria). + """Detect first occurrence of signal exceeding threshold criteria. + + Extends BaseMetric to find the earliest time when a signal is detected + based on threshold criteria, returning init_time, lead_time, and + valid_time information. Flexible for different signal detection criteria. """ def __init__( @@ -864,6 +945,17 @@ def __init__( spatial_aggregation: Literal["any", "all", "half"] = "any", **kwargs, ): + """Initialize the Early Signal detection metric. + + Args: + name: The name of the metric. Defaults to "EarlySignal". + comparison_operator: The comparison operator for signal detection. + threshold: The threshold value for signal detection. + spatial_aggregation: Spatial aggregation method. Options: "any" + (any gridpoint meets criteria), "all" (all gridpoints meet + criteria), or "half" (at least half meet criteria). + **kwargs: Additional keyword arguments passed to BaseMetric. + """ # Extract threshold params before passing to super self.comparison_operator = utils.maybe_get_operator(comparison_operator) self.threshold = threshold @@ -929,19 +1021,11 @@ def _compute_metric( class MaximumMeanAbsoluteError(MeanAbsoluteError): - """Computes the mean absolute error between the forecast and target maximum values. - - The forecast is filtered to a time window around the target's maximum using - tolerance_range_hours (in the event of variation between the timing between the - target and forecast maximum values). The mean absolute error is computed between the - filtered forecast and target maximum value. - - Args: - tolerance_range_hours: The time window (hours) around the target's maximum - value to search for forecast minimum. Defaults to 24 hours. - reduce_spatial_dims: The spatial dimensions to reduce. Defaults to - ["latitude", "longitude"]. - name: The name of the metric. Defaults to "MaximumMeanAbsoluteError". + """Compute MAE between forecast and target maximum values. + + Extends MeanAbsoluteError to filter forecast to a time window around the + target's maximum using tolerance_range_hours. Useful for evaluating peak + value timing and magnitude. """ def __init__( @@ -952,6 +1036,20 @@ def __init__( *args, **kwargs, ): + """Initialize the Maximum Mean Absolute Error metric. + + Args: + tolerance_range_hours: Time window (hours) around target's + maximum to search for forecast maximum. Defaults to 24. + reduce_spatial_dims: Spatial dimensions to reduce. Defaults to + ["latitude", "longitude"]. + name: The name of the metric. Defaults to + "MaximumMeanAbsoluteError". + *args: Additional positional arguments passed to + MeanAbsoluteError. + **kwargs: Additional keyword arguments passed to + MeanAbsoluteError. + """ self.tolerance_range_hours = tolerance_range_hours self.reduce_spatial_dims = reduce_spatial_dims super().__init__(name, *args, **kwargs) @@ -1007,19 +1105,11 @@ def _compute_metric( class MinimumMeanAbsoluteError(MeanAbsoluteError): - """Computes the mean absolute error between the forecast and target minimum values. - - The forecast is filtered to a time window around the target's minimum using - tolerance_range_hours (in the event of variation between the timing between the - target and forecast minimum values). The mean absolute error is computed between the - filtered forecast and target minimum value. - - Args: - tolerance_range_hours: The time window (hours) around the target's minimum - value to search for forecast minimum. Defaults to 24 hours. - reduce_spatial_dims: The spatial dimensions to reduce. Defaults to - ["latitude", "longitude"]. - name: The name of the metric. Defaults to "MinimumMeanAbsoluteError". + """Compute MAE between forecast and target minimum values. + + Extends MeanAbsoluteError to filter forecast to a time window around the + target's minimum using tolerance_range_hours. Useful for evaluating + minimum value timing and magnitude. """ def __init__( @@ -1030,6 +1120,20 @@ def __init__( *args, **kwargs, ): + """Initialize the Minimum Mean Absolute Error metric. + + Args: + tolerance_range_hours: Time window (hours) around target's + minimum to search for forecast minimum. Defaults to 24. + reduce_spatial_dims: Spatial dimensions to reduce. Defaults to + ["latitude", "longitude"]. + name: The name of the metric. Defaults to + "MinimumMeanAbsoluteError". + *args: Additional positional arguments passed to + MeanAbsoluteError. + **kwargs: Additional keyword arguments passed to + MeanAbsoluteError. + """ self.tolerance_range_hours = tolerance_range_hours self.reduce_spatial_dims = reduce_spatial_dims super().__init__(name, *args, **kwargs) @@ -1082,16 +1186,11 @@ def _compute_metric( class MaximumLowestMeanAbsoluteError(MeanAbsoluteError): - """Mean Absolute Error of the maximum of aggregated minimum values. - - Meant for heatwave evaluation by aggregating the minimum values over a day and then - computing the MeanAbsoluteError between the warmest nighttime (daily minimum) - temperature in the target and forecast. + """Compute MAE of maximum aggregated minimum values for heatwaves. - Args: - tolerance_range_hours: The time window (hours) around the target's max-min - value to search for forecast max-min. Defaults to 24 hours. - name: The name of the metric. Defaults to "MaximumLowestMeanAbsoluteError". + Extends MeanAbsoluteError for heatwave evaluation by aggregating daily + minimum values and computing MAE between the warmest nighttime (daily + minimum) temperature in target and forecast. """ def __init__( @@ -1101,6 +1200,18 @@ def __init__( *args, **kwargs, ): + """Initialize the Maximum Lowest Mean Absolute Error metric. + + Args: + tolerance_range_hours: Time window (hours) around target's + max-min value to search for forecast max-min. Defaults to 24. + name: The name of the metric. Defaults to + "MaximumLowestMeanAbsoluteError". + *args: Additional positional arguments passed to + MeanAbsoluteError. + **kwargs: Additional keyword arguments passed to + MeanAbsoluteError. + """ self.tolerance_range_hours = tolerance_range_hours super().__init__(name, *args, **kwargs) @@ -1184,22 +1295,10 @@ def _compute_metric( class DurationMeanError(MeanError): - """Compute the duration of a case's event. - - This metric computes the mean error between the forecast and target durations. - - Args: - threshold_criteria: The criteria for event detection. Can be either a DataArray - of a climatology with dimensions (dayofyear, hour, latitude, longitude) or a - float value representing a fixed threshold. - reduce_spatial_dims: The spatial dimensions to reduce prior to applying threshold - criteria. Defaults to ["latitude", "longitude"]. - op_func: Comparison operator or string (e.g., operator.ge for >=). - name: Name of the metric. - preserve_dims: Dimensions to preserve during aggregation. Defaults to - "init_time". - product_time_resolution_hours: Whether to product the duration by the time - resolution of the forecast (in hours). Defaults to False. + """Compute mean error of event duration between forecast and target. + + Extends MeanError to compute the mean error between forecast and target + event durations based on threshold criteria and spatial aggregation. """ def __init__( @@ -1211,6 +1310,23 @@ def __init__( preserve_dims: str = "init_time", product_time_resolution_hours: bool = False, ): + """Initialize the Duration Mean Error metric. + + Args: + threshold_criteria: Criteria for event detection. Either a + DataArray of climatology with dimensions (dayofyear, hour, + latitude, longitude) or a float fixed threshold. + reduce_spatial_dims: Spatial dimensions to reduce prior to + applying threshold criteria. Defaults to ["latitude", + "longitude"]. + op_func: Comparison operator or string (e.g., operator.ge for + >=). + name: Name of the metric. Defaults to "DurationMeanError". + preserve_dims: Dimensions to preserve during aggregation. + Defaults to "init_time". + product_time_resolution_hours: Whether to multiply duration by + time resolution of forecast (in hours). Defaults to False. + """ super().__init__(name=name, preserve_dims=preserve_dims) self.reduce_spatial_dims = reduce_spatial_dims self.threshold_criteria = threshold_criteria @@ -1307,15 +1423,17 @@ def _compute_metric( class LandfallMetric(CompositeMetric): - """Base class for landfall metrics. + """Base class for tropical cyclone landfall metrics. + + Extends CompositeMetric to compute landfalls using calc.find_landfalls, + which utilizes land geometry and line segments based on track data to + determine intersections. - Landfall metrics compute landfalls using the calc.find_landfalls function, which - utilizes a land geometry and line segments based on track data to determine - intersections. + Can be used as a base class for custom landfall metrics, as a mixin with + other metrics, or as a composite metric for multiple landfall metrics. - Can be used as a base class for custom landfall metrics, as a mixin with other - metrics, or as a composite metric for multiple landfall metrics (which utilize - identical landfalling locations). + Public methods: + maybe_prepare_composite_kwargs: Prepare kwargs for landfall composites """ def __init__( @@ -1521,13 +1639,11 @@ def _compute_metric( class SpatialDisplacement(BaseMetric): - """Spatial displacement error metric for atmospheric rivers and similar events. - - Computes the great circle distance between the center of mass of forecast - and target spatial patterns. + """Compute spatial displacement between forecast and target patterns. - Args: - name: The name of the metric. Defaults to "spatial_displacement". + Extends BaseMetric to compute great circle distance between centers of + mass of forecast and target spatial patterns. Useful for atmospheric + rivers and similar spatial features. """ def __init__( @@ -1535,6 +1651,13 @@ def __init__( name: str = "spatial_displacement", **kwargs: Any, ): + """Initialize the Spatial Displacement metric. + + Args: + name: The name of the metric. Defaults to + "spatial_displacement". + **kwargs: Additional keyword arguments passed to BaseMetric. + """ super().__init__(name, **kwargs) def _compute_metric( @@ -1621,13 +1744,10 @@ def center_of_mass_ufunc(data): class LandfallDisplacement(LandfallMetric): - """Calculate the distance between forecast and target landfall positions. + """Compute distance between forecast and target landfall positions. - This metric computes the distance between the forecast and target - landfall positions, defaulting to kilometers. - - Args: - name: The name of the metric. Defaults to "landfall_displacement". + Extends LandfallMetric to calculate the spatial distance between forecast + and target landfall positions, defaulting to kilometers. """ def __init__( @@ -1636,6 +1756,14 @@ def __init__( *args, **kwargs, ): + """Initialize the Landfall Displacement metric. + + Args: + name: The name of the metric. Defaults to + "landfall_displacement". + *args: Additional positional arguments passed to LandfallMetric. + **kwargs: Additional keyword arguments passed to LandfallMetric. + """ super().__init__(name, *args, **kwargs) self.units = kwargs.get("units", "km") @@ -1717,15 +1845,11 @@ def _compute_metric( class LandfallTimeMeanError(LandfallMetric): - """Landfall time mean error. - - This metric computes the mean error between the forecast and target landfall times. - A positive value indicates the forecast landfall time is later than the target - landfall time, a negative value indicates the forecast landfall time is earlier than - the target landfall time. + """Compute mean error between forecast and target landfall times. - Args: - name: The name of the metric. Defaults to "landfall_time_me". + Extends LandfallMetric to calculate timing difference. Positive values + indicate forecast landfall is later than target; negative values indicate + forecast landfall is earlier than target. """ def __init__( @@ -1734,6 +1858,13 @@ def __init__( *args, **kwargs, ): + """Initialize the Landfall Time Mean Error metric. + + Args: + name: The name of the metric. Defaults to "landfall_time_me". + *args: Additional positional arguments passed to LandfallMetric. + **kwargs: Additional keyword arguments passed to LandfallMetric. + """ super().__init__(name, *args, **kwargs) def calculate_time_difference( @@ -1791,18 +1922,14 @@ def _compute_metric( class LandfallIntensityMeanAbsoluteError(LandfallMetric, MeanAbsoluteError): - """Compute the MeanAbsoluteError between forecast and target. + """Compute MAE of forecast and target intensity at landfall. - This metric computes the mean absolute error between forecast and target - intensity at landfall. + Extends both LandfallMetric and MeanAbsoluteError to calculate mean + absolute error between forecast and target intensity at landfall time. The intensity variable is determined by forecast_variable and - target_variable. To evaluate multiple intensity variables (e.g., - surface_wind_speed and air_pressure_at_mean_sea_level), create - separate metric instances for each variable. - - Args: - name: The name of the metric. Defaults to "landfall_intensity_mae". + target_variable. For multiple intensity variables, create separate metric + instances for each variable. """ def __init__( @@ -1811,6 +1938,14 @@ def __init__( *args, **kwargs, ): + """Initialize the Landfall Intensity Mean Absolute Error metric. + + Args: + name: The name of the metric. Defaults to + "landfall_intensity_mae". + *args: Additional positional arguments passed to parent classes. + **kwargs: Additional keyword arguments passed to parent classes. + """ super().__init__(name, *args, **kwargs) def _compute_metric( diff --git a/src/extremeweatherbench/regions.py b/src/extremeweatherbench/regions.py index f2580ccb..5a36dc0e 100644 --- a/src/extremeweatherbench/regions.py +++ b/src/extremeweatherbench/regions.py @@ -22,7 +22,21 @@ class Region(abc.ABC): - """Base class for different region representations.""" + """Base class for different region representations. + + This abstract class defines the interface for geographic regions used in + ExtremeWeatherBench. Regions can be centered, bounding boxes, or defined + by shapefiles. + + Public methods: + create_region: Abstract factory method to create a region + as_geopandas: Convert region to GeoDataFrame representation + get_adjusted_bounds: Get region bounds adjusted to dataset convention + mask: Mask a dataset to this region + intersects: Check if this region intersects another region + contains: Check if this region contains another region + area_overlap_fraction: Calculate area overlap with another region + """ @classmethod @abc.abstractmethod @@ -159,17 +173,11 @@ def area_overlap_fraction(self, other: "Region") -> float: class CenteredRegion(Region): - """A region defined by a center point and a bounding box. - - bounding_box_degrees is the width (length) of one or all sides, not half size; - e.g., bounding_box_degrees=10.0 means a 10 degree by 10 degree box around - the center point. + """Region defined by center point and bounding box. - Attributes: - latitude: Center latitude - longitude: Center longitude - bounding_box_degrees: Size of bounding box in degrees or tuple of - (lat_degrees, lon_degrees) + Extends Region to define a region using a center point and bounding box + dimensions. The bounding_box_degrees is the full width/height (not half + size); e.g., 10.0 means a 10x10 degree box around the center. """ def __repr__(self): @@ -182,6 +190,15 @@ def __repr__(self): def __init__( self, latitude: float, longitude: float, bounding_box_degrees: float | tuple ): + """Initialize the CenteredRegion. + + Args: + latitude: Center latitude in degrees. + longitude: Center longitude in degrees. + bounding_box_degrees: Size of bounding box in degrees. Either a + single float (square box) or tuple of (lat_degrees, + lon_degrees). + """ self.latitude = latitude self.longitude = longitude self.bounding_box_degrees = bounding_box_degrees @@ -229,13 +246,9 @@ def as_geopandas(self) -> gpd.GeoDataFrame: class BoundingBoxRegion(Region): - """A region defined by explicit latitude and longitude bounds. + """Region defined by explicit latitude and longitude bounds. - Attributes: - latitude_min: Minimum latitude bound - latitude_max: Maximum latitude bound - longitude_min: Minimum longitude bound - longitude_max: Maximum longitude bound + Extends Region to define a region using explicit bounding box coordinates. """ def __repr__(self): @@ -253,6 +266,14 @@ def __init__( longitude_min: float, longitude_max: float, ): + """Initialize the BoundingBoxRegion. + + Args: + latitude_min: Minimum latitude bound in degrees. + latitude_max: Maximum latitude bound in degrees. + longitude_min: Minimum longitude bound in degrees. + longitude_max: Maximum longitude bound in degrees. + """ self.latitude_min = latitude_min self.latitude_max = latitude_max self.longitude_min = longitude_min @@ -286,19 +307,21 @@ def as_geopandas(self) -> gpd.GeoDataFrame: class ShapefileRegion(Region): - """A region defined by a shapefile. - - A geopandas object shapefile is read in and stored as an attribute - on instantiation. + """Region defined by a shapefile. - Attributes: - shapefile_path: Local or remote path to the .shp shapefile + Extends Region to define a region using a shapefile. The shapefile is read + using geopandas on instantiation. """ def __repr__(self): return f"{self.__class__.__name__}(shapefile_path={self.shapefile_path})" def __init__(self, shapefile_path: str | pathlib.Path): + """Initialize the ShapefileRegion. + + Args: + shapefile_path: Local or remote path to the .shp shapefile. + """ self.shapefile_path = pathlib.Path(shapefile_path) @classmethod @@ -465,16 +488,17 @@ def _create_geopandas_from_bounds( class RegionSubsetter: - """A utility class for subsetting ExtremeWeatherBench objects by region. + """Utility class for subsetting ExtremeWeatherBench objects by region. - Attributes: - region: The region to subset to. Can be a Region object or a - dictionary of bounds with keys "latitude_min", "latitude_max", - "longitude_min", and "longitude_max". - method: The method to use for subsetting. Options: - - "intersects": Include cases where ANY part of a case intersects region - - "percent": Include cases where percent of case area overlaps with region. - - "all": Only include cases where entirety of a case is within region + Provides methods for filtering case collections based on spatial overlap + with a specified region using various inclusion criteria. + + Public methods: + subset: Subset a case collection based on region overlap + + Instance attributes: + region: The region to subset to + method: The subsetting method used percent_threshold: Threshold for percent overlap (0.0 to 1.0) """ @@ -491,10 +515,10 @@ def __init__( """Initialize the RegionSubsetter. Args: - region: The region to subset to. Can be a Region object or a - dictionary of bounds with keys "latitude_min", "latitude_max", - "longitude_min", and "longitude_max". - method: The method to use for subsetting. Options: + region: The region to subset to. Can be a Region object or + dictionary with keys "latitude_min", "latitude_max", + "longitude_min", "longitude_max". + method: The subsetting method. Options: - "intersects": Include cases where ANY part of a case intersects region - "percent": Include cases where percent of case area overlaps with region diff --git a/src/extremeweatherbench/sources/base.py b/src/extremeweatherbench/sources/base.py index 0abed12c..dd58641c 100644 --- a/src/extremeweatherbench/sources/base.py +++ b/src/extremeweatherbench/sources/base.py @@ -6,7 +6,17 @@ @runtime_checkable class Source(Protocol): - """A protocol for input sources.""" + """Protocol defining the interface for input data sources. + + This protocol specifies the methods that input source implementations must + provide for variable extraction, temporal validation, and spatial data + checking. + + Required methods: + safely_pull_variables: Extract specified variables from data + check_for_valid_times: Check if data has valid times in date range + check_for_spatial_data: Check if data has spatial coverage for region + """ def safely_pull_variables( self, From 79cf15fbb5ff9ac0645200af0b23c6abc3d42619 Mon Sep 17 00:00:00 2001 From: Taylor Mandelbaum Date: Fri, 23 Jan 2026 21:35:16 -0500 Subject: [PATCH 10/15] add explanation for dim reqs (#320) --- docs/usage.md | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/docs/usage.md b/docs/usage.md index a5e67b8c..5190a704 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -8,11 +8,6 @@ To run the Brightband-based evaluation on an existing AIWP model (FCN v2), which includes the default 337 cases for heat waves, freezes, severe convective days, tropical cyclones, and atmospheric rivers: -```bash -ewb --default -``` - -or: ```python from extremeweatherbench import evaluate, defaults, cases @@ -28,12 +23,21 @@ outputs = ewb.run() outputs.to_csv('your_outputs.csv') ``` +or: + +```bash +ewb --default +``` ## Running an Evaluation for a Single Event Type ExtremeWeatherBench has default event types and cases for heat waves, freezes, severe convection, tropical cyclones, and atmospheric rivers. To run an evaluation, there are three components required: a forecast, a target, and an evaluation object. +ExtremeWeatherBench requires forecasts to have `init_time`, `lead_time`, `latitude`, and `longitude` dimensions at minimum. If not already in that naming convention, initializing a `ForecastBase` object with a `variable_mapping` to map to those names is required. Other dimensions such as pressure level (`level`) can be included. + +Targets require at least a `valid_time` with at least one spatial dimension. Examples include `location`, `station`, or (`latitude`, `longitude`). Forecasts are aligned to targets during the steps immediately prior to evaluating a metric. + ```python from extremeweatherbench import inputs ``` From d4c8cdedbd27d6dfd7b8438a77662e4448f34f9d Mon Sep 17 00:00:00 2001 From: Taylor Mandelbaum Date: Fri, 23 Jan 2026 23:07:22 -0500 Subject: [PATCH 11/15] Update `defaults` and `inputs` to include new CIRA icechunk store (#319) * more explicit naming, add func and model names var * add test coverage, ruff, linting * update readme for new cira approach * move cira func and model ref to inputs * update docs * module wasnt called for moved func * update tests for moving func and var * ruff * fix mock typos --- README.md | 49 +---- docs/recipes/cira_forecast.md | 39 ++-- docs/usage.md | 28 ++- src/extremeweatherbench/defaults.py | 140 ++++++++------- src/extremeweatherbench/inputs.py | 59 ++++++ tests/test_defaults.py | 95 ++++++++-- tests/test_inputs.py | 267 +++++++++++++++++++++++++++- 7 files changed, 520 insertions(+), 157 deletions(-) diff --git a/README.md b/README.md index f0249f50..d90aff16 100644 --- a/README.md +++ b/README.md @@ -67,48 +67,11 @@ $ ewb --default ```python from extremeweatherbench import cases, inputs, metrics, evaluate, utils -# Select model -model = 'FOUR_v200_GFS' - -# Set up path to directory of file - zarr or kerchunk/virtualizarr json/parquet -forecast_dir = f'gs://extremeweatherbench/{model}.parq' - -# Preprocessing function exclusive to handling the CIRA parquets -def preprocess_bb_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: - """Preprocess CIRA kerchunk (parquet) data in the ExtremeWeatherBench bucket. - A preprocess function that renames the time coordinate to lead_time, - creates a valid_time coordinate, and sets the lead time range and resolution not - present in the original dataset. - Args: - ds: The forecast dataset to rename. - Returns: - The renamed forecast dataset. - """ - ds = ds.rename({"time": "lead_time"}) - - # The evaluation configuration is used to set the lead time range and resolution. - ds["lead_time"] = np.array( - [i for i in range(0, 241, 6)], dtype="timedelta64[h]" - ).astype("timedelta64[ns]") - - return ds - -# Define a forecast object; in this case, a KerchunkForecast -fcnv2_forecast = inputs.KerchunkForecast( - name="fcnv2_forecast", # identifier for this forecast in results - source=forecast_dir, # source path - variables=["surface_air_temperature"], # variables to use in the evaluation - variable_mapping=inputs.CIRA_metadata_variable_mapping, # mapping to use for variables in forecast dataset to EWB variable names - storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, # storage options for access - preprocess=preprocess_bb_cira_forecast_dataset # required preprocessing function for CIRA references -) +# Load in a forecast; here, we load in GFS initialized FCNv2 from the CIRA MLWP archive with a default variable built-in for convenience +fcnv2_heatwave_forecast = defaults.cira_fcnv2_heatwave_forecast -# Load in ERA5; source defaults to the ARCO ERA5 dataset from Google and variable mapping is provided by default as well -era5_heatwave_target = inputs.ERA5( - variables=["surface_air_temperature"], # variable to use in the evaluation - storage_options={"remote_options": {"anon": True}}, # storage options for access - chunks=None, # define chunks for the ERA5 data -) +# Load in ERA5 with another default convenience variable +era5_heatwave_target = defaults.era5_heatwave_target # EvaluationObjects are used to evaluate a single forecast source against a single target source with a defined event type. Event types are declared with each case. One or more metrics can be evaluated with each EvaluationObject. heatwave_evaluation_list = [ @@ -120,7 +83,7 @@ heatwave_evaluation_list = [ metrics.MaximumLowestMeanAbsoluteError(), ], target=era5_heatwave_target, - forecast=fcnv2_forecast, + forecast=fcnv2_heatwave_forecast, ), ] # Load in the EWB default list of event cases @@ -134,7 +97,7 @@ ewb_instance = evaluate.ExtremeWeatherBench( # Execute a parallel run and return the evaluation results as a pandas DataFrame heatwave_outputs = ewb_instance.run( - parallel_config={'backend':'loky','n_jobs':16} # Uses 16 jobs with the loky backend + parallel_config={'n_jobs':16} # Uses 16 jobs with the loky backend as default ) # Save the results diff --git a/docs/recipes/cira_forecast.md b/docs/recipes/cira_forecast.md index 43a7a6ac..9cf57b90 100644 --- a/docs/recipes/cira_forecast.md +++ b/docs/recipes/cira_forecast.md @@ -2,22 +2,10 @@ We have a dedicated virtual reference icechunk store for CIRA data **up to May 26th, 2025** available at `gs://extremeweatherbench/cira-icechunk`. Compared to using parquet virtual references, we have seen a speed improvements of around 2x with ~25% more memory usage. -## Loading the store - -```python - -from extremeweatherbench import cases, inputs, metrics, evaluate, defaults -import datetime -import icechunk - -storage = icechunk.gcs_storage( - bucket="extremeweatherbench", prefix="cira-icechunk", anonymous=True -) -``` - ## Accessing a CIRA Model from the store ```python +from extremeweatherbench import inputs group_list = inputs.list_groups_in_icechunk_datatree(storage) ``` @@ -39,22 +27,33 @@ group_list = inputs.list_groups_in_icechunk_datatree(storage) ```python -# Find FCNv2's name in the group list -fcnv2_group = [n for n in group_list if 'FOUR_v200_GFS' in n][0] - # Helper function to access the virtual dataset -fcnv2 = inputs.open_icechunk_dataset_from_datatree( +fcnv2 = inputs.get_cira_icechunk(model_name='FOUR_v200_IFS') +``` + +`fcnv2` is a `ForecastBase` object ready to be used within EWB's evaluation framework. + +> **Detailed Explanation**: `inputs.get_cira_icechunk` is syntactic sugar for this: +```python +import icechunk + +storage = icechunk.gcs_storage( + bucket="extremeweatherbench", prefix="cira-icechunk", anonymous=True +) + +fcnv2_icechunk_ds = inputs.open_icechunk_dataset_from_datatree( storage=storage, - group=fcnv2_group, + group="FOUR_v200_IFS", authorize_virtual_chunk_access=inputs.CIRA_CREDENTIALS ) -fcnv2_icechunk_forecast_object = inputs.XarrayForecast( + +fcnv2 = inputs.XarrayForecast( ds=fcnv2, variable_mapping=inputs.CIRA_metadata_variable_mapping ) ``` -`fcnv2_icechunk_forecast_object` is a `ForecastBase` object ready to be used within EWB's evaluation framework. +Which is a three step process of accessing the icechunk storage, loading the dataset from the datatree/zarr group format, and finally applying that `Dataset` in a `ForecastBase` object. ## Set up metrics and target for evaluation diff --git a/docs/usage.md b/docs/usage.md index 5190a704..d5e37778 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -41,7 +41,7 @@ Targets require at least a `valid_time` with at least one spatial dimension. Exa ```python from extremeweatherbench import inputs ``` -There are two built-in `ForecastBase` classes to set up a forecast: `ZarrForecast` and `KerchunkForecast`. Here is an example of a `ZarrForecast`, using Weatherbench2's HRES zarr store: +There are three built-in `ForecastBase` classes to set up a forecast: `ZarrForecast`, `XarrayForecast`, and `KerchunkForecast`. Here is an example of a `ZarrForecast`, using Weatherbench2's HRES zarr store: ```python hres_forecast = inputs.ZarrForecast( @@ -60,9 +60,9 @@ There are required arguments, namely: - `variables`* - `variable_mapping` -* `variables` can be defined within one or more metrics instead of in a `ForecastBase` object. +* `variables` can alternatively be defined within one or more metrics, instead of in a `ForecastBase` object. -A forecast needs a `source`, which is a link to the zarr store in this case. A `name` is required to identify the outputs. It also needs `variables` defined, which are based on CF Conventions. A list of variable namings exists in `defaults.py` as `DEFAULT_VARIABLE_NAMES`. Each forecast will likely have different names for their variables, so a `variable_mapping` dictionary is also essential to process the variables, as well as the coordinates and dimensions. EWB uses `lead_time`, `init_time`, and `valid_time` as time coordinates. The HRES data is mapped from `prediction_timedelta` to `lead_time`, as an example. `storage_options` define access patterns for the data if needed. These are passed to the opening function, e.g. `xarray.open_zarr`. +> **Detailed Explanation**: A forecast needs a `source`, which is a link to the zarr store in this case. A `name` is required to identify the outputs. It also needs `variables` defined, which are based on CF Conventions. A list of variable namings exists in `defaults.py` as `DEFAULT_VARIABLE_NAMES`. Each forecast will likely have different names for their variables, so a `variable_mapping` dictionary is also essential to process the variables, as well as the coordinates and dimensions. EWB uses `lead_time`, `init_time`, and `valid_time` as time coordinates. The HRES data is mapped from `prediction_timedelta` to `lead_time`, as an example. `storage_options` define access patterns for the data if needed. These are passed to the opening function, e.g. `xarray.open_zarr`. Next, a target dataset must be defined as well to evaluate against. For this evaluation, we'll use ERA5: @@ -75,7 +75,19 @@ era5_heatwave_target = inputs.ERA5( ) ``` -Similarly to forecasts, we need to define the `source`, which here is the ARCO ERA5 provided by Google. `variables` are again required to be set for the `inputs.ERA5` class; `variable_mapping` defaults to `inputs.ERA5_metadata_variable_mapping` for many existing variables and likely is not required to be set unless your use case is for less common variables. Both forecasts and targets, if relevant, have an optional `chunks` parameter which defaults to what should be the most efficient value - usually `None` or `'auto'`, but can be changed as seen above. +Note that EWB provides defaults for arguments, so most users will be able to instead write this (if defining variables with the intent of it applying to all metrics): + +```python +era5_heatwave_target = inputs.ERA5(variables=['surface_air_temperature']) +``` + +Or (if defining variables as arguments to the metrics): + +```python +era5_heatwave_target = inputs.ERA5() +``` + +> **Detailed Explanation**: Similarly to forecasts, we need to define the `source`, which here is the ARCO ERA5 provided by Google. `variables` are used to subset `inputs.ERA5` in an evaluation; `variable_mapping` defaults to `inputs.ERA5_metadata_variable_mapping` for many existing variables and likely is not required to be set unless your use case is for less common variables. Both forecasts and targets, if relevant, have an optional `chunks` parameter which defaults to what should be the most efficient value - usually `None` or `'auto'`, but can be changed as seen above. *If using the ARCO ERA5 and setting `chunks=None`, it is critical to order your subsetting by variables -> time -> `.sel` or `.isel` latitude & longitude -> rechunk. [See this Github comment](https://github.com/pydata/xarray/issues/8902#issuecomment-2036435045). We then set up an `EvaluationObject` list: @@ -102,11 +114,11 @@ Plugging these all in: ```python from extremeweatherbench import cases, evaluate -case_yaml = cases.load_ewb_events_yaml_into_case_list() +case_list = cases.load_ewb_events_yaml_into_case_list() ewb_instance = evaluate.ExtremeWeatherBench( - cases=case_yaml, + cases=case_list, evaluation_objects=heatwave_evaluation_list, ) @@ -115,6 +127,8 @@ outputs = ewb_instance.run() outputs.to_csv('your_file_name.csv') ``` -Where the EWB default events YAML file is loaded in using a built-in utility helper function, then applied to an instance of `evaluate.ExtremeWeatherBench` along with the `EvaluationObject` list. Finally, we run the evaluation with the `.run()` method, where defaults are typically sufficient to run with a small to moderate-sized virtual machine. after subsetting and prior to metric calculation. +Where the EWB default events YAML file is loaded in using a built-in utility helper function, then applied to an instance of `evaluate.ExtremeWeatherBench` along with the `EvaluationObject` list. Finally, we trigger the evaluation with the `.run()` method, where defaults are typically sufficient to run with a small to moderate-sized virtual machine. after subsetting and prior to metric calculation. + +Running locally is feasible but is typically bottlenecked heavily by IO and network bandwidth. Even on a gigabit connection, the rate of data access is significantly slower compared to within a cloud provider VM. The outputs are returned as a pandas DataFrame and can be manipulated in the script, a notebook, or post-hoc after saving it. diff --git a/src/extremeweatherbench/defaults.py b/src/extremeweatherbench/defaults.py index a78cdcca..7dcc68d6 100644 --- a/src/extremeweatherbench/defaults.py +++ b/src/extremeweatherbench/defaults.py @@ -58,28 +58,36 @@ ] -def _preprocess_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: +def _preprocess_cira_forecast_dataset( + ds: xr.Dataset, kerchunk: bool = True +) -> xr.Dataset: """A preprocess function for CIRA data that renames the time coordinate to lead_time, creates a valid_time coordinate, and sets the lead time range and resolution not present in the original dataset. Args: ds: The forecast dataset to preprocess. - + kerchunk: Whether the dataset is a kerchunk reference. Defaults to True. Returns: The preprocessed forecast dataset. """ - ds = ds.rename({"time": "lead_time"}) - # The evaluation configuration is used to set the lead time range and resolution. - ds["lead_time"] = np.array( - [i for i in range(0, 241, 6)], dtype="timedelta64[h]" - ).astype("timedelta64[ns]") + + # If the dataset is a kerchunk, we need to rename the time coordinate to lead_time + # and set the lead time range and resolution. Otherwise, pass through the dataset. + if kerchunk: + ds = ds.rename({"time": "lead_time"}) + # The evaluation configuration is used to set the lead time range and resolution. + ds["lead_time"] = np.array( + [i for i in range(0, 241, 6)], dtype="timedelta64[h]" + ).astype("timedelta64[ns]") return ds # Preprocessing function for CIRA data that includes geopotential thickness calculation # required for tropical cyclone tracks -def _preprocess_cira_tc_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: +def _preprocess_cira_tc_forecast_dataset( + ds: xr.Dataset, kerchunk: bool = True +) -> xr.Dataset: """A preprocess function for CIRA data that includes geopotential thickness calculation required for tropical cyclone tracks. @@ -89,16 +97,18 @@ def _preprocess_cira_tc_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: Args: ds: The forecast dataset to rename. - + kerchunk: Whether the dataset is a kerchunk reference. Defaults to True. Returns: The renamed forecast dataset. """ - ds = ds.rename({"time": "lead_time"}) - - # The evaluation configuration is used to set the lead time range and resolution. - ds["lead_time"] = np.array( - [i for i in range(0, 241, 6)], dtype="timedelta64[h]" - ).astype("timedelta64[ns]") + # If the dataset is a kerchunk, we need to rename the time coordinate to lead_time + # and set the lead time range and resolution. Otherwise, pass through the dataset. + if kerchunk: + ds = ds.rename({"time": "lead_time"}) + # The evaluation configuration is used to set the lead time range and resolution. + ds["lead_time"] = np.array( + [i for i in range(0, 241, 6)], dtype="timedelta64[h]" + ).astype("timedelta64[ns]") # Calculate the geopotential thickness required for tropical cyclone tracks ds["geopotential_thickness"] = ( @@ -133,23 +143,27 @@ def _preprocess_hres_tc_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: # Preprocess function for CIRA data using Brightband kerchunk parquets -def _preprocess_ar_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: +def _preprocess_cira_ar_forecast_dataset( + ds: xr.Dataset, kerchunk: bool = True +) -> xr.Dataset: """An example preprocess function that renames the time coordinate to lead_time, creates a valid_time coordinate, and sets the lead time range and resolution not present in the original dataset. Args: ds: The forecast dataset to rename. - + kerchunk: Whether the dataset is a kerchunk reference. Defaults to True. Returns: The renamed forecast dataset. """ - ds = ds.rename({"time": "lead_time"}) - - # The evaluation configuration is used to set the lead time range and resolution. - ds["lead_time"] = np.array( - [i for i in range(0, 241, 6)], dtype="timedelta64[h]" - ).astype("timedelta64[ns]") + # If the dataset is a kerchunk, we need to rename the time coordinate to lead_time + # and set the lead time range and resolution. Otherwise, pass through the dataset. + if kerchunk: + ds = ds.rename({"time": "lead_time"}) + # The evaluation configuration is used to set the lead time range and resolution. + ds["lead_time"] = np.array( + [i for i in range(0, 241, 6)], dtype="timedelta64[h]" + ).astype("timedelta64[ns]") if "q" not in ds.variables: # Calculate specific humidity from relative humidity and air temperature ds["specific_humidity"] = calc.specific_humidity_from_relative_humidity( @@ -161,23 +175,27 @@ def _preprocess_ar_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: # Preprocess function for CIRA data using Brightband kerchunk parquets -def _preprocess_severe_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: +def _preprocess_severe_cira_forecast_dataset( + ds: xr.Dataset, kerchunk: bool = True +) -> xr.Dataset: """An example preprocess function that renames the time coordinate to lead_time, creates a valid_time coordinate, and sets the lead time range and resolution not present in the original dataset. Args: ds: The forecast dataset to rename. - + kerchunk: Whether the dataset is a kerchunk reference. Defaults to True. Returns: The renamed forecast dataset. """ - ds = ds.rename({"time": "lead_time"}) - - # The evaluation configuration is used to set the lead time range and resolution. - ds["lead_time"] = np.array( - [i for i in range(0, 241, 6)], dtype="timedelta64[h]" - ).astype("timedelta64[ns]") + # If the dataset is a kerchunk, we need to rename the time coordinate to lead_time + # and set the lead time range and resolution. Otherwise, pass through the dataset. + if kerchunk: + ds = ds.rename({"time": "lead_time"}) + # The evaluation configuration is used to set the lead time range and resolution. + ds["lead_time"] = np.array( + [i for i in range(0, 241, 6)], dtype="timedelta64[h]" + ).astype("timedelta64[ns]") if "q" not in ds.variables: # Calculate specific humidity from relative humidity and air temperature ds["specific_humidity"] = calc.specific_humidity_from_relative_humidity( @@ -243,51 +261,39 @@ def get_climatology(quantile: float = 0.85) -> xr.DataArray: ibtracs_target = inputs.IBTrACS() # Forecasts -cira_heatwave_forecast = inputs.KerchunkForecast( - name="FourCastNetv2", - source="gs://extremeweatherbench/FOUR_v200_GFS.parq", +cira_fcnv2_heatwave_forecast = inputs.get_cira_icechunk( + model_name="FOUR_v200_GFS", variables=["surface_air_temperature"], - variable_mapping=inputs.CIRA_metadata_variable_mapping, - storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, - preprocess=_preprocess_cira_forecast_dataset, + name="FourCastNetv2", ) -cira_freeze_forecast = inputs.KerchunkForecast( - name="FourCastNetv2", - source="gs://extremeweatherbench/FOUR_v200_GFS.parq", +cira_fcnv2_freeze_forecast = inputs.get_cira_icechunk( + model_name="FOUR_v200_GFS", variables=["surface_air_temperature"], - variable_mapping=inputs.CIRA_metadata_variable_mapping, - storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, - preprocess=_preprocess_cira_forecast_dataset, + name="FourCastNetv2", ) -cira_tropical_cyclone_forecast = inputs.KerchunkForecast( - name="FourCastNetv2", - source="gs://extremeweatherbench/FOUR_v200_GFS.parq", +cira_fcnv2_tropical_cyclone_forecast = inputs.get_cira_icechunk( + model_name="FOUR_v200_GFS", variables=[derived.TropicalCycloneTrackVariables()], - variable_mapping=inputs.CIRA_metadata_variable_mapping, - storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, + name="FourCastNetv2", preprocess=_preprocess_cira_tc_forecast_dataset, ) -cira_atmospheric_river_forecast = inputs.KerchunkForecast( - name="FourCastNetv2", - source="gs://extremeweatherbench/FOUR_v200_GFS.parq", +cira_fcnv2_atmospheric_river_forecast = inputs.get_cira_icechunk( + model_name="FOUR_v200_GFS", variables=[ derived.AtmosphericRiverVariables( output_variables=["atmospheric_river_land_intersection"] ) ], - variable_mapping=inputs.CIRA_metadata_variable_mapping, - storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, - preprocess=_preprocess_ar_cira_forecast_dataset, + name="FourCastNetv2", + preprocess=_preprocess_cira_ar_forecast_dataset, ) -cira_severe_convection_forecast = inputs.KerchunkForecast( - name="FourCastNetv2", - source="gs://extremeweatherbench/FOUR_v200_GFS.parq", +cira_fcnv2_severe_convection_forecast = inputs.get_cira_icechunk( + model_name="FOUR_v200_GFS", variables=[derived.CravenBrooksSignificantSevere()], - variable_mapping=inputs.CIRA_metadata_variable_mapping, - storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, + name="FourCastNetv2", preprocess=_preprocess_severe_cira_forecast_dataset, ) @@ -363,37 +369,37 @@ def get_brightband_evaluation_objects() -> list[inputs.EvaluationObject]: event_type="heat_wave", metric_list=heatwave_metric_list, target=era5_heatwave_target, - forecast=cira_heatwave_forecast, + forecast=cira_fcnv2_heatwave_forecast, ), inputs.EvaluationObject( event_type="heat_wave", metric_list=heatwave_metric_list, target=ghcn_heatwave_target, - forecast=cira_heatwave_forecast, + forecast=cira_fcnv2_heatwave_forecast, ), inputs.EvaluationObject( event_type="freeze", metric_list=freeze_metric_list, target=era5_freeze_target, - forecast=cira_freeze_forecast, + forecast=cira_fcnv2_freeze_forecast, ), inputs.EvaluationObject( event_type="freeze", metric_list=freeze_metric_list, target=ghcn_freeze_target, - forecast=cira_freeze_forecast, + forecast=cira_fcnv2_freeze_forecast, ), inputs.EvaluationObject( event_type="severe_convection", metric_list=pph_metric_list, target=pph_target, - forecast=cira_severe_convection_forecast, + forecast=cira_fcnv2_severe_convection_forecast, ), inputs.EvaluationObject( event_type="severe_convection", metric_list=lsr_metric_list, target=lsr_target, - forecast=cira_severe_convection_forecast, + forecast=cira_fcnv2_severe_convection_forecast, ), inputs.EvaluationObject( event_type="atmospheric_river", @@ -403,12 +409,12 @@ def get_brightband_evaluation_objects() -> list[inputs.EvaluationObject]: metrics.EarlySignal(), ], target=era5_atmospheric_river_target, - forecast=cira_atmospheric_river_forecast, + forecast=cira_fcnv2_atmospheric_river_forecast, ), inputs.EvaluationObject( event_type="tropical_cyclone", metric_list=composite_landfall_metrics, target=ibtracs_target, - forecast=cira_tropical_cyclone_forecast, + forecast=cira_fcnv2_tropical_cyclone_forecast, ), ] diff --git a/src/extremeweatherbench/inputs.py b/src/extremeweatherbench/inputs.py index cd0f52cd..8708dde9 100644 --- a/src/extremeweatherbench/inputs.py +++ b/src/extremeweatherbench/inputs.py @@ -145,6 +145,17 @@ {"s3://noaa-oar-mlwp-data/": icechunk.s3_credentials(anonymous=True)} ) +CIRA_MODEL_NAMES = [ + "AURO_v100_GFS", + "FOUR_v200_IFS", + "PANG_v100_IFS", + "FOUR_v200_GFS", + "GRAP_v100_GFS", + "AURO_v100_IFS", + "PANG_v100_GFS", + "GRAP_v100_IFS", +] + def _default_preprocess(input_data: IncomingDataInput) -> IncomingDataInput: """Default forecast preprocess function that does nothing.""" @@ -1268,3 +1279,51 @@ def check_for_missing_data( return False else: return True + + +def get_cira_icechunk( + model_name: str, + variables: list[Union[str, derived.DerivedVariable]] = [], + preprocess: Callable = _default_preprocess, + name: Optional[str] = None, +) -> XarrayForecast: + """Get a CIRA icechunk forecast object for a given model name. + + Args: + model_name: The name of the model from CIRA to get the forecast object for. For + example, "FOUR_v200_GFS". For a list of available models, see + `extremeweatherbench.defaults.CIRA_MODEL_NAMES`. + variables: The variables to select from the model. Defaults to all variables. + preprocess: The preprocessing function to apply to the model. Defaults to the + default passthrough preprocess function. + name: The name of the forecast object. Defaults to model_name by default unless + `name` is provided. + Returns: + An XarrayForecast object for the given model. + """ + # Check if the model name is valid + if model_name not in CIRA_MODEL_NAMES: + raise ValueError( + f"Model name {model_name} not found in CIRA_MODEL_NAMES. Model names must be one of: {CIRA_MODEL_NAMES}" + ) + + # Get the CIRA icechunkstorage + cira_storage = icechunk.gcs_storage( + bucket="extremeweatherbench", prefix="cira-icechunk", anonymous=True + ) + + # The models are distinct groups within the icechunk store; open the group + # corresponding to the model name + cira_model_ds = open_icechunk_dataset_from_datatree( + cira_storage, model_name, authorize_virtual_chunk_access=CIRA_CREDENTIALS + ) + + # Create the XarrayForecast object for the given model + cira_model_forecast = XarrayForecast( + ds=cira_model_ds, + variables=variables, + variable_mapping=CIRA_metadata_variable_mapping, + name=name if name else model_name, + preprocess=preprocess, + ) + return cira_model_forecast diff --git a/tests/test_defaults.py b/tests/test_defaults.py index cbbea8bd..ff452c91 100644 --- a/tests/test_defaults.py +++ b/tests/test_defaults.py @@ -120,10 +120,22 @@ def test_target_objects_exist(self): def test_forecast_objects_exist(self): """Test that forecast objects are properly defined.""" - assert hasattr(defaults, "cira_heatwave_forecast") - assert hasattr(defaults, "cira_freeze_forecast") - assert isinstance(defaults.cira_heatwave_forecast, inputs.KerchunkForecast) - assert isinstance(defaults.cira_freeze_forecast, inputs.KerchunkForecast) + assert hasattr(defaults, "cira_fcnv2_heatwave_forecast") + assert hasattr(defaults, "cira_fcnv2_freeze_forecast") + assert hasattr(defaults, "cira_fcnv2_tropical_cyclone_forecast") + assert hasattr(defaults, "cira_fcnv2_atmospheric_river_forecast") + assert hasattr(defaults, "cira_fcnv2_severe_convection_forecast") + assert isinstance(defaults.cira_fcnv2_heatwave_forecast, inputs.XarrayForecast) + assert isinstance(defaults.cira_fcnv2_freeze_forecast, inputs.XarrayForecast) + assert isinstance( + defaults.cira_fcnv2_tropical_cyclone_forecast, inputs.XarrayForecast + ) + assert isinstance( + defaults.cira_fcnv2_atmospheric_river_forecast, inputs.XarrayForecast + ) + assert isinstance( + defaults.cira_fcnv2_severe_convection_forecast, inputs.XarrayForecast + ) def test_era5_heatwave_target_configuration(self): """Test ERA5 heatwave target configuration.""" @@ -149,21 +161,6 @@ def test_era5_freeze_target_configuration(self): for key, value in expected_mapping.items(): assert target.variable_mapping[key] == value - def test_cira_forecasts_have_preprocess_function(self): - """Test that CIRA forecasts have the preprocess function set.""" - assert defaults.cira_heatwave_forecast.preprocess is not None - assert defaults.cira_freeze_forecast.preprocess is not None - - # Test that the preprocess function is the expected one - assert ( - defaults.cira_heatwave_forecast.preprocess - == defaults._preprocess_cira_forecast_dataset - ) - assert ( - defaults.cira_freeze_forecast.preprocess - == defaults._preprocess_cira_forecast_dataset - ) - def test_get_brightband_evaluation_objects_no_exceptions(self): """Test that get_brightband_evaluation_objects runs without exceptions.""" try: @@ -173,3 +170,63 @@ def test_get_brightband_evaluation_objects_no_exceptions(self): assert len(result) > 0 except Exception as e: pytest.fail(f"get_brightband_evaluation_objects raised an exception: {e}") + + +class TestCiraFcnv2PreprocessFunctions: + """Tests that each cira_fcnv2 forecast has the correct preprocessing function.""" + + def test_heatwave_forecast_has_default_preprocess(self): + """Test that cira_fcnv2_heatwave_forecast uses default preprocess.""" + forecast = defaults.cira_fcnv2_heatwave_forecast + assert forecast.preprocess == inputs._default_preprocess + + def test_freeze_forecast_has_default_preprocess(self): + """Test that cira_fcnv2_freeze_forecast uses default preprocess.""" + forecast = defaults.cira_fcnv2_freeze_forecast + assert forecast.preprocess == inputs._default_preprocess + + def test_tropical_cyclone_forecast_has_tc_preprocess(self): + """Test that cira_fcnv2_tropical_cyclone_forecast uses TC preprocess.""" + forecast = defaults.cira_fcnv2_tropical_cyclone_forecast + assert forecast.preprocess == defaults._preprocess_cira_tc_forecast_dataset + + def test_atmospheric_river_forecast_has_ar_preprocess(self): + """Test that cira_fcnv2_atmospheric_river_forecast uses AR preprocess.""" + forecast = defaults.cira_fcnv2_atmospheric_river_forecast + assert forecast.preprocess == defaults._preprocess_cira_ar_forecast_dataset + + def test_severe_convection_forecast_has_severe_preprocess(self): + """Test that cira_fcnv2_severe_convection_forecast uses severe preprocess.""" + forecast = defaults.cira_fcnv2_severe_convection_forecast + assert forecast.preprocess == defaults._preprocess_severe_cira_forecast_dataset + + def test_all_forecasts_have_preprocess_attribute(self): + """Test that all cira_fcnv2 forecasts have a preprocess attribute set.""" + forecasts = [ + defaults.cira_fcnv2_heatwave_forecast, + defaults.cira_fcnv2_freeze_forecast, + defaults.cira_fcnv2_tropical_cyclone_forecast, + defaults.cira_fcnv2_atmospheric_river_forecast, + defaults.cira_fcnv2_severe_convection_forecast, + ] + for forecast in forecasts: + assert hasattr(forecast, "preprocess") + assert forecast.preprocess is not None + assert callable(forecast.preprocess) + + def test_preprocess_functions_are_distinct_where_expected(self): + """Test that different event types use different preprocess functions.""" + # TC, AR, and severe should have distinct preprocess functions + tc_preprocess = defaults.cira_fcnv2_tropical_cyclone_forecast.preprocess + ar_preprocess = defaults.cira_fcnv2_atmospheric_river_forecast.preprocess + severe_preprocess = defaults.cira_fcnv2_severe_convection_forecast.preprocess + + assert tc_preprocess != ar_preprocess + assert tc_preprocess != severe_preprocess + # Note: AR and severe could be the same or different depending on impl + + def test_heatwave_and_freeze_use_same_preprocess(self): + """Test that heatwave and freeze forecasts use the same preprocess.""" + heatwave_preprocess = defaults.cira_fcnv2_heatwave_forecast.preprocess + freeze_preprocess = defaults.cira_fcnv2_freeze_forecast.preprocess + assert heatwave_preprocess == freeze_preprocess diff --git a/tests/test_inputs.py b/tests/test_inputs.py index 4531d150..0725afb0 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -2233,7 +2233,9 @@ def test_xarray_forecast_none_handling_for_optional_params( """Test that None values are properly converted to empty defaults.""" # Explicitly pass None to test the None handling in __init__ forecast = inputs.XarrayForecast( - ds=sample_forecast_with_valid_time, variables=None, variable_mapping=None + ds=sample_forecast_with_valid_time, + variables=None, + variable_mapping=None, # type: ignore ) # Should be converted to empty containers @@ -2323,3 +2325,266 @@ def test_default_preprocess(): df = pd.DataFrame({"a": [1, 2, 3]}) result_df = inputs._default_preprocess(df) assert result_df is df + + +class TestGetCIRAIcechunk: + """Tests for get_cira_icechunk function.""" + + def test_invalid_model_name_raises_value_error(self): + """Test that an invalid model name raises ValueError.""" + with pytest.raises(ValueError) as exc_info: + inputs.get_cira_icechunk(model_name="INVALID_MODEL") + + assert "INVALID_MODEL" in str(exc_info.value) + assert "CIRA_MODEL_NAMES" in str(exc_info.value) + + def test_empty_model_name_raises_value_error(self): + """Test that an empty model name raises ValueError.""" + with pytest.raises(ValueError): + inputs.get_cira_icechunk(model_name="") + + def test_none_model_name_raises_error(self): + """Test that None as model name raises appropriate error.""" + with pytest.raises((ValueError, TypeError)): + inputs.get_cira_icechunk(model_name=None) # type: ignore + + def test_case_sensitive_model_name(self): + """Test that model name matching is case-sensitive.""" + # Lowercase version of a valid model name should fail + with pytest.raises(ValueError): + inputs.get_cira_icechunk(model_name="four_v200_gfs") + + # Mixed case should fail + with pytest.raises(ValueError): + inputs.get_cira_icechunk(model_name="Four_V200_GFS") + + def test_partial_model_name_raises_value_error(self): + """Test that partial model names are rejected.""" + with pytest.raises(ValueError): + inputs.get_cira_icechunk(model_name="FOUR") + + with pytest.raises(ValueError): + inputs.get_cira_icechunk(model_name="GFS") + + def test_model_name_with_extra_chars_raises_value_error(self): + """Test that model names with extra characters are rejected.""" + with pytest.raises(ValueError): + inputs.get_cira_icechunk(model_name="FOUR_v200_GFS_extra") + + with pytest.raises(ValueError): + inputs.get_cira_icechunk(model_name=" FOUR_v200_GFS") + + def test_error_message_lists_valid_model_names(self): + """Test that the error message includes the list of valid model names.""" + with pytest.raises(ValueError) as exc_info: + inputs.get_cira_icechunk(model_name="BAD_MODEL") + + error_msg = str(exc_info.value) + # Check that at least some valid model names are shown in the error + assert "FOUR_v200_GFS" in error_msg or "CIRA_MODEL_NAMES" in error_msg + + @mock.patch("extremeweatherbench.inputs.icechunk.gcs_storage") + @mock.patch("extremeweatherbench.inputs.open_icechunk_dataset_from_datatree") + @mock.patch("extremeweatherbench.inputs.XarrayForecast") + def test_valid_model_name_four_v200_gfs( + self, mock_forecast, mock_open, mock_storage + ): + """Test that FOUR_v200_GFS is a valid model name.""" + mock_storage.return_value = mock.MagicMock() + mock_open.return_value = mock.MagicMock() + mock_forecast.return_value = mock.MagicMock() + + result = inputs.get_cira_icechunk(model_name="FOUR_v200_GFS") + + assert result is not None + mock_storage.assert_called_once() + mock_open.assert_called_once() + + @mock.patch("extremeweatherbench.inputs.icechunk.gcs_storage") + @mock.patch("extremeweatherbench.inputs.open_icechunk_dataset_from_datatree") + @mock.patch("extremeweatherbench.inputs.XarrayForecast") + def test_valid_model_name_auro_v100_gfs( + self, mock_forecast, mock_open, mock_storage + ): + """Test that AURO_v100_GFS is a valid model name.""" + mock_storage.return_value = mock.MagicMock() + mock_open.return_value = mock.MagicMock() + mock_forecast.return_value = mock.MagicMock() + + result = inputs.get_cira_icechunk(model_name="AURO_v100_GFS") + + assert result is not None + + @mock.patch("extremeweatherbench.inputs.icechunk.gcs_storage") + @mock.patch("extremeweatherbench.inputs.open_icechunk_dataset_from_datatree") + @mock.patch("extremeweatherbench.inputs.XarrayForecast") + def test_all_cira_model_names_are_valid( + self, mock_forecast, mock_open, mock_storage + ): + """Test that all model names in CIRA_MODEL_NAMES are accepted.""" + mock_storage.return_value = mock.MagicMock() + mock_open.return_value = mock.MagicMock() + mock_forecast.return_value = mock.MagicMock() + + for model_name in inputs.CIRA_MODEL_NAMES: + result = inputs.get_cira_icechunk(model_name=model_name) + assert result is not None, f"Model {model_name} should be valid" + + @mock.patch("extremeweatherbench.inputs.icechunk.gcs_storage") + @mock.patch("extremeweatherbench.inputs.open_icechunk_dataset_from_datatree") + @mock.patch("extremeweatherbench.inputs.XarrayForecast") + def test_custom_name_parameter(self, mock_forecast, mock_open, mock_storage): + """Test that a custom name parameter is passed to XarrayForecast.""" + mock_storage.return_value = mock.MagicMock() + mock_open.return_value = mock.MagicMock() + mock_forecast.return_value = mock.MagicMock() + + inputs.get_cira_icechunk(model_name="FOUR_v200_GFS", name="CustomName") + + # Check that XarrayForecast was called with the custom name + call_kwargs = mock_forecast.call_args[1] + assert call_kwargs["name"] == "CustomName" + + @mock.patch("extremeweatherbench.inputs.icechunk.gcs_storage") + @mock.patch("extremeweatherbench.inputs.open_icechunk_dataset_from_datatree") + @mock.patch("extremeweatherbench.inputs.XarrayForecast") + def test_default_name_uses_model_name(self, mock_forecast, mock_open, mock_storage): + """Test that name inputs to model_name when not provided.""" + mock_storage.return_value = mock.MagicMock() + mock_open.return_value = mock.MagicMock() + mock_forecast.return_value = mock.MagicMock() + + inputs.get_cira_icechunk(model_name="FOUR_v200_GFS") + + call_kwargs = mock_forecast.call_args[1] + assert call_kwargs["name"] == "FOUR_v200_GFS" + + @mock.patch("extremeweatherbench.inputs.icechunk.gcs_storage") + @mock.patch("extremeweatherbench.inputs.open_icechunk_dataset_from_datatree") + @mock.patch("extremeweatherbench.inputs.XarrayForecast") + def test_empty_variables_list(self, mock_forecast, mock_open, mock_storage): + """Test that empty variables list is valid.""" + mock_storage.return_value = mock.MagicMock() + mock_open.return_value = mock.MagicMock() + mock_forecast.return_value = mock.MagicMock() + + result = inputs.get_cira_icechunk(model_name="FOUR_v200_GFS", variables=[]) + + assert result is not None + call_kwargs = mock_forecast.call_args[1] + assert call_kwargs["variables"] == [] + + @mock.patch("extremeweatherbench.inputs.icechunk.gcs_storage") + @mock.patch("extremeweatherbench.inputs.open_icechunk_dataset_from_datatree") + @mock.patch("extremeweatherbench.inputs.XarrayForecast") + def test_custom_variables_list(self, mock_forecast, mock_open, mock_storage): + """Test that a custom variables list is passed through.""" + mock_storage.return_value = mock.MagicMock() + mock_open.return_value = mock.MagicMock() + mock_forecast.return_value = mock.MagicMock() + + variables = ["surface_air_temperature", "air_pressure"] + inputs.get_cira_icechunk(model_name="FOUR_v200_GFS", variables=variables) + + call_kwargs = mock_forecast.call_args[1] + assert call_kwargs["variables"] == variables + + @mock.patch("extremeweatherbench.inputs.icechunk.gcs_storage") + @mock.patch("extremeweatherbench.inputs.open_icechunk_dataset_from_datatree") + @mock.patch("extremeweatherbench.inputs.XarrayForecast") + def test_custom_preprocess_function(self, mock_forecast, mock_open, mock_storage): + """Test that a custom preprocess function is passed through.""" + mock_storage.return_value = mock.MagicMock() + mock_open.return_value = mock.MagicMock() + mock_forecast.return_value = mock.MagicMock() + + def custom_preprocess(ds: xr.Dataset) -> xr.Dataset: + return ds + + inputs.get_cira_icechunk( + model_name="FOUR_v200_GFS", preprocess=custom_preprocess + ) + + call_kwargs = mock_forecast.call_args[1] + assert call_kwargs["preprocess"] == custom_preprocess + + @mock.patch("extremeweatherbench.inputs.icechunk.gcs_storage") + @mock.patch("extremeweatherbench.inputs.open_icechunk_dataset_from_datatree") + @mock.patch("extremeweatherbench.inputs.XarrayForecast") + def test_returns_xarray_forecast_object( + self, mock_forecast, mock_open, mock_storage + ): + """Test that the function returns an XarrayForecast object.""" + mock_storage.return_value = mock.MagicMock() + mock_open.return_value = mock.MagicMock() + expected_forecast = mock.MagicMock() + mock_forecast.return_value = expected_forecast + + result = inputs.get_cira_icechunk(model_name="FOUR_v200_GFS") + + assert result is expected_forecast + + @mock.patch("extremeweatherbench.inputs.icechunk.gcs_storage") + @mock.patch("extremeweatherbench.inputs.open_icechunk_dataset_from_datatree") + @mock.patch("extremeweatherbench.inputs.XarrayForecast") + def test_gcs_storage_configuration(self, mock_forecast, mock_open, mock_storage): + """Test that GCS storage is configured with correct parameters.""" + mock_storage.return_value = mock.MagicMock() + mock_open.return_value = mock.MagicMock() + mock_forecast.return_value = mock.MagicMock() + + inputs.get_cira_icechunk(model_name="FOUR_v200_GFS") + + mock_storage.assert_called_once_with( + bucket="extremeweatherbench", prefix="cira-icechunk", anonymous=True + ) + + @mock.patch("extremeweatherbench.inputs.icechunk.gcs_storage") + @mock.patch("extremeweatherbench.inputs.open_icechunk_dataset_from_datatree") + @mock.patch("extremeweatherbench.inputs.XarrayForecast") + def test_uses_cira_variable_mapping(self, mock_forecast, mock_open, mock_storage): + """Test that CIRA metadata variable mapping is used.""" + mock_storage.return_value = mock.MagicMock() + mock_open.return_value = mock.MagicMock() + mock_forecast.return_value = mock.MagicMock() + + inputs.get_cira_icechunk(model_name="FOUR_v200_GFS") + + call_kwargs = mock_forecast.call_args[1] + assert call_kwargs["variable_mapping"] == inputs.CIRA_metadata_variable_mapping + + +class TestCiraModelNames: + """Tests for CIRA_MODEL_NAMES constant.""" + + def test_cira_model_names_is_list(self): + """Test that CIRA_MODEL_NAMES is a list.""" + assert isinstance(inputs.CIRA_MODEL_NAMES, list) + + def test_cira_model_names_not_empty(self): + """Test that CIRA_MODEL_NAMES is not empty.""" + assert len(inputs.CIRA_MODEL_NAMES) > 0 + + def test_cira_model_names_contains_expected_models(self): + """Test that CIRA_MODEL_NAMES contains expected model names.""" + expected_models = [ + "FOUR_v200_GFS", + "FOUR_v200_IFS", + "AURO_v100_GFS", + "AURO_v100_IFS", + "PANG_v100_GFS", + "PANG_v100_IFS", + "GRAP_v100_GFS", + "GRAP_v100_IFS", + ] + for model in expected_models: + assert model in inputs.CIRA_MODEL_NAMES + + def test_cira_model_names_all_strings(self): + """Test that all entries in CIRA_MODEL_NAMES are strings.""" + for model in inputs.CIRA_MODEL_NAMES: + assert isinstance(model, str) + + def test_cira_model_names_no_duplicates(self): + """Test that CIRA_MODEL_NAMES has no duplicate entries.""" + assert len(inputs.CIRA_MODEL_NAMES) == len(set(inputs.CIRA_MODEL_NAMES)) From 2a2f2208e24fe7cfafe04f31784ab218e7c3918e Mon Sep 17 00:00:00 2001 From: Taylor Mandelbaum Date: Mon, 26 Jan 2026 15:08:04 -0500 Subject: [PATCH 12/15] Bump version from 0.2.0 to 0.3.0 (#324) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 1960a539..73eae194 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "extremeweatherbench" -version = "0.2.0" +version = "0.3.0" description = "Benchmarking weather and weather AI models using extreme events" keywords = [ "weather", From 0a524a33a40b1956becaa1cd0dc75183bfcf7c70 Mon Sep 17 00:00:00 2001 From: Taylor Mandelbaum Date: Mon, 26 Jan 2026 15:18:36 -0500 Subject: [PATCH 13/15] Updated API (#321) * move cache dir creation to init, rename funcs, add parallel/serial check function, update test names * update naming * add run method for backwards compatibility * update tests * add tests and cover if serial and parallel_config is not None * feat: redesign public API with hierarchical namespace submodules - Add ewb.evaluation() as main entry point (alias for ExtremeWeatherBench) - Create namespace submodules: ewb.targets, ewb.forecasts, ewb.metrics, ewb.derived, ewb.regions, ewb.cases, ewb.defaults - Expose all classes at top level for convenience (ewb.ERA5, etc.) - Add ewb.load_cases() convenience alias - Update all example files to use new import pattern - Update usage.md documentation - Maintain backward compatibility with existing imports * ruff/linting. add utils to init * add test coverage for module loading patterns * ruff * Cleanup docstrings in repo (#318) * update these docstrings * remove docstring changes markdown * update docstrings * update other docstrings * remove individualcasecollection reference, update based on develop changes * add explanation for dim reqs (#320) * Update `defaults` and `inputs` to include new CIRA icechunk store (#319) * more explicit naming, add func and model names var * add test coverage, ruff, linting * update readme for new cira approach * move cira func and model ref to inputs * update docs * module wasnt called for moved func * update tests for moving func and var * ruff * fix mock typos * update defaults var refs --- data_prep/ar_bounds.py | 6 +- data_prep/ibtracs_bounds.py | 26 +- .../practically_perfect_hindcast_from_lsr.py | 3 +- data_prep/severe_convection_bounds.py | 3 +- data_prep/subset_heat_cold_events.py | 3 +- docs/examples/applied_ar.py | 61 +-- docs/examples/applied_freeze.py | 32 +- docs/examples/applied_heatwave.py | 32 +- docs/examples/applied_severe.py | 40 +- docs/examples/applied_tc.py | 102 +++-- docs/examples/example_config.py | 21 +- docs/parallelism.md | 2 +- docs/usage.md | 86 ++-- scripts/brightband_evaluation.py | 2 +- src/extremeweatherbench/__init__.py | 339 ++++++++++++++++ src/extremeweatherbench/cases.py | 5 +- src/extremeweatherbench/defaults.py | 2 +- src/extremeweatherbench/derived.py | 2 +- src/extremeweatherbench/evaluate.py | 188 +++++---- src/extremeweatherbench/evaluate_cli.py | 6 +- src/extremeweatherbench/inputs.py | 7 +- src/extremeweatherbench/regions.py | 2 +- src/extremeweatherbench/sources/base.py | 2 +- .../sources/pandas_dataframe.py | 4 +- .../sources/polars_lazyframe.py | 4 +- .../sources/xarray_dataarray.py | 3 +- .../sources/xarray_dataset.py | 2 +- tests/test_evaluate.py | 370 ++++++++++-------- tests/test_evaluate_cli.py | 34 +- tests/test_init.py | 259 ++++++++++++ tests/test_integration.py | 4 +- 31 files changed, 1210 insertions(+), 442 deletions(-) create mode 100644 tests/test_init.py diff --git a/data_prep/ar_bounds.py b/data_prep/ar_bounds.py index b642d334..1c4cf5cf 100644 --- a/data_prep/ar_bounds.py +++ b/data_prep/ar_bounds.py @@ -17,7 +17,11 @@ from dask.distributed import Client from matplotlib.patches import Rectangle -from extremeweatherbench import cases, derived, inputs, regions, utils +import extremeweatherbench.cases as cases +import extremeweatherbench.derived as derived +import extremeweatherbench.inputs as inputs +import extremeweatherbench.regions as regions +import extremeweatherbench.utils as utils from extremeweatherbench.events import atmospheric_river as ar logging.basicConfig() diff --git a/data_prep/ibtracs_bounds.py b/data_prep/ibtracs_bounds.py index 0a2ecc0a..0dc962d9 100644 --- a/data_prep/ibtracs_bounds.py +++ b/data_prep/ibtracs_bounds.py @@ -4,6 +4,7 @@ import logging import re from importlib import resources +from typing import TYPE_CHECKING import cartopy.crs as ccrs import cartopy.feature as cfeature @@ -14,8 +15,11 @@ import yaml from matplotlib.patches import Rectangle +import extremeweatherbench as ewb import extremeweatherbench.data -from extremeweatherbench import cases, inputs, regions, utils + +if TYPE_CHECKING: + from extremeweatherbench.regions import Region logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -67,7 +71,7 @@ def calculate_extent_bounds( bottom_lat: float, top_lat: float, extent_buffer: float = 250, -) -> regions.Region: +) -> Region: """Calculate extent bounds with buffer. Args: @@ -94,9 +98,9 @@ def calculate_extent_bounds( calculate_end_point(bottom_lat, right_lon, 90, extent_buffer), 1 ) - new_left_lon = np.round(utils.convert_longitude_to_360(new_left_lon), 1) - new_right_lon = np.round(utils.convert_longitude_to_360(new_right_lon), 1) - new_box = regions.BoundingBoxRegion( + new_left_lon = np.round(ewb.utils.convert_longitude_to_360(new_left_lon), 1) + new_right_lon = np.round(ewb.utils.convert_longitude_to_360(new_right_lon), 1) + new_box = ewb.regions.BoundingBoxRegion( new_bottom_lat, new_top_lat, new_left_lon, new_right_lon ) return new_box @@ -164,10 +168,10 @@ def load_and_process_ibtracs_data(): """ logger.info("Loading IBTrACS data...") - IBTRACS = inputs.IBTrACS( - source=inputs.IBTRACS_URI, + IBTRACS = ewb.inputs.IBTrACS( + source=ewb.inputs.IBTRACS_URI, variables=["vmax", "slp"], - variable_mapping=inputs.IBTrACS_metadata_variable_mapping, + variable_mapping=ewb.inputs.IBTrACS_metadata_variable_mapping, storage_options={}, ) @@ -177,7 +181,7 @@ def load_and_process_ibtracs_data(): # Get all storms from 2020 - 2025 seasons all_storms_2020_2025_lf = IBTRACS_lf.filter( (pl.col("SEASON").cast(pl.Int32) >= 2020) - ).select(inputs.IBTrACS_metadata_variable_mapping.values()) + ).select(ewb.inputs.IBTrACS_metadata_variable_mapping.values()) schema = all_storms_2020_2025_lf.collect_schema() # Convert pressure and surface wind columns to float, replacing " " with null @@ -464,7 +468,7 @@ def find_storm_bounds_for_case(storm_name, storm_bounds, all_storms_df): # If we found both, merge them by taking the bounding box that # encompasses both if bounds1 is not None and bounds2 is not None: - merged_bbox = regions.BoundingBoxRegion( + merged_bbox = ewb.regions.BoundingBoxRegion( latitude_min=min( bounds1.iloc[0].latitude_min, bounds2.iloc[0].latitude_min ), @@ -537,7 +541,7 @@ def update_cases_with_storm_bounds(storm_bounds, all_storms_df): """ logger.info("Updating cases with storm bounds...") - cases_all = cases.load_ewb_events_yaml_into_case_list() + cases_all = ewb.cases.load_ewb_events_yaml_into_case_list() cases_new = cases_all.copy() # Update the yaml cases with storm bounds from IBTrACS data diff --git a/data_prep/practically_perfect_hindcast_from_lsr.py b/data_prep/practically_perfect_hindcast_from_lsr.py index db114242..1e96cbda 100644 --- a/data_prep/practically_perfect_hindcast_from_lsr.py +++ b/data_prep/practically_perfect_hindcast_from_lsr.py @@ -11,7 +11,8 @@ from scipy.ndimage import gaussian_filter from tqdm.auto import tqdm -from extremeweatherbench import inputs, utils +import extremeweatherbench.inputs as inputs +import extremeweatherbench.utils as utils def sparse_practically_perfect_hindcast( diff --git a/data_prep/severe_convection_bounds.py b/data_prep/severe_convection_bounds.py index 12e1bcf9..11985632 100644 --- a/data_prep/severe_convection_bounds.py +++ b/data_prep/severe_convection_bounds.py @@ -17,7 +17,8 @@ import yaml from scipy.ndimage import label -from extremeweatherbench import calc, cases +import extremeweatherbench.calc as calc +import extremeweatherbench.cases as cases # Radius of Earth in km (mean radius) EARTH_RADIUS_KM = 6371.0 diff --git a/data_prep/subset_heat_cold_events.py b/data_prep/subset_heat_cold_events.py index e109889b..ee6b4add 100644 --- a/data_prep/subset_heat_cold_events.py +++ b/data_prep/subset_heat_cold_events.py @@ -13,7 +13,8 @@ from matplotlib import dates as mdates from mpl_toolkits.axes_grid1 import make_axes_locatable -from extremeweatherbench import cases, utils +import extremeweatherbench.cases as cases +import extremeweatherbench.utils as utils sns.set_theme(style="whitegrid", context="talk") diff --git a/docs/examples/applied_ar.py b/docs/examples/applied_ar.py index 75239f47..a99e5a0e 100644 --- a/docs/examples/applied_ar.py +++ b/docs/examples/applied_ar.py @@ -4,7 +4,7 @@ import numpy as np import xarray as xr -from extremeweatherbench import cases, derived, evaluate, inputs, metrics +import extremeweatherbench as ewb # %% @@ -38,85 +38,86 @@ def _preprocess_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: # Load case data from the default events.yaml -# Users can also define their own cases -case_yaml = cases.load_ewb_events_yaml_into_case_list() -case_yaml = [n for n in case_yaml if n.case_id_number == 114] -case_yaml[0].start_date = datetime.datetime(2022, 12, 27, 11, 0, 0) -case_yaml[0].end_date = datetime.datetime(2022, 12, 27, 13, 0, 0) +# Users can also define their own cases_dict structure +case_yaml = ewb.load_cases() +case_list = [n for n in case_yaml if n.case_id_number == 114] + +case_list[0].start_date = datetime.datetime(2022, 12, 27, 11, 0, 0) +case_list[0].end_date = datetime.datetime(2022, 12, 27, 13, 0, 0) # Define ERA5 target -era5_target = inputs.ERA5( +era5_target = ewb.targets.ERA5( variables=[ - derived.AtmosphericRiverVariables( + ewb.derived.AtmosphericRiverVariables( output_variables=["atmospheric_river_land_intersection"] ) ], ) # Define forecast (HRES) -hres_forecast = inputs.ZarrForecast( +hres_forecast = ewb.forecasts.ZarrForecast( source="gs://weatherbench2/datasets/hres/2016-2022-0012-1440x721.zarr", name="HRES", variables=[ - derived.AtmosphericRiverVariables( + ewb.derived.AtmosphericRiverVariables( output_variables=["atmospheric_river_land_intersection"] ) ], - variable_mapping=inputs.HRES_metadata_variable_mapping, + variable_mapping=ewb.HRES_metadata_variable_mapping, ) -grap_forecast = inputs.KerchunkForecast( +grap_forecast = ewb.forecasts.KerchunkForecast( name="Graphcast", source="gs://extremeweatherbench/GRAP_v100_IFS.parq", variables=[ - derived.AtmosphericRiverVariables( + ewb.derived.AtmosphericRiverVariables( output_variables=["atmospheric_river_land_intersection"] ) ], - variable_mapping=inputs.CIRA_metadata_variable_mapping, + variable_mapping=ewb.CIRA_metadata_variable_mapping, storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, preprocess=_preprocess_cira_forecast_dataset, ) -pang_forecast = inputs.KerchunkForecast( +pang_forecast = ewb.forecasts.KerchunkForecast( name="Pangu", source="gs://extremeweatherbench/PANG_v100_IFS.parq", variables=[ - derived.AtmosphericRiverVariables( + ewb.derived.AtmosphericRiverVariables( output_variables=["atmospheric_river_land_intersection"] ) ], - variable_mapping=inputs.CIRA_metadata_variable_mapping, + variable_mapping=ewb.CIRA_metadata_variable_mapping, storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, preprocess=_preprocess_cira_forecast_dataset, ) # Create a list of evaluation objects for atmospheric river ar_evaluation_objects = [ - inputs.EvaluationObject( + ewb.EvaluationObject( event_type="atmospheric_river", metric_list=[ - metrics.CriticalSuccessIndex(), - metrics.EarlySignal(), - metrics.SpatialDisplacement(), + ewb.metrics.CriticalSuccessIndex(), + ewb.metrics.EarlySignal(), + ewb.metrics.SpatialDisplacement(), ], target=era5_target, forecast=hres_forecast, ), - inputs.EvaluationObject( + ewb.EvaluationObject( event_type="atmospheric_river", metric_list=[ - metrics.CriticalSuccessIndex(), - metrics.EarlySignal(), - metrics.SpatialDisplacement(), + ewb.metrics.CriticalSuccessIndex(), + ewb.metrics.EarlySignal(), + ewb.metrics.SpatialDisplacement(), ], target=era5_target, forecast=grap_forecast, ), - inputs.EvaluationObject( + ewb.EvaluationObject( event_type="atmospheric_river", metric_list=[ - metrics.CriticalSuccessIndex(), - metrics.EarlySignal(), - metrics.SpatialDisplacement(), + ewb.metrics.CriticalSuccessIndex(), + ewb.metrics.EarlySignal(), + ewb.metrics.SpatialDisplacement(), ], target=era5_target, forecast=pang_forecast, @@ -126,7 +127,7 @@ def _preprocess_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: if __name__ == "__main__": # Initialize ExtremeWeatherBench; will only run on cases with event_type # atmospheric_river - ar_ewb = evaluate.ExtremeWeatherBench( + ar_ewb = ewb.evaluation( case_metadata=case_yaml, evaluation_objects=ar_evaluation_objects, ) diff --git a/docs/examples/applied_freeze.py b/docs/examples/applied_freeze.py index 8b2325cd..864d76f7 100644 --- a/docs/examples/applied_freeze.py +++ b/docs/examples/applied_freeze.py @@ -1,55 +1,55 @@ import logging import operator -from extremeweatherbench import cases, defaults, evaluate, inputs, metrics +import extremeweatherbench as ewb # Set the logger level to INFO logger = logging.getLogger("extremeweatherbench") logger.setLevel(logging.INFO) # Load case data from the default events.yaml -# Users can also define their own cases -case_yaml = cases.load_ewb_events_yaml_into_case_list() +# Users can also define their own cases_dict structure +case_yaml = ewb.load_cases() # Define targets # ERA5 target -era5_freeze_target = inputs.ERA5( +era5_freeze_target = ewb.targets.ERA5( variables=["surface_air_temperature"], chunks=None, ) # GHCN target -ghcn_freeze_target = inputs.GHCN(variables=["surface_air_temperature"]) +ghcn_freeze_target = ewb.targets.GHCN(variables=["surface_air_temperature"]) # Define forecast (FCNv2 CIRA Virtualizarr) -fcnv2_forecast = inputs.KerchunkForecast( +fcnv2_forecast = ewb.forecasts.KerchunkForecast( name="fcnv2_forecast", source="gs://extremeweatherbench/FOUR_v200_GFS.parq", variables=["surface_air_temperature"], - variable_mapping=inputs.CIRA_metadata_variable_mapping, + variable_mapping=ewb.CIRA_metadata_variable_mapping, storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, - preprocess=defaults._preprocess_cira_forecast_dataset, + preprocess=ewb.defaults._preprocess_bb_cira_forecast_dataset, ) # Load the climatology for DurationMeanError -climatology = defaults.get_climatology(quantile=0.85) +climatology = ewb.get_climatology(quantile=0.85) # Define the metrics metrics_list = [ - metrics.RootMeanSquaredError(), - metrics.MinimumMeanAbsoluteError(), - metrics.DurationMeanError(threshold_criteria=climatology, op_func=operator.le), + ewb.metrics.RootMeanSquaredError(), + ewb.metrics.MinimumMeanAbsoluteError(), + ewb.metrics.DurationMeanError(threshold_criteria=climatology, op_func=operator.le), ] # Create a list of evaluation objects for freeze freeze_evaluation_object = [ - inputs.EvaluationObject( + ewb.EvaluationObject( event_type="freeze", metric_list=metrics_list, target=ghcn_freeze_target, forecast=fcnv2_forecast, ), - inputs.EvaluationObject( + ewb.EvaluationObject( event_type="freeze", metric_list=metrics_list, target=era5_freeze_target, @@ -59,13 +59,13 @@ if __name__ == "__main__": # Initialize ExtremeWeatherBench runner instance - ewb = evaluate.ExtremeWeatherBench( + freeze_ewb = ewb.evaluation( case_metadata=case_yaml, evaluation_objects=freeze_evaluation_object, ) # Run the workflow - outputs = ewb.run(parallel_config={"backend": "loky", "n_jobs": 1}) + outputs = freeze_ewb.run(parallel_config={"backend": "loky", "n_jobs": 1}) # Print the outputs; can be saved if desired outputs.to_csv("freeze_outputs.csv") diff --git a/docs/examples/applied_heatwave.py b/docs/examples/applied_heatwave.py index 7f44b081..22c9b809 100644 --- a/docs/examples/applied_heatwave.py +++ b/docs/examples/applied_heatwave.py @@ -1,56 +1,56 @@ import logging import operator -from extremeweatherbench import cases, defaults, evaluate, inputs, metrics +import extremeweatherbench as ewb # Set the logger level to INFO logger = logging.getLogger("extremeweatherbench") logger.setLevel(logging.INFO) # Load case data from the default events.yaml -# Users can also define their own cases -case_yaml = cases.load_ewb_events_yaml_into_case_list() +# Users can also define their own cases_dict structure +case_yaml = ewb.load_cases() # Define targets # ERA5 target -era5_heatwave_target = inputs.ERA5( +era5_heatwave_target = ewb.targets.ERA5( variables=["surface_air_temperature"], chunks=None, ) # GHCN target -ghcn_heatwave_target = inputs.GHCN( +ghcn_heatwave_target = ewb.targets.GHCN( variables=["surface_air_temperature"], ) # Define forecast (HRES) -hres_forecast = inputs.ZarrForecast( +hres_forecast = ewb.forecasts.ZarrForecast( name="hres_forecast", source="gs://weatherbench2/datasets/hres/2016-2022-0012-1440x721.zarr", variables=["surface_air_temperature"], - variable_mapping=inputs.HRES_metadata_variable_mapping, + variable_mapping=ewb.HRES_metadata_variable_mapping, ) # Load the climatology for DurationMeanError -climatology = defaults.get_climatology(quantile=0.85) +climatology = ewb.get_climatology(quantile=0.85) # Define the metrics metrics_list = [ - metrics.MaximumMeanAbsoluteError(), - metrics.RootMeanSquaredError(), - metrics.DurationMeanError(threshold_criteria=climatology, op_func=operator.ge), - metrics.MaximumLowestMeanAbsoluteError(), + ewb.metrics.MaximumMeanAbsoluteError(), + ewb.metrics.RootMeanSquaredError(), + ewb.metrics.DurationMeanError(threshold_criteria=climatology, op_func=operator.ge), + ewb.metrics.MaximumLowestMeanAbsoluteError(), ] # Create a list of evaluation objects for heatwave heatwave_evaluation_object = [ - inputs.EvaluationObject( + ewb.EvaluationObject( event_type="heat_wave", metric_list=metrics_list, target=ghcn_heatwave_target, forecast=hres_forecast, ), - inputs.EvaluationObject( + ewb.EvaluationObject( event_type="heat_wave", metric_list=metrics_list, target=era5_heatwave_target, @@ -59,11 +59,11 @@ ] if __name__ == "__main__": # Initialize ExtremeWeatherBench - ewb = evaluate.ExtremeWeatherBench( + heatwave_ewb = ewb.evaluation( case_metadata=case_yaml, evaluation_objects=heatwave_evaluation_object, ) # Run the workflow - outputs = ewb.run(parallel_config={"backend": "loky", "n_jobs": 2}) + outputs = heatwave_ewb.run(parallel_config={"backend": "loky", "n_jobs": 2}) outputs.to_csv("applied_heatwave_outputs.csv") diff --git a/docs/examples/applied_severe.py b/docs/examples/applied_severe.py index 6cdba3f9..f6a07003 100644 --- a/docs/examples/applied_severe.py +++ b/docs/examples/applied_severe.py @@ -1,6 +1,6 @@ import logging -from extremeweatherbench import cases, derived, evaluate, inputs, metrics +import extremeweatherbench as ewb # Set the logger level to INFO logger = logging.getLogger("extremeweatherbench") @@ -8,45 +8,45 @@ # Load case data from the default events.yaml -case_yaml = cases.load_ewb_events_yaml_into_case_list() -case_yaml = [n for n in case_yaml if n.case_id_number == 305] +case_yaml = ewb.load_cases() +case_list = [n for n in case_yaml if n.case_id_number == 305] # Define PPH target -pph_target = inputs.PPH( +pph_target = ewb.targets.PPH( variables=["practically_perfect_hindcast"], ) # Define LSR target -lsr_target = inputs.LSR() +lsr_target = ewb.targets.LSR() # Define HRES forecast -hres_forecast = inputs.ZarrForecast( +hres_forecast = ewb.forecasts.ZarrForecast( name="hres_forecast", source="gs://weatherbench2/datasets/hres/2016-2022-0012-1440x721.zarr", - variables=[derived.CravenBrooksSignificantSevere(layer_depth=100)], - variable_mapping=inputs.HRES_metadata_variable_mapping, + variables=[ewb.derived.CravenBrooksSignificantSevere(layer_depth=100)], + variable_mapping=ewb.HRES_metadata_variable_mapping, storage_options={"remote_options": {"anon": True}}, ) # Define pph metrics as thresholdmetric to share scores contingency table pph_metrics = [ - metrics.ThresholdMetric( + ewb.metrics.ThresholdMetric( metrics=[ - metrics.CriticalSuccessIndex, - metrics.FalseAlarmRatio, + ewb.metrics.CriticalSuccessIndex, + ewb.metrics.FalseAlarmRatio, ], forecast_threshold=15000, target_threshold=0.3, ), - metrics.EarlySignal(threshold=15000), + ewb.metrics.EarlySignal(threshold=15000), ] # Define LSR metrics as thresholdmetric to share scores contingency table lsr_metrics = [ - metrics.ThresholdMetric( + ewb.metrics.ThresholdMetric( metrics=[ - metrics.TruePositives, - metrics.FalseNegatives, + ewb.metrics.TruePositives, + ewb.metrics.FalseNegatives, ], forecast_threshold=15000, target_threshold=0.5, @@ -56,7 +56,7 @@ # Define evaluation objects for severe convection: # One evaluation object for PPH pph_evaluation_objects = [ - inputs.EvaluationObject( + ewb.EvaluationObject( event_type="severe_convection", metric_list=pph_metrics, target=pph_target, @@ -66,7 +66,7 @@ # One evaluation object for LSR lsr_evaluation_objects = [ - inputs.EvaluationObject( + ewb.EvaluationObject( event_type="severe_convection", metric_list=lsr_metrics, target=lsr_target, @@ -76,14 +76,14 @@ if __name__ == "__main__": # Initialize ExtremeWeatherBench with both evaluation objects - ewb = evaluate.ExtremeWeatherBench( - case_metadata=case_yaml, + severe_ewb = ewb.evaluation( + case_metadata=case_list, evaluation_objects=lsr_evaluation_objects + pph_evaluation_objects, ) logger.info("Starting EWB run") # Run the workflow with parllel_config backend set to dask - outputs = ewb.run(parallel_config={"backend": "loky", "n_jobs": 3}) + outputs = severe_ewb.run(parallel_config={"backend": "loky", "n_jobs": 3}) # Save the results to a CSV file outputs.to_csv("applied_severe_convection_results.csv") diff --git a/docs/examples/applied_tc.py b/docs/examples/applied_tc.py index 79d4d3b4..e0d17d3c 100644 --- a/docs/examples/applied_tc.py +++ b/docs/examples/applied_tc.py @@ -1,56 +1,102 @@ import logging -from extremeweatherbench import cases, defaults, derived, evaluate, inputs, metrics +import numpy as np +import xarray as xr + +import extremeweatherbench as ewb # Set the logger level to INFO logger = logging.getLogger("extremeweatherbench") logger.setLevel(logging.INFO) -# Load the case list from the YAML file -case_yaml = cases.load_ewb_events_yaml_into_case_list() +# Preprocessing function for CIRA data that includes geopotential thickness calculation +# required for tropical cyclone tracks +def _preprocess_bb_cira_tc_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: + """An example preprocess function that renames the time coordinate to lead_time, + creates a valid_time coordinate, and sets the lead time range and resolution not + present in the original dataset. + + Args: + ds: The forecast dataset to rename. + + Returns: + The renamed forecast dataset. + """ + ds = ds.rename({"time": "lead_time"}) + # The evaluation configuration is used to set the lead time range and resolution. + ds["lead_time"] = np.array( + [i for i in range(0, 241, 6)], dtype="timedelta64[h]" + ).astype("timedelta64[ns]") + ds["geopotential_thickness"] = ewb.calc.geopotential_thickness( + ds["z"], top_level_value=300, bottom_level_value=500 + ) + return ds + + +# Preprocessing function for HRES data that includes geopotential thickness calculation +# required for tropical cyclone tracks +def _preprocess_hres_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: + """An example preprocess function that renames the time coordinate to lead_time, + creates a valid_time coordinate, and sets the lead time range and resolution not + present in the original dataset. + + Args: + ds: The forecast dataset to rename. + """ + ds["geopotential_thickness"] = ewb.calc.geopotential_thickness( + ds["geopotential"], + top_level_value=300, + bottom_level_value=500, + geopotential=True, + ) + return ds + + +# Load the case collection from the YAML file +case_yaml = ewb.load_cases() # Select single case (TC Ida) -case_yaml = [n for n in case_yaml if n.case_id_number == 220] +case_list = [n for n in case_yaml if n.case_id_number == 220] # Define IBTrACS target, no arguments needed as defaults are sufficient -ibtracs_target = inputs.IBTrACS() +ibtracs_target = ewb.targets.IBTrACS() # Define HRES forecast -hres_forecast = inputs.ZarrForecast( +hres_forecast = ewb.forecasts.ZarrForecast( name="hres_forecast", source="gs://weatherbench2/datasets/hres/2016-2022-0012-1440x721.zarr", # Define tropical cyclone track derivedvariable to include in the forecast - variables=[derived.TropicalCycloneTrackVariables()], + variables=[ewb.derived.TropicalCycloneTrackVariables()], # Define metadata variable mapping for HRES forecast - variable_mapping=inputs.HRES_metadata_variable_mapping, + variable_mapping=ewb.HRES_metadata_variable_mapping, storage_options={"remote_options": {"anon": True}}, # Preprocess the HRES forecast to include geopotential thickness calculation - preprocess=defaults._preprocess_hres_tc_forecast_dataset, + preprocess=ewb.defaults._preprocess_hres_tc_forecast_dataset, ) -# Define FCNv2 forecast -fcnv2_forecast = inputs.KerchunkForecast( +# Define FCNv2 forecast, this is the old version for reference only +fcnv2_forecast = ewb.forecasts.KerchunkForecast( name="fcn_forecast", source="gs://extremeweatherbench/FOUR_v200_GFS.parq", - variables=[derived.TropicalCycloneTrackVariables()], + variables=[ewb.derived.TropicalCycloneTrackVariables()], # Define metadata variable mapping for FCNv2 forecast - variable_mapping=inputs.CIRA_metadata_variable_mapping, + variable_mapping=ewb.CIRA_metadata_variable_mapping, # Preprocess the FCNv2 forecast to include geopotential thickness calculation - preprocess=defaults._preprocess_cira_tc_forecast_dataset, + preprocess=ewb.defaults._preprocess_cira_tc_forecast_dataset, storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, ) # Define Pangu forecast -pangu_forecast = inputs.KerchunkForecast( +pangu_forecast = ewb.forecasts.KerchunkForecast( name="pangu_forecast", source="gs://extremeweatherbench/PANG_v100_GFS.parq", - variables=[derived.TropicalCycloneTrackVariables()], + variables=[ewb.derived.TropicalCycloneTrackVariables()], # Define metadata variable mapping for Pangu forecast - variable_mapping=inputs.CIRA_metadata_variable_mapping, + variable_mapping=ewb.CIRA_metadata_variable_mapping, # Preprocess the Pangu forecast to include geopotential thickness calculation # which uses the same preprocessing function as the FCNv2 forecast - preprocess=defaults._preprocess_cira_tc_forecast_dataset, + preprocess=ewb.defaults._preprocess_cira_tc_forecast_dataset, storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, ) @@ -60,11 +106,11 @@ # the evaluation to occur, in the case of multiple landfalls, for the next landfall in # time to be evaluated against composite_landfall_metrics = [ - metrics.LandfallMetric( + ewb.metrics.LandfallMetric( metrics=[ - metrics.LandfallIntensityMeanAbsoluteError, - metrics.LandfallTimeMeanError, - metrics.LandfallDisplacement, + ewb.metrics.LandfallIntensityMeanAbsoluteError, + ewb.metrics.LandfallTimeMeanError, + ewb.metrics.LandfallDisplacement, ], approach="next", # Set the intensity variable to use for the metric @@ -77,21 +123,21 @@ # the relevant cases inside the events YAML file tc_evaluation_object = [ # HRES forecast - inputs.EvaluationObject( + ewb.EvaluationObject( event_type="tropical_cyclone", metric_list=composite_landfall_metrics, target=ibtracs_target, forecast=hres_forecast, ), # Pangu forecast - inputs.EvaluationObject( + ewb.EvaluationObject( event_type="tropical_cyclone", metric_list=composite_landfall_metrics, target=ibtracs_target, forecast=pangu_forecast, ), # FCNv2 forecast - inputs.EvaluationObject( + ewb.EvaluationObject( event_type="tropical_cyclone", metric_list=composite_landfall_metrics, target=ibtracs_target, @@ -101,13 +147,13 @@ if __name__ == "__main__": # Initialize ExtremeWeatherBench - ewb = evaluate.ExtremeWeatherBench( - case_metadata=case_yaml, + tc_ewb = ewb.evaluation( + case_metadata=case_list, evaluation_objects=tc_evaluation_object, ) logger.info("Starting EWB run") # Run the workflow with parallel_config backend set to dask - outputs = ewb.run( + outputs = tc_ewb.run( parallel_config={"backend": "loky", "n_jobs": 3}, ) outputs.to_csv("tc_metric_test_results.csv") diff --git a/docs/examples/example_config.py b/docs/examples/example_config.py index 3aff5967..f41d5e44 100644 --- a/docs/examples/example_config.py +++ b/docs/examples/example_config.py @@ -7,31 +7,29 @@ ewb --config-file example_config.py """ -from extremeweatherbench import cases, inputs, metrics +import extremeweatherbench as ewb # Define targets (observation data) -era5_heatwave_target = inputs.ERA5( +era5_heatwave_target = ewb.targets.ERA5( variables=["surface_air_temperature"], chunks=None, ) # Define forecasts -fcnv2_forecast = inputs.KerchunkForecast( +fcnv2_forecast = ewb.forecasts.KerchunkForecast( name="fcnv2_forecast", source="gs://extremeweatherbench/FOUR_v200_GFS.parq", variables=["surface_air_temperature"], - variable_mapping=inputs.CIRA_metadata_variable_mapping, + variable_mapping=ewb.CIRA_metadata_variable_mapping, ) # Define evaluation objects evaluation_objects = [ - inputs.EvaluationObject( + ewb.EvaluationObject( event_type="heat_wave", metric_list=[ - metrics.MaximumMeanAbsoluteError(), - metrics.RootMeanSquaredError(), - metrics.OnsetMeanError(), - metrics.DurationMeanError(), + ewb.metrics.MaximumMeanAbsoluteError(), + ewb.metrics.RootMeanSquaredError(), ], target=era5_heatwave_target, forecast=fcnv2_forecast, @@ -39,8 +37,9 @@ ] # Load case data from the default events.yaml -# Users can also define their own cases -cases_list = cases.load_ewb_events_yaml_into_case_list() +# Users can also define their own cases_dict structure +cases_list = ewb.load_cases() + # Alternatively, users could define custom cases like this: # cases_list = [ # { diff --git a/docs/parallelism.md b/docs/parallelism.md index 98ceb4be..fae25fea 100644 --- a/docs/parallelism.md +++ b/docs/parallelism.md @@ -35,7 +35,7 @@ ewb = evaluate.ExtremeWeatherBench( # The larger the machine, the larger n_jobs can be (a bit of an oversimplification) parallel_config = {"backend":"loky","n_jobs":len(evaluation_objects)} -outputs = ewb.run(parallel_config=parallel_config) +outputs = ewb.run_evaluation(parallel_config=parallel_config) ``` The _safest_ approach is to run EWB in serial, with `n_jobs` set to 1. `Dask` will still be invoked during each `CaseOperator` when the case executes and computes the directed acyclic graph, only one at a time. That said, for evaluations with more cases this approach would likely be too time-consuming. \ No newline at end of file diff --git a/docs/usage.md b/docs/usage.md index d5e37778..ddd372fa 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -10,16 +10,17 @@ tropical cyclones, and atmospheric rivers: ```python -from extremeweatherbench import evaluate, defaults, cases +import extremeweatherbench as ewb -eval_objects = defaults.get_brightband_evaluation_objects() +eval_objects = ewb.get_brightband_evaluation_objects() +cases = ewb.load_cases() -cases = cases.load_ewb_events_yaml_into_case_list() -ewb = ExtremeWeatherBench(cases=cases, -evaluation_objects=eval_objects) - -outputs = ewb.run() +runner = ewb.evaluation( + case_metadata=cases, + evaluation_objects=eval_objects +) +outputs = runner.run() outputs.to_csv('your_outputs.csv') ``` @@ -28,6 +29,30 @@ or: ```bash ewb --default ``` + +## API Overview + +ExtremeWeatherBench provides a hierarchical API for accessing its components: + +```python +import extremeweatherbench as ewb + +# Main evaluation entry point +ewb.evaluation(...) # Alias for ExtremeWeatherBench class + +# Hierarchical access via namespaces +ewb.targets.ERA5(...) # Target classes +ewb.forecasts.ZarrForecast(...) # Forecast classes +ewb.metrics.MeanAbsoluteError() # Metric classes +ewb.derived.AtmosphericRiverVariables() # Derived variables +ewb.regions.BoundingBoxRegion(...) # Region classes +ewb.cases.IndividualCase # Case metadata classes + +# Also available at top level for convenience +ewb.ERA5(...) +ewb.ZarrForecast(...) +ewb.load_cases() +``` ## Running an Evaluation for a Single Event Type ExtremeWeatherBench has default event types and cases for heat waves, freezes, severe convection, tropical cyclones, and atmospheric rivers. @@ -39,20 +64,20 @@ ExtremeWeatherBench requires forecasts to have `init_time`, `lead_time`, `latitu Targets require at least a `valid_time` with at least one spatial dimension. Examples include `location`, `station`, or (`latitude`, `longitude`). Forecasts are aligned to targets during the steps immediately prior to evaluating a metric. ```python -from extremeweatherbench import inputs +import extremeweatherbench as ewb ``` There are three built-in `ForecastBase` classes to set up a forecast: `ZarrForecast`, `XarrayForecast`, and `KerchunkForecast`. Here is an example of a `ZarrForecast`, using Weatherbench2's HRES zarr store: ```python -hres_forecast = inputs.ZarrForecast( +hres_forecast = ewb.forecasts.ZarrForecast( source="gs://weatherbench2/datasets/hres/2016-2022-0012-1440x721.zarr", name="HRES", variables=["surface_air_temperature"], - variable_mapping=inputs.HRES_metadata_variable_mapping, # built-in mapping available + variable_mapping=ewb.HRES_metadata_variable_mapping, # built-in mapping available storage_options={"remote_options": {"anon": True}}, - ) ``` + There are required arguments, namely: - `source` @@ -67,8 +92,8 @@ There are required arguments, namely: Next, a target dataset must be defined as well to evaluate against. For this evaluation, we'll use ERA5: ```python -era5_heatwave_target = inputs.ERA5( - source=inputs.ARCO_ERA5_FULL_URI, +era5_heatwave_target = ewb.targets.ERA5( + source=ewb.ARCO_ERA5_FULL_URI, variables=["surface_air_temperature"], storage_options={"remote_options": {"anon": True}}, chunks=None, @@ -87,48 +112,53 @@ Or (if defining variables as arguments to the metrics): era5_heatwave_target = inputs.ERA5() ``` -> **Detailed Explanation**: Similarly to forecasts, we need to define the `source`, which here is the ARCO ERA5 provided by Google. `variables` are used to subset `inputs.ERA5` in an evaluation; `variable_mapping` defaults to `inputs.ERA5_metadata_variable_mapping` for many existing variables and likely is not required to be set unless your use case is for less common variables. Both forecasts and targets, if relevant, have an optional `chunks` parameter which defaults to what should be the most efficient value - usually `None` or `'auto'`, but can be changed as seen above. *If using the ARCO ERA5 and setting `chunks=None`, it is critical to order your subsetting by variables -> time -> `.sel` or `.isel` latitude & longitude -> rechunk. [See this Github comment](https://github.com/pydata/xarray/issues/8902#issuecomment-2036435045). +> **Detailed Explanation**: Similarly to forecasts, we need to define the `source`, which here is the ARCO ERA5 provided by Google. `variables` are used to subset `ewb.inputs.ERA5` in an evaluation; `variable_mapping` defaults to `ewb.inputs.ERA5_metadata_variable_mapping` for many existing variables and likely is not required to be set unless your use case is for less common variables. Both forecasts and targets, if relevant, have an optional `chunks` parameter which defaults to what should be the most efficient value - usually `None` or `'auto'`, but can be changed as seen above. *If using the ARCO ERA5 and setting `chunks=None`, it is critical to order your subsetting by variables -> time -> `.sel` or `.isel` latitude & longitude -> rechunk. [See this Github comment](https://github.com/pydata/xarray/issues/8902#issuecomment-2036435045). We then set up an `EvaluationObject` list: ```python -from extremeweatherbench import metrics - heatwave_evaluation_list = [ - inputs.EvaluationObject( + ewb.EvaluationObject( event_type="heat_wave", metric_list=[ - metrics.MaximumMeanAbsoluteError(), - metrics.RootMeanSquaredError(), - metrics.MaximumLowestMeanAbsoluteError() + ewb.metrics.MaximumMeanAbsoluteError(), + ewb.metrics.RootMeanSquaredError(), + ewb.metrics.MaximumLowestMeanAbsoluteError() ], target=era5_heatwave_target, forecast=hres_forecast, ), ] ``` + Which includes the event_type of interest (as defined in the case dictionary or YAML file used), the list of metrics to run, one target, and one forecast. There can be multiple `EvaluationObjects` which are used for an evaluation run. Plugging these all in: ```python -from extremeweatherbench import cases, evaluate -case_list = cases.load_ewb_events_yaml_into_case_list() - +case_yaml = ewb.load_cases() -ewb_instance = evaluate.ExtremeWeatherBench( - cases=case_list, +ewb_instance = ewb.evaluation( + case_metadata=case_yaml, evaluation_objects=heatwave_evaluation_list, ) outputs = ewb_instance.run() - outputs.to_csv('your_file_name.csv') ``` -Where the EWB default events YAML file is loaded in using a built-in utility helper function, then applied to an instance of `evaluate.ExtremeWeatherBench` along with the `EvaluationObject` list. Finally, we trigger the evaluation with the `.run()` method, where defaults are typically sufficient to run with a small to moderate-sized virtual machine. after subsetting and prior to metric calculation. +Where the EWB default events YAML file is loaded in using `ewb.load_cases()`, then applied to an instance of `ewb.evaluation` along with the `EvaluationObject` list. Finally, we run the evaluation with the `.run()` method, where defaults are typically sufficient to run with a small to moderate-sized virtual machine. Running locally is feasible but is typically bottlenecked heavily by IO and network bandwidth. Even on a gigabit connection, the rate of data access is significantly slower compared to within a cloud provider VM. -The outputs are returned as a pandas DataFrame and can be manipulated in the script, a notebook, or post-hoc after saving it. +The outputs are returned as a pandas DataFrame and can be manipulated in the script, a notebook, etc. + +## Backward Compatibility + +All existing import patterns remain functional: + +```python +from extremeweatherbench import evaluate, inputs, cases, metrics # Still works +from extremeweatherbench.evaluate import ExtremeWeatherBench # Still works +``` diff --git a/scripts/brightband_evaluation.py b/scripts/brightband_evaluation.py index 4086ef51..1dc8103b 100644 --- a/scripts/brightband_evaluation.py +++ b/scripts/brightband_evaluation.py @@ -39,5 +39,5 @@ def configure_logger(level=logging.INFO): # Set up parallel configuration parallel_config = {"backend": "loky", "n_jobs": n_processes} - results = ewb.run(parallel_config=parallel_config) + results = ewb.run_evaluation(parallel_config=parallel_config) results.to_csv("brightband_evaluation_results.csv", index=False) diff --git a/src/extremeweatherbench/__init__.py b/src/extremeweatherbench/__init__.py index e69de29b..af479c2f 100644 --- a/src/extremeweatherbench/__init__.py +++ b/src/extremeweatherbench/__init__.py @@ -0,0 +1,339 @@ +"""ExtremeWeatherBench: A benchmarking framework for extreme weather forecasts. + +This module provides the public API for ExtremeWeatherBench. Users can import +the package and access all key functionality: + + import extremeweatherbench as ewb + + # Main entry point for evaluation + ewb.evaluation(case_metadata=..., evaluation_objects=...) + + # Hierarchical access via namespace submodules + ewb.targets.ERA5(...) + ewb.forecasts.ZarrForecast(...) + ewb.metrics.MeanAbsoluteError(...) + + # Also available at top level + ewb.ERA5(...) + ewb.load_cases() +""" + +from types import SimpleNamespace + +# Import actual modules for backwards compatibility +from extremeweatherbench import calc, cases, defaults, derived, metrics, regions, utils + +# Import specific items for top-level access +from extremeweatherbench.calc import ( + convert_from_cartesian_to_latlon, + geopotential_thickness, + great_circle_mask, + haversine_distance, + maybe_calculate_wind_speed, + mixing_ratio, + orography, + pressure_at_surface, + saturation_mixing_ratio, + saturation_vapor_pressure, + specific_humidity_from_relative_humidity, +) +from extremeweatherbench.cases import ( + CaseOperator, + IndividualCase, + build_case_operators, + load_ewb_events_yaml_into_case_list, + load_individual_cases, + load_individual_cases_from_yaml, + read_incoming_yaml, +) +from extremeweatherbench.defaults import ( + DEFAULT_COORDINATE_VARIABLES, + DEFAULT_VARIABLE_NAMES, + cira_fcnv2_atmospheric_river_forecast, + cira_fcnv2_freeze_forecast, + cira_fcnv2_heatwave_forecast, + cira_fcnv2_severe_convection_forecast, + cira_fcnv2_tropical_cyclone_forecast, + era5_atmospheric_river_target, + era5_freeze_target, + era5_heatwave_target, + get_brightband_evaluation_objects, + get_climatology, + ghcn_freeze_target, + ghcn_heatwave_target, + ibtracs_target, + lsr_target, + pph_target, +) +from extremeweatherbench.derived import ( + AtmosphericRiverVariables, + CravenBrooksSignificantSevere, + DerivedVariable, + TropicalCycloneTrackVariables, + maybe_derive_variables, + maybe_include_variables_from_derived_input, +) +from extremeweatherbench.evaluate import ExtremeWeatherBench +from extremeweatherbench.inputs import ( + ARCO_ERA5_FULL_URI, + DEFAULT_GHCN_URI, + ERA5, + GHCN, + IBTRACS_URI, + LSR, + LSR_URI, + PPH, + PPH_URI, + CIRA_metadata_variable_mapping, + ERA5_metadata_variable_mapping, + EvaluationObject, + ForecastBase, + HRES_metadata_variable_mapping, + IBTrACS, + IBTrACS_metadata_variable_mapping, + InputBase, + KerchunkForecast, + TargetBase, + XarrayForecast, + ZarrForecast, + align_forecast_to_target, + check_for_missing_data, + maybe_subset_variables, + open_kerchunk_reference, + zarr_target_subsetter, +) +from extremeweatherbench.metrics import ( + Accuracy, + BaseMetric, + CompositeMetric, + CriticalSuccessIndex, + DurationMeanError, + EarlySignal, + FalseAlarmRatio, + FalseNegatives, + FalsePositives, + LandfallDisplacement, + LandfallIntensityMeanAbsoluteError, + LandfallMetric, + LandfallTimeMeanError, + MaximumLowestMeanAbsoluteError, + MaximumMeanAbsoluteError, + MeanAbsoluteError, + MeanError, + MeanSquaredError, + MinimumMeanAbsoluteError, + RootMeanSquaredError, + SpatialDisplacement, + ThresholdMetric, + TrueNegatives, + TruePositives, +) +from extremeweatherbench.regions import ( + REGION_TYPES, + BoundingBoxRegion, + CenteredRegion, + Region, + RegionSubsetter, + ShapefileRegion, + map_to_create_region, + subset_cases_to_region, + subset_results_to_region, +) +from extremeweatherbench.utils import ( + check_for_vars, + convert_day_yearofday_to_time, + convert_init_time_to_valid_time, + convert_longitude_to_180, + convert_longitude_to_360, + convert_valid_time_to_init_time, + derive_indices_from_init_time_and_lead_time, + determine_temporal_resolution, + extract_tc_names, + filter_kwargs_for_callable, + find_common_init_times, + idx_to_coords, + interp_climatology_to_target, + is_valid_landfall, + load_land_geometry, + maybe_cache_and_compute, + maybe_densify_dataarray, + maybe_get_closest_timestamp_to_center_of_valid_times, + maybe_get_operator, + min_if_all_timesteps_present, + min_if_all_timesteps_present_forecast, + read_event_yaml, + remove_ocean_gridpoints, + stack_dataarray_from_dims, +) + +# Aliases +evaluation = ExtremeWeatherBench +load_cases = load_ewb_events_yaml_into_case_list + +# Namespace submodules for convenient grouping (these don't shadow actual modules) +targets = SimpleNamespace( + TargetBase=TargetBase, + ERA5=ERA5, + GHCN=GHCN, + IBTrACS=IBTrACS, + LSR=LSR, + PPH=PPH, +) + +forecasts = SimpleNamespace( + ForecastBase=ForecastBase, + ZarrForecast=ZarrForecast, + KerchunkForecast=KerchunkForecast, + XarrayForecast=XarrayForecast, +) + +__all__ = [ + # Core evaluation + "evaluation", + "ExtremeWeatherBench", + # Modules + "calc", + "cases", + "defaults", + "derived", + "metrics", + "regions", + "utils", + # Namespace submodules + "targets", + "forecasts", + # Aliases + "load_cases", + # calc + "convert_from_cartesian_to_latlon", + "geopotential_thickness", + "great_circle_mask", + "haversine_distance", + "maybe_calculate_wind_speed", + "mixing_ratio", + "orography", + "pressure_at_surface", + "saturation_mixing_ratio", + "saturation_vapor_pressure", + "specific_humidity_from_relative_humidity", + # cases + "CaseOperator", + "IndividualCase", + "build_case_operators", + "load_ewb_events_yaml_into_case_list", + "load_individual_cases", + "load_individual_cases_from_yaml", + "read_incoming_yaml", + # defaults + "DEFAULT_COORDINATE_VARIABLES", + "DEFAULT_VARIABLE_NAMES", + "cira_fcnv2_atmospheric_river_forecast", + "cira_fcnv2_freeze_forecast", + "cira_fcnv2_heatwave_forecast", + "cira_fcnv2_severe_convection_forecast", + "cira_fcnv2_tropical_cyclone_forecast", + "era5_atmospheric_river_target", + "era5_freeze_target", + "era5_heatwave_target", + "get_brightband_evaluation_objects", + "get_climatology", + "ghcn_freeze_target", + "ghcn_heatwave_target", + "ibtracs_target", + "lsr_target", + "pph_target", + # derived + "AtmosphericRiverVariables", + "CravenBrooksSignificantSevere", + "DerivedVariable", + "TropicalCycloneTrackVariables", + "maybe_derive_variables", + "maybe_include_variables_from_derived_input", + # inputs + "ARCO_ERA5_FULL_URI", + "CIRA_metadata_variable_mapping", + "DEFAULT_GHCN_URI", + "ERA5", + "ERA5_metadata_variable_mapping", + "EvaluationObject", + "ForecastBase", + "GHCN", + "HRES_metadata_variable_mapping", + "IBTrACS", + "IBTrACS_metadata_variable_mapping", + "IBTRACS_URI", + "InputBase", + "KerchunkForecast", + "LSR", + "LSR_URI", + "PPH", + "PPH_URI", + "TargetBase", + "XarrayForecast", + "ZarrForecast", + "align_forecast_to_target", + "check_for_missing_data", + "maybe_subset_variables", + "open_kerchunk_reference", + "zarr_target_subsetter", + # metrics + "Accuracy", + "BaseMetric", + "CompositeMetric", + "CriticalSuccessIndex", + "DurationMeanError", + "EarlySignal", + "FalseAlarmRatio", + "FalseNegatives", + "FalsePositives", + "LandfallDisplacement", + "LandfallIntensityMeanAbsoluteError", + "LandfallMetric", + "LandfallTimeMeanError", + "MaximumLowestMeanAbsoluteError", + "MaximumMeanAbsoluteError", + "MeanAbsoluteError", + "MeanError", + "MeanSquaredError", + "MinimumMeanAbsoluteError", + "RootMeanSquaredError", + "SpatialDisplacement", + "ThresholdMetric", + "TrueNegatives", + "TruePositives", + # regions + "BoundingBoxRegion", + "CenteredRegion", + "REGION_TYPES", + "Region", + "RegionSubsetter", + "ShapefileRegion", + "map_to_create_region", + "subset_cases_to_region", + "subset_results_to_region", + # utils + "check_for_vars", + "convert_day_yearofday_to_time", + "convert_init_time_to_valid_time", + "convert_longitude_to_180", + "convert_longitude_to_360", + "convert_valid_time_to_init_time", + "derive_indices_from_init_time_and_lead_time", + "determine_temporal_resolution", + "extract_tc_names", + "filter_kwargs_for_callable", + "find_common_init_times", + "idx_to_coords", + "interp_climatology_to_target", + "is_valid_landfall", + "load_land_geometry", + "maybe_cache_and_compute", + "maybe_densify_dataarray", + "maybe_get_closest_timestamp_to_center_of_valid_times", + "maybe_get_operator", + "min_if_all_timesteps_present", + "min_if_all_timesteps_present_forecast", + "read_event_yaml", + "remove_ocean_gridpoints", + "stack_dataarray_from_dims", +] diff --git a/src/extremeweatherbench/cases.py b/src/extremeweatherbench/cases.py index 3f2e858c..236104c9 100644 --- a/src/extremeweatherbench/cases.py +++ b/src/extremeweatherbench/cases.py @@ -14,10 +14,11 @@ import dacite import yaml # type: ignore[import] -from extremeweatherbench import regions +import extremeweatherbench.regions as regions if TYPE_CHECKING: - from extremeweatherbench import inputs, metrics + import extremeweatherbench.inputs as inputs + import extremeweatherbench.metrics as metrics logger = logging.getLogger(__name__) diff --git a/src/extremeweatherbench/defaults.py b/src/extremeweatherbench/defaults.py index 7dcc68d6..41d20e16 100644 --- a/src/extremeweatherbench/defaults.py +++ b/src/extremeweatherbench/defaults.py @@ -309,7 +309,7 @@ def get_brightband_evaluation_objects() -> list[inputs.EvaluationObject]: routine. """ # Import metrics here to avoid circular import - from extremeweatherbench import metrics + import extremeweatherbench.metrics as metrics heatwave_metric_list: list[metrics.BaseMetric] = [ metrics.MaximumMeanAbsoluteError(), diff --git a/src/extremeweatherbench/derived.py b/src/extremeweatherbench/derived.py index 0e609fbe..7517d232 100644 --- a/src/extremeweatherbench/derived.py +++ b/src/extremeweatherbench/derived.py @@ -10,7 +10,7 @@ from extremeweatherbench.events import tropical_cyclone if TYPE_CHECKING: - from extremeweatherbench import cases + import extremeweatherbench.cases as cases logger = logging.getLogger(__name__) diff --git a/src/extremeweatherbench/evaluate.py b/src/extremeweatherbench/evaluate.py index 64c0ba96..504d7694 100644 --- a/src/extremeweatherbench/evaluate.py +++ b/src/extremeweatherbench/evaluate.py @@ -15,10 +15,15 @@ from tqdm.contrib.logging import logging_redirect_tqdm from tqdm.dask import TqdmCallback -from extremeweatherbench import cases, derived, inputs, metrics, sources, utils +import extremeweatherbench.cases as cases +import extremeweatherbench.derived as derived +import extremeweatherbench.inputs as inputs +import extremeweatherbench.metrics as metrics +import extremeweatherbench.sources as sources +import extremeweatherbench.utils as utils if TYPE_CHECKING: - from extremeweatherbench import regions + import extremeweatherbench.regions as regions logger = logging.getLogger(__name__) @@ -101,11 +106,11 @@ def run( parallel_config: Optional[dict] = None, **kwargs, ) -> pd.DataFrame: - """Runs the ExtremeWeatherBench workflow. + """Runs the ExtremeWeatherBench evaluation workflow. - This method will run the workflow in the order of the case operators, optionally - caching the mid-flight outputs of the workflow if cache_dir was provided for - serial runs. + This method will run the evaluation workflow in the order of the case operators, + optionally caching the mid-flight outputs of the workflow if cache_dir was + provided for serial runs. Args: n_jobs: The number of jobs to run in parallel. If None, defaults to the @@ -113,16 +118,60 @@ def run( Ignored if parallel_config is provided. parallel_config: Optional dictionary of joblib parallel configuration. If provided, this takes precedence over n_jobs. If not provided and - n_jobs is specified, a default config with loky backend is used. + n_jobs is specified, a default config with the loky backend is used. + **kwargs: Additional arguments to pass to compute_case_operator. + Returns: + A concatenated dataframe of the evaluation results. + """ + logger.warning("The run method is deprecated. Use run_evaluation instead.") + logger.info("Running ExtremeWeatherBench evaluations...") + + # Check for serial or parallel configuration + parallel_config = _parallel_serial_config_check(n_jobs, parallel_config) + + run_results = _run_evaluation( + self.case_operators, + cache_dir=self.cache_dir, + parallel_config=parallel_config, + **kwargs, + ) + + # If there are results, concatenate them and return, else return an empty + # DataFrame with the expected columns + if run_results: + return _safe_concat(run_results, ignore_index=True) + else: + return pd.DataFrame(columns=OUTPUT_COLUMNS) + + def run_evaluation( + self, + n_jobs: Optional[int] = None, + parallel_config: Optional[dict] = None, + **kwargs, + ) -> pd.DataFrame: + """Runs the ExtremeWeatherBench evaluation workflow. + + This method will run the evaluation workflow in the order of the case operators, + optionally caching the mid-flight outputs of the workflow if cache_dir was + provided for serial runs. + Args: + n_jobs: The number of jobs to run in parallel. If None, defaults to the + joblib backend default value. If 1, the workflow will run serially. + Ignored if parallel_config is provided. + parallel_config: Optional dictionary of joblib parallel configuration. + If provided, this takes precedence over n_jobs. If not provided and + n_jobs is specified, a default config with the loky backend is used. + **kwargs: Additional arguments to pass to compute_case_operator. Returns: A concatenated dataframe of the evaluation results. """ - logger.info("Running ExtremeWeatherBench workflow...") + logger.info("Running ExtremeWeatherBench evaluations...") # Check for serial or parallel configuration parallel_config = _parallel_serial_config_check(n_jobs, parallel_config) - run_results = _run_case_operators( + + run_results = _run_evaluation( self.case_operators, cache_dir=self.cache_dir, parallel_config=parallel_config, @@ -137,7 +186,48 @@ def run( return pd.DataFrame(columns=OUTPUT_COLUMNS) -def _run_case_operators( +def _parallel_serial_config_check( + n_jobs: Optional[int] = None, + parallel_config: Optional[dict] = None, +) -> Optional[dict]: + """Check if running in serial or parallel mode. + + Args: + n_jobs: The number of jobs to run in parallel. If None, defaults to the + joblib backend default value. If 1, the workflow will run serially. + parallel_config: Optional dictionary of joblib parallel configuration. If + provided, this takes precedence over n_jobs. If not provided and n_jobs is + specified, a default config with loky backend is used. + Returns: + None if running in serial mode, otherwise a dictionary of joblib parallel + configuration. + """ + # Determine if running in serial or parallel mode + # Serial: n_jobs=1 or (parallel_config with n_jobs=1) + # Parallel: n_jobs>1 or (parallel_config with n_jobs>1) + is_serial = ( + (n_jobs == 1) + or (parallel_config is not None and parallel_config.get("n_jobs") == 1) + or (n_jobs is None and parallel_config is None) + ) + logger.debug("Running in %s mode.", "serial" if is_serial else "parallel") + + if not is_serial: + # Build parallel_config if not provided + if parallel_config is None and n_jobs is not None: + logger.debug( + "No parallel_config provided, using loky backend and %s jobs.", + n_jobs, + ) + parallel_config = {"backend": "loky", "n_jobs": n_jobs} + # If running in serial mode, set parallel_config to None if not already + else: + parallel_config = None + # Return the maybe updated kwargs + return parallel_config + + +def _run_evaluation( case_operators: list["cases.CaseOperator"], cache_dir: Optional[pathlib.Path] = None, parallel_config: Optional[dict] = None, @@ -154,37 +244,29 @@ def _run_case_operators( Returns: List of result DataFrames. """ - with logging_redirect_tqdm(): - # Run in parallel if parallel_config exists and n_jobs != 1 - if parallel_config is not None: + if parallel_config is not None: + with logging_redirect_tqdm(): logger.info("Running case operators in parallel...") - return _run_parallel( + run_results = _run_parallel_evaluation( case_operators, cache_dir=cache_dir, parallel_config=parallel_config, **kwargs, ) - else: - logger.info("Running case operators in serial...") - return _run_serial(case_operators, cache_dir=cache_dir, **kwargs) - - -def _run_serial( - case_operators: list["cases.CaseOperator"], - cache_dir: Optional[pathlib.Path] = None, - **kwargs, -) -> list[pd.DataFrame]: - """Run the case operators in serial.""" - run_results = [] + else: + logger.info("Running case operators in serial...") + run_results = [] + for case_operator in tqdm(case_operators): + run_results.append( + compute_case_operator(case_operator, cache_dir, **kwargs) + ) - # Loop over the case operators - for case_operator in tqdm(case_operators): - run_results.append(compute_case_operator(case_operator, cache_dir, **kwargs)) return run_results -def _run_parallel( +def _run_parallel_evaluation( case_operators: list["cases.CaseOperator"], + parallel_config: dict, cache_dir: Optional[pathlib.Path] = None, **kwargs, ) -> list[pd.DataFrame]: @@ -197,11 +279,6 @@ def _run_parallel( Returns: List of result DataFrames. """ - parallel_config = kwargs.pop("parallel_config", None) - - if parallel_config is None: - raise ValueError("parallel_config must be provided to _run_parallel") - if parallel_config.get("n_jobs") is None: logger.warning("No number of jobs provided, using joblib backend default.") @@ -900,44 +977,3 @@ def _safe_concat( return pd.concat(valid_dfs, ignore_index=ignore_index) else: return pd.DataFrame(columns=OUTPUT_COLUMNS) - - -def _parallel_serial_config_check( - n_jobs: Optional[int] = None, - parallel_config: Optional[dict] = None, -) -> Optional[dict]: - """Check if running in serial or parallel mode. - - Args: - n_jobs: The number of jobs to run in parallel. If None, defaults to the - joblib backend default value. If 1, the workflow will run serially. - parallel_config: Optional dictionary of joblib parallel configuration. If - provided, this takes precedence over n_jobs. If not provided and n_jobs is - specified, a default config with loky backend is used. - Returns: - None if running in serial mode, otherwise a dictionary of joblib parallel - configuration. - """ - # Determine if running in serial or parallel mode - # Serial: n_jobs=1 or (parallel_config with n_jobs=1) - # Parallel: n_jobs>1 or (parallel_config with n_jobs>1) - is_serial = ( - (n_jobs == 1) - or (parallel_config is not None and parallel_config.get("n_jobs") == 1) - or (n_jobs is None and parallel_config is None) - ) - logger.debug("Running in %s mode.", "serial" if is_serial else "parallel") - - if not is_serial: - # Build parallel_config if not provided - if parallel_config is None and n_jobs is not None: - logger.debug( - "No parallel_config provided, using loky backend and %s jobs.", - n_jobs, - ) - parallel_config = {"backend": "loky", "n_jobs": n_jobs} - # If running in serial mode, set parallel_config to None if not already - else: - parallel_config = None - # Return the maybe updated kwargs - return parallel_config diff --git a/src/extremeweatherbench/evaluate_cli.py b/src/extremeweatherbench/evaluate_cli.py index 9b978269..1a96dfc1 100644 --- a/src/extremeweatherbench/evaluate_cli.py +++ b/src/extremeweatherbench/evaluate_cli.py @@ -7,7 +7,9 @@ import click import pandas as pd -from extremeweatherbench import cases, defaults, evaluate +import extremeweatherbench.cases as cases +import extremeweatherbench.defaults as defaults +import extremeweatherbench.evaluate as evaluate @click.command() @@ -152,7 +154,7 @@ def cli_runner( # Run evaluation click.echo("Running evaluation...") - results = ewb.run( + results = ewb.run_evaluation( n_jobs=n_jobs, parallel_config=parallel_config, ) diff --git a/src/extremeweatherbench/inputs.py b/src/extremeweatherbench/inputs.py index 8708dde9..37aba5a5 100644 --- a/src/extremeweatherbench/inputs.py +++ b/src/extremeweatherbench/inputs.py @@ -19,10 +19,13 @@ import polars as pl import xarray as xr -from extremeweatherbench import cases, derived, sources, utils +import extremeweatherbench.cases as cases +import extremeweatherbench.derived as derived +import extremeweatherbench.sources as sources +import extremeweatherbench.utils as utils if TYPE_CHECKING: - from extremeweatherbench import metrics + import extremeweatherbench.metrics as metrics logger = logging.getLogger(__name__) diff --git a/src/extremeweatherbench/regions.py b/src/extremeweatherbench/regions.py index 5a36dc0e..4e6c7730 100644 --- a/src/extremeweatherbench/regions.py +++ b/src/extremeweatherbench/regions.py @@ -16,7 +16,7 @@ from extremeweatherbench import utils if TYPE_CHECKING: - from extremeweatherbench import cases + import extremeweatherbench.cases as cases logger = logging.getLogger(__name__) diff --git a/src/extremeweatherbench/sources/base.py b/src/extremeweatherbench/sources/base.py index dd58641c..e7dbda6e 100644 --- a/src/extremeweatherbench/sources/base.py +++ b/src/extremeweatherbench/sources/base.py @@ -1,7 +1,7 @@ import datetime from typing import Any, Protocol, runtime_checkable -from extremeweatherbench import regions +import extremeweatherbench.regions as regions @runtime_checkable diff --git a/src/extremeweatherbench/sources/pandas_dataframe.py b/src/extremeweatherbench/sources/pandas_dataframe.py index 31bc4062..b6eb91a9 100644 --- a/src/extremeweatherbench/sources/pandas_dataframe.py +++ b/src/extremeweatherbench/sources/pandas_dataframe.py @@ -8,7 +8,7 @@ from extremeweatherbench import utils if TYPE_CHECKING: - from extremeweatherbench import regions + import extremeweatherbench.regions as regions def safely_pull_variables( @@ -43,7 +43,7 @@ def safely_pull_variables( >>> list(result.columns) ['temp'] """ - from extremeweatherbench import defaults + import extremeweatherbench.defaults as defaults # Get column names from DataFrame available_columns = list(data.columns) diff --git a/src/extremeweatherbench/sources/polars_lazyframe.py b/src/extremeweatherbench/sources/polars_lazyframe.py index e9e56cf4..f0caa41e 100644 --- a/src/extremeweatherbench/sources/polars_lazyframe.py +++ b/src/extremeweatherbench/sources/polars_lazyframe.py @@ -8,7 +8,7 @@ from extremeweatherbench import utils if TYPE_CHECKING: - from extremeweatherbench import regions + import extremeweatherbench.regions as regions def safely_pull_variables( @@ -47,7 +47,7 @@ def safely_pull_variables( >>> result.collect().columns ['temp'] """ - from extremeweatherbench import defaults + import extremeweatherbench.defaults as defaults # Get column names from LazyFrame available_columns = data.collect_schema().names() diff --git a/src/extremeweatherbench/sources/xarray_dataarray.py b/src/extremeweatherbench/sources/xarray_dataarray.py index e58d82d6..f3b4e734 100644 --- a/src/extremeweatherbench/sources/xarray_dataarray.py +++ b/src/extremeweatherbench/sources/xarray_dataarray.py @@ -5,7 +5,8 @@ import pandas as pd import xarray as xr -from extremeweatherbench import regions, utils +import extremeweatherbench.regions as regions +import extremeweatherbench.utils as utils def safely_pull_variables( diff --git a/src/extremeweatherbench/sources/xarray_dataset.py b/src/extremeweatherbench/sources/xarray_dataset.py index 56d52618..ae8e8b89 100644 --- a/src/extremeweatherbench/sources/xarray_dataset.py +++ b/src/extremeweatherbench/sources/xarray_dataset.py @@ -9,7 +9,7 @@ from extremeweatherbench import utils if TYPE_CHECKING: - from extremeweatherbench import regions + import extremeweatherbench.regions as regions def safely_pull_variables( diff --git a/tests/test_evaluate.py b/tests/test_evaluate.py index 9220093f..18569e6b 100644 --- a/tests/test_evaluate.py +++ b/tests/test_evaluate.py @@ -332,10 +332,10 @@ def test_case_operators_property( # Check that the result is what the mock returned assert result == [sample_case_operator] - @mock.patch("extremeweatherbench.evaluate._run_case_operators") - def test_run_serial( + @mock.patch("extremeweatherbench.evaluate._run_evaluation") + def test_run_serial_evaluation( self, - mock_run_case_operators, + mock_run_evaluation, sample_cases_list, sample_evaluation_object, sample_case_operator, @@ -345,7 +345,7 @@ def test_run_serial( with mock.patch.object( evaluate.ExtremeWeatherBench, "case_operators", new=[sample_case_operator] ): - # Mock _run_case_operators to return a list of DataFrames + # Mock _run_evaluation to return a list of DataFrames mock_result = [ pd.DataFrame( { @@ -355,17 +355,17 @@ def test_run_serial( } ) ] - mock_run_case_operators.return_value = mock_result + mock_run_evaluation.return_value = mock_result ewb = evaluate.ExtremeWeatherBench( case_metadata=sample_cases_list, evaluation_objects=[sample_evaluation_object], ) - result = ewb.run(n_jobs=1) + result = ewb.run_evaluation(n_jobs=1) - # Serial mode passes parallel_config=None - mock_run_case_operators.assert_called_once_with( + # Serial mode should pass parallel_config=None + mock_run_evaluation.assert_called_once_with( [sample_case_operator], cache_dir=None, parallel_config=None, @@ -373,10 +373,10 @@ def test_run_serial( assert isinstance(result, pd.DataFrame) assert len(result) == 1 - @mock.patch("extremeweatherbench.evaluate._run_case_operators") - def test_run_parallel( + @mock.patch("extremeweatherbench.evaluate._run_evaluation") + def test_run_parallel_evaluation( self, - mock_run_case_operators, + mock_run_evaluation, sample_cases_list, sample_evaluation_object, sample_case_operator, @@ -394,16 +394,16 @@ def test_run_parallel( } ) ] - mock_run_case_operators.return_value = mock_result + mock_run_evaluation.return_value = mock_result ewb = evaluate.ExtremeWeatherBench( case_metadata=sample_cases_list, evaluation_objects=[sample_evaluation_object], ) - result = ewb.run(n_jobs=2) + result = ewb.run_evaluation(n_jobs=2) - mock_run_case_operators.assert_called_once_with( + mock_run_evaluation.assert_called_once_with( [sample_case_operator], cache_dir=None, parallel_config={"backend": "loky", "n_jobs": 2}, @@ -411,10 +411,10 @@ def test_run_parallel( assert isinstance(result, pd.DataFrame) assert len(result) == 1 - @mock.patch("extremeweatherbench.evaluate._run_case_operators") + @mock.patch("extremeweatherbench.evaluate._run_evaluation") def test_run_with_kwargs( self, - mock_run_case_operators, + mock_run_evaluation, sample_cases_list, sample_evaluation_object, sample_case_operator, @@ -424,37 +424,37 @@ def test_run_with_kwargs( evaluate.ExtremeWeatherBench, "case_operators", new=[sample_case_operator] ): mock_result = [pd.DataFrame({"value": [1.0]})] - mock_run_case_operators.return_value = mock_result + mock_run_evaluation.return_value = mock_result ewb = evaluate.ExtremeWeatherBench( case_metadata=sample_cases_list, evaluation_objects=[sample_evaluation_object], ) - result = ewb.run(n_jobs=1, threshold=0.5) + result = ewb.run_evaluation(n_jobs=1, threshold=0.5) # Check that kwargs were passed through - call_args = mock_run_case_operators.call_args + call_args = mock_run_evaluation.call_args assert call_args[1]["threshold"] == 0.5 assert isinstance(result, pd.DataFrame) - @mock.patch("extremeweatherbench.evaluate._run_case_operators") + @mock.patch("extremeweatherbench.evaluate._run_evaluation") def test_run_empty_results( self, - mock_run_case_operators, + mock_run_evaluation, sample_cases_list, sample_evaluation_object, ): """Test the run method handles empty results.""" with mock.patch.object(evaluate.ExtremeWeatherBench, "case_operators", new=[]): - mock_run_case_operators.return_value = [] + mock_run_evaluation.return_value = [] ewb = evaluate.ExtremeWeatherBench( case_metadata=sample_cases_list, evaluation_objects=[sample_evaluation_object], ) - result = ewb.run() + result = ewb.run_evaluation() assert isinstance(result, pd.DataFrame) assert len(result) == 0 @@ -504,7 +504,7 @@ def mock_compute_with_caching(case_operator, cache_dir_arg, **kwargs): cache_dir=cache_dir, ) - ewb.run(n_jobs=1) + ewb.run_evaluation(n_jobs=1) # Check that cache directory was created assert cache_dir.exists() @@ -538,7 +538,7 @@ def test_run_multiple_cases( evaluation_objects=[sample_evaluation_object], ) - result = ewb.run() + result = ewb.run_evaluation() assert mock_compute_case_operator.call_count == 2 assert len(result) == 2 @@ -546,107 +546,111 @@ def test_run_multiple_cases( class TestRunCaseOperators: - """Test the _run_case_operators function.""" + """Test the _run_evaluation function.""" - @mock.patch("extremeweatherbench.evaluate._run_serial") - def test_run_case_operators_serial(self, mock_run_serial, sample_case_operator): - """Test _run_case_operators routes to serial execution.""" - mock_results = [pd.DataFrame({"value": [1.0]})] - mock_run_serial.return_value = mock_results + @mock.patch("extremeweatherbench.evaluate.compute_case_operator") + @mock.patch("tqdm.auto.tqdm") + def test_run_evaluation_serial( + self, mock_tqdm, mock_compute_case_operator, sample_case_operator + ): + """Test _run_evaluation executes serially when parallel_config=None.""" + mock_tqdm.return_value = [sample_case_operator] + mock_results = pd.DataFrame({"value": [1.0]}) + mock_compute_case_operator.return_value = mock_results # Serial mode: don't pass parallel_config - result = evaluate._run_case_operators([sample_case_operator], cache_dir=None) + result = evaluate._run_evaluation([sample_case_operator], cache_dir=None) - mock_run_serial.assert_called_once_with([sample_case_operator], cache_dir=None) - assert result == mock_results + mock_compute_case_operator.assert_called_once_with(sample_case_operator, None) + assert len(result) == 1 + assert result[0].equals(mock_results) - @mock.patch("extremeweatherbench.evaluate._run_parallel") - def test_run_case_operators_parallel(self, mock_run_parallel, sample_case_operator): - """Test _run_case_operators routes to parallel execution.""" + @mock.patch("extremeweatherbench.evaluate._run_parallel_evaluation") + def test_run_evaluation_parallel( + self, mock_run_parallel_evaluation, sample_case_operator + ): + """Test _run_evaluation routes to parallel execution.""" mock_results = [pd.DataFrame({"value": [1.0]})] - mock_run_parallel.return_value = mock_results + mock_run_parallel_evaluation.return_value = mock_results - result = evaluate._run_case_operators( + result = evaluate._run_evaluation( [sample_case_operator], cache_dir=None, parallel_config={"backend": "threading", "n_jobs": 4}, ) - mock_run_parallel.assert_called_once_with( + mock_run_parallel_evaluation.assert_called_once_with( [sample_case_operator], cache_dir=None, parallel_config={"backend": "threading", "n_jobs": 4}, ) assert result == mock_results - @mock.patch("extremeweatherbench.evaluate._run_serial") - def test_run_case_operators_with_kwargs( - self, mock_run_serial, sample_case_operator + @mock.patch("extremeweatherbench.evaluate.compute_case_operator") + @mock.patch("tqdm.auto.tqdm") + def test_run_evaluation_with_kwargs( + self, mock_tqdm, mock_compute_case_operator, sample_case_operator ): - """Test _run_case_operators passes kwargs correctly.""" - mock_results = [pd.DataFrame({"value": [1.0]})] - mock_run_serial.return_value = mock_results + """Test _run_evaluation passes kwargs correctly in serial mode.""" + mock_tqdm.return_value = [sample_case_operator] + mock_results = pd.DataFrame({"value": [1.0]}) + mock_compute_case_operator.return_value = mock_results # Serial mode: don't pass parallel_config - result = evaluate._run_case_operators( + result = evaluate._run_evaluation( [sample_case_operator], cache_dir=None, threshold=0.5, ) - call_args = mock_run_serial.call_args - assert call_args[0][0] == [sample_case_operator] - assert call_args[1]["cache_dir"] is None + call_args = mock_compute_case_operator.call_args + assert call_args[0][0] == sample_case_operator + assert call_args[0][1] is None # cache_dir assert call_args[1]["threshold"] == 0.5 assert isinstance(result, list) - @mock.patch("extremeweatherbench.evaluate._run_parallel") - def test_run_case_operators_parallel_with_kwargs( - self, mock_run_parallel, sample_case_operator + @mock.patch("extremeweatherbench.evaluate._run_parallel_evaluation") + def test_run_evaluation_parallel_with_kwargs( + self, mock_run_parallel_evaluation, sample_case_operator ): - """Test _run_case_operators passes kwargs to parallel execution.""" + """Test _run_evaluation passes kwargs to parallel execution.""" mock_results = [pd.DataFrame({"value": [1.0]})] - mock_run_parallel.return_value = mock_results + mock_run_parallel_evaluation.return_value = mock_results - result = evaluate._run_case_operators( + result = evaluate._run_evaluation( [sample_case_operator], parallel_config={"backend": "threading", "n_jobs": 2}, custom_param="test_value", ) - call_args = mock_run_parallel.call_args + call_args = mock_run_parallel_evaluation.call_args assert call_args[0][0] == [sample_case_operator] assert call_args[1]["parallel_config"] == {"backend": "threading", "n_jobs": 2} assert call_args[1]["custom_param"] == "test_value" assert isinstance(result, list) - def test_run_case_operators_empty_list(self): - """Test _run_case_operators with empty case operator list.""" - with mock.patch("extremeweatherbench.evaluate._run_serial") as mock_serial: - mock_serial.return_value = [] - - # Serial mode: don't pass parallel_config - result = evaluate._run_case_operators([], cache_dir=None) - - mock_serial.assert_called_once_with([], cache_dir=None) - assert result == [] + def test_run_evaluation_empty_list(self): + """Test _run_evaluation with empty case operator list.""" + # Serial mode: don't pass parallel_config + result = evaluate._run_evaluation([], cache_dir=None) + assert result == [] class TestRunSerial: - """Test the _run_serial function.""" + """Test the serial execution path of _run_evaluation.""" @mock.patch("extremeweatherbench.evaluate.compute_case_operator") @mock.patch("tqdm.auto.tqdm") - def test_run_serial_basic( + def test_run_serial_evaluation_basic( self, mock_tqdm, mock_compute_case_operator, sample_case_operator ): - """Test basic _run_serial functionality.""" + """Test basic serial execution functionality.""" # Setup mocks mock_tqdm.return_value = [sample_case_operator] # tqdm returns iterable mock_result = pd.DataFrame({"value": [1.0], "case_id_number": [1]}) mock_compute_case_operator.return_value = mock_result - result = evaluate._run_serial([sample_case_operator]) + result = evaluate._run_evaluation([sample_case_operator], parallel_config=None) mock_compute_case_operator.assert_called_once_with(sample_case_operator, None) assert len(result) == 1 @@ -654,8 +658,10 @@ def test_run_serial_basic( @mock.patch("extremeweatherbench.evaluate.compute_case_operator") @mock.patch("tqdm.auto.tqdm") - def test_run_serial_multiple_cases(self, mock_tqdm, mock_compute_case_operator): - """Test _run_serial with multiple case operators.""" + def test_run_serial_evaluation_multiple_cases( + self, mock_tqdm, mock_compute_case_operator + ): + """Test serial execution with multiple case operators.""" case_op_1 = mock.Mock() case_op_2 = mock.Mock() case_operators = [case_op_1, case_op_2] @@ -666,7 +672,7 @@ def test_run_serial_multiple_cases(self, mock_tqdm, mock_compute_case_operator): pd.DataFrame({"value": [2.0], "case_id_number": [2]}), ] - result = evaluate._run_serial(case_operators) + result = evaluate._run_evaluation(case_operators, parallel_config=None) assert mock_compute_case_operator.call_count == 2 assert len(result) == 2 @@ -675,16 +681,19 @@ def test_run_serial_multiple_cases(self, mock_tqdm, mock_compute_case_operator): @mock.patch("extremeweatherbench.evaluate.compute_case_operator") @mock.patch("tqdm.auto.tqdm") - def test_run_serial_with_kwargs( + def test_run_serial_evaluation_with_kwargs( self, mock_tqdm, mock_compute_case_operator, sample_case_operator ): - """Test _run_serial passes kwargs to compute_case_operator.""" + """Test serial execution passes kwargs to compute_case_operator.""" mock_tqdm.return_value = [sample_case_operator] mock_result = pd.DataFrame({"value": [1.0]}) mock_compute_case_operator.return_value = mock_result - result = evaluate._run_serial( - [sample_case_operator], threshold=0.7, custom_param="test" + result = evaluate._run_evaluation( + [sample_case_operator], + parallel_config=None, + threshold=0.7, + custom_param="test", ) call_args = mock_compute_case_operator.call_args @@ -693,22 +702,22 @@ def test_run_serial_with_kwargs( assert call_args[1]["custom_param"] == "test" assert isinstance(result, list) - def test_run_serial_empty_list(self): - """Test _run_serial with empty case operator list.""" - result = evaluate._run_serial([]) + def test_run_serial_evaluation_empty_list(self): + """Test serial execution with empty case operator list.""" + result = evaluate._run_evaluation([], parallel_config=None) assert result == [] class TestRunParallel: - """Test the _run_parallel function.""" + """Test the _run_parallel_evaluation function.""" @mock.patch("extremeweatherbench.utils.ParallelTqdm") @mock.patch("joblib.delayed") @mock.patch("tqdm.auto.tqdm") - def test_run_parallel_basic( + def test_run_parallel_evaluation_basic( self, mock_tqdm, mock_delayed, mock_parallel_class, sample_case_operator ): - """Test basic _run_parallel functionality.""" + """Test basic _run_parallel_evaluation functionality.""" # Setup mocks mock_tqdm.return_value = [sample_case_operator] mock_delayed_func = mock.Mock() @@ -719,7 +728,7 @@ def test_run_parallel_basic( mock_result = [pd.DataFrame({"value": [1.0], "case_id_number": [1]})] mock_parallel_instance.return_value = mock_result - result = evaluate._run_parallel( + result = evaluate._run_parallel_evaluation( [sample_case_operator], parallel_config={"backend": "threading", "n_jobs": 2}, ) @@ -735,10 +744,10 @@ def test_run_parallel_basic( @mock.patch("extremeweatherbench.utils.ParallelTqdm") @mock.patch("joblib.delayed") @mock.patch("tqdm.auto.tqdm") - def test_run_parallel_with_none_n_jobs( + def test_run_parallel_evaluation_with_none_n_jobs( self, mock_tqdm, mock_delayed, mock_parallel_class, sample_case_operator ): - """Test _run_parallel with n_jobs=None (should use all CPUs).""" + """Test _run_parallel_evaluation with n_jobs=None (should use all CPUs).""" mock_tqdm.return_value = [sample_case_operator] mock_delayed_func = mock.Mock() mock_delayed.return_value = mock_delayed_func @@ -749,7 +758,7 @@ def test_run_parallel_with_none_n_jobs( mock_parallel_instance.return_value = mock_result with mock.patch("extremeweatherbench.evaluate.logger.warning") as mock_warning: - result = evaluate._run_parallel( + result = evaluate._run_parallel_evaluation( [sample_case_operator], parallel_config={"backend": "threading", "n_jobs": None}, ) @@ -765,7 +774,7 @@ def test_run_parallel_with_none_n_jobs( @mock.patch("joblib.parallel_config") @mock.patch("extremeweatherbench.utils.ParallelTqdm") - def test_run_parallel_n_jobs_in_config( + def test_run_parallel_evaluation_n_jobs_in_config( self, mock_parallel_class, mock_parallel_config ): """Test that n_jobs is passed through parallel_config, not directly.""" @@ -782,7 +791,7 @@ def test_run_parallel_n_jobs_in_config( ) mock_parallel_config.return_value.__exit__ = mock.Mock(return_value=False) - result = evaluate._run_parallel( + result = evaluate._run_parallel_evaluation( [sample_case_operator], parallel_config={"backend": "threading", "n_jobs": 4}, ) @@ -800,10 +809,10 @@ def test_run_parallel_n_jobs_in_config( @mock.patch("extremeweatherbench.utils.ParallelTqdm") @mock.patch("joblib.delayed") @mock.patch("tqdm.auto.tqdm") - def test_run_parallel_multiple_cases( + def test_run_parallel_evaluation_multiple_cases( self, mock_tqdm, mock_delayed, mock_parallel_class ): - """Test _run_parallel with multiple case operators.""" + """Test _run_parallel_evaluation with multiple case operators.""" case_op_1 = mock.Mock() case_op_2 = mock.Mock() case_operators = [case_op_1, case_op_2] @@ -820,7 +829,7 @@ def test_run_parallel_multiple_cases( ] mock_parallel_instance.return_value = mock_result - result = evaluate._run_parallel( + result = evaluate._run_parallel_evaluation( case_operators, parallel_config={"backend": "threading", "n_jobs": 4} ) @@ -831,10 +840,10 @@ def test_run_parallel_multiple_cases( @mock.patch("extremeweatherbench.utils.ParallelTqdm") @mock.patch("joblib.delayed") @mock.patch("tqdm.auto.tqdm") - def test_run_parallel_with_kwargs( + def test_run_parallel_evaluation_with_kwargs( self, mock_tqdm, mock_delayed, mock_parallel_class, sample_case_operator ): - """Test _run_parallel passes kwargs correctly.""" + """Test _run_parallel_evaluation passes kwargs correctly.""" mock_tqdm.return_value = [sample_case_operator] mock_delayed_func = mock.Mock() mock_delayed.return_value = mock_delayed_func @@ -844,7 +853,7 @@ def test_run_parallel_with_kwargs( mock_result = [pd.DataFrame({"value": [1.0]})] mock_parallel_instance.return_value = mock_result - result = evaluate._run_parallel( + result = evaluate._run_parallel_evaluation( [sample_case_operator], parallel_config={"backend": "threading", "n_jobs": 2}, threshold=0.8, @@ -861,8 +870,8 @@ def test_run_parallel_with_kwargs( assert len(delayed_calls) == 1 assert isinstance(result, list) - def test_run_parallel_empty_list(self): - """Test _run_parallel with empty case operator list.""" + def test_run_parallel_evaluation_empty_list(self): + """Test _run_parallel_evaluation with empty case operator list.""" with mock.patch( "extremeweatherbench.utils.ParallelTqdm" ) as mock_parallel_class: @@ -872,7 +881,7 @@ def test_run_parallel_empty_list(self): mock_parallel_class.return_value = mock_parallel_instance mock_parallel_instance.return_value = [] - result = evaluate._run_parallel( + result = evaluate._run_parallel_evaluation( [], parallel_config={"backend": "threading", "n_jobs": 2} ) @@ -883,10 +892,10 @@ def test_run_parallel_empty_list(self): ) @mock.patch("dask.distributed.Client") @mock.patch("dask.distributed.LocalCluster") - def test_run_parallel_dask_backend_auto_client( + def test_run_parallel_evaluation_dask_backend_auto_client( self, mock_local_cluster, mock_client_class, sample_case_operator ): - """Test _run_parallel with dask backend automatically creates client.""" + """Test _run_parallel_evaluation with dask backend automatically creates client.""" # Mock Client.current() to raise ValueError (no existing client) mock_client_class.current.side_effect = ValueError("No client found") @@ -905,7 +914,7 @@ def test_run_parallel_dask_backend_auto_client( mock_parallel_instance.return_value = [pd.DataFrame({"test": [1]})] with mock.patch("joblib.parallel_config"): - result = evaluate._run_parallel( + result = evaluate._run_parallel_evaluation( [sample_case_operator], parallel_config={"backend": "dask", "n_jobs": 2}, ) @@ -919,10 +928,10 @@ def test_run_parallel_dask_backend_auto_client( not HAS_DASK_DISTRIBUTED, reason="dask.distributed not installed" ) @mock.patch("dask.distributed.Client") - def test_run_parallel_dask_backend_existing_client( + def test_run_parallel_evaluation_dask_backend_existing_client( self, mock_client_class, sample_case_operator ): - """Test _run_parallel with dask backend uses existing client.""" + """Test _run_parallel_evaluation with dask backend uses existing client.""" # Mock existing client mock_existing_client = mock.Mock() mock_client_class.current.return_value = mock_existing_client @@ -934,7 +943,7 @@ def test_run_parallel_dask_backend_existing_client( mock_parallel_instance.return_value = [pd.DataFrame({"test": [1]})] with mock.patch("joblib.parallel_config"): - result = evaluate._run_parallel( + result = evaluate._run_parallel_evaluation( [sample_case_operator], parallel_config={"backend": "dask", "n_jobs": 2}, ) @@ -1385,7 +1394,7 @@ def test_run_pipeline_target( def test_run_pipeline_invalid_source(self, sample_case_operator): """Test run_pipeline function with invalid input source.""" with pytest.raises(AttributeError, match="'str' object has no attribute"): - evaluate.run_pipeline(sample_case_operator.case_metadata, "invalid") + evaluate.run_pipeline(sample_case_operator.case_metadata, "invalid") # type: ignore def test_maybe_cache_and_compute_with_cache_dir( self, sample_forecast_dataset, sample_target_dataset, sample_individual_case @@ -1600,9 +1609,14 @@ def test_extremeweatherbench_empty_cases(self, sample_evaluation_object): evaluation_objects=[sample_evaluation_object], ) - result = ewb.run() - assert isinstance(result, pd.DataFrame) - assert len(result) == 0 + with mock.patch("extremeweatherbench.cases.build_case_operators") as mock_build: + mock_build.return_value = [] + + result = ewb.run_evaluation() + + # Should return empty DataFrame when no cases + assert isinstance(result, pd.DataFrame) + assert len(result) == 0 def test_compute_case_operator_exception_handling(self, sample_case_operator): """Test exception handling in compute_case_operator.""" @@ -1645,49 +1659,53 @@ def test_evaluate_metric_computation_failure( case_operator=sample_case_operator, ) - @mock.patch("extremeweatherbench.evaluate._run_serial") - def test_run_case_operators_serial_exception( - self, mock_run_serial, sample_case_operator + @mock.patch("extremeweatherbench.evaluate.compute_case_operator") + @mock.patch("tqdm.auto.tqdm") + def test_run_evaluation_serial_exception( + self, mock_tqdm, mock_compute_case_operator, sample_case_operator ): - """Test _run_case_operators handles exceptions in serial execution.""" - mock_run_serial.side_effect = Exception("Serial execution failed") + """Test _run_evaluation handles exceptions in serial execution.""" + mock_tqdm.return_value = [sample_case_operator] + mock_compute_case_operator.side_effect = Exception("Serial execution failed") with pytest.raises(Exception, match="Serial execution failed"): # Serial mode: don't pass parallel_config - evaluate._run_case_operators([sample_case_operator], None) + evaluate._run_evaluation([sample_case_operator], parallel_config=None) - @mock.patch("extremeweatherbench.evaluate._run_parallel") - def test_run_case_operators_parallel_exception( - self, mock_run_parallel, sample_case_operator + @mock.patch("extremeweatherbench.evaluate._run_parallel_evaluation") + def test_run_evaluation_parallel_exception( + self, mock_run_parallel_evaluation, sample_case_operator ): - """Test _run_case_operators handles exceptions in parallel execution.""" - mock_run_parallel.side_effect = Exception("Parallel execution failed") + """Test _run_evaluation handles exceptions in parallel execution.""" + mock_run_parallel_evaluation.side_effect = Exception( + "Parallel execution failed" + ) with pytest.raises(Exception, match="Parallel execution failed"): - evaluate._run_case_operators( + evaluate._run_evaluation( [sample_case_operator], parallel_config={"backend": "threading", "n_jobs": 2}, ) @mock.patch("extremeweatherbench.evaluate.compute_case_operator") @mock.patch("tqdm.auto.tqdm") - def test_run_serial_case_operator_exception( + def test_run_serial_evaluation_case_operator_exception( self, mock_tqdm, mock_compute_case_operator, sample_case_operator ): - """Test _run_serial handles exceptions from individual case operators.""" + """Test serial execution handles exceptions from individual case operators.""" mock_tqdm.return_value = [sample_case_operator] mock_compute_case_operator.side_effect = Exception("Case operator failed") with pytest.raises(Exception, match="Case operator failed"): - evaluate._run_serial([sample_case_operator]) + evaluate._run_evaluation([sample_case_operator], parallel_config=None) @mock.patch("extremeweatherbench.utils.ParallelTqdm") @mock.patch("joblib.delayed") @mock.patch("tqdm.auto.tqdm") - def test_run_parallel_joblib_exception( + def test_run_parallel_evaluation_joblib_exception( self, mock_tqdm, mock_delayed, mock_parallel_class, sample_case_operator ): - """Test _run_parallel handles joblib Parallel exceptions.""" + """Test _run_parallel_evaluation handles joblib Parallel exceptions.""" mock_tqdm.return_value = [sample_case_operator] mock_delayed_func = mock.Mock() mock_delayed.return_value = mock_delayed_func @@ -1697,7 +1715,7 @@ def test_run_parallel_joblib_exception( mock_parallel_instance.side_effect = Exception("Joblib parallel failed") with pytest.raises(Exception, match="Joblib parallel failed"): - evaluate._run_parallel( + evaluate._run_parallel_evaluation( [sample_case_operator], parallel_config={"backend": "threading", "n_jobs": 2}, ) @@ -1705,10 +1723,10 @@ def test_run_parallel_joblib_exception( @mock.patch("extremeweatherbench.utils.ParallelTqdm") @mock.patch("joblib.delayed") @mock.patch("tqdm.auto.tqdm") - def test_run_parallel_delayed_function_exception( + def test_run_parallel_evaluation_delayed_function_exception( self, mock_tqdm, mock_delayed, mock_parallel_class, sample_case_operator ): - """Test _run_parallel handles exceptions in delayed functions.""" + """Test _run_parallel_evaluation handles exceptions in delayed functions.""" mock_tqdm.return_value = [sample_case_operator] # Mock delayed to raise an exception @@ -1724,12 +1742,12 @@ def consume_generator(generator): mock_parallel_instance.side_effect = consume_generator with pytest.raises(Exception, match="Delayed function creation failed"): - evaluate._run_parallel( + evaluate._run_parallel_evaluation( [sample_case_operator], parallel_config={"backend": "threading", "n_jobs": 2}, ) - @mock.patch("extremeweatherbench.evaluate._run_case_operators") + @mock.patch("extremeweatherbench.evaluate._run_evaluation") def test_run_method_exception_propagation( self, mock_run_case_operators, sample_cases_list, sample_evaluation_object ): @@ -1742,12 +1760,14 @@ def test_run_method_exception_propagation( ) with pytest.raises(Exception, match="Execution failed"): - ewb.run() + ewb.run_evaluation() @mock.patch("extremeweatherbench.evaluate.compute_case_operator") @mock.patch("tqdm.auto.tqdm") - def test_run_serial_partial_failure(self, mock_tqdm, mock_compute_case_operator): - """Test _run_serial behavior when some case operators fail.""" + def test_run_serial_evaluation_partial_failure( + self, mock_tqdm, mock_compute_case_operator + ): + """Test serial execution behavior when some case operators fail.""" case_op_1 = mock.Mock() case_op_2 = mock.Mock() case_op_3 = mock.Mock() @@ -1764,7 +1784,7 @@ def test_run_serial_partial_failure(self, mock_tqdm, mock_compute_case_operator) # Should fail on the second case operator with pytest.raises(Exception, match="Case operator 2 failed"): - evaluate._run_serial(case_operators) + evaluate._run_evaluation(case_operators, parallel_config=None) # Should have tried only the first two assert mock_compute_case_operator.call_count == 2 @@ -1772,10 +1792,10 @@ def test_run_serial_partial_failure(self, mock_tqdm, mock_compute_case_operator) @mock.patch("extremeweatherbench.utils.ParallelTqdm") @mock.patch("joblib.delayed") @mock.patch("tqdm.auto.tqdm") - def test_run_parallel_invalid_n_jobs( + def test_run_parallel_evaluation_invalid_n_jobs( self, mock_tqdm, mock_delayed, mock_parallel_class, sample_case_operator ): - """Test _run_parallel with invalid n_jobs parameter.""" + """Test _run_parallel_evaluation with invalid n_jobs parameter.""" mock_tqdm.return_value = [sample_case_operator] mock_delayed_func = mock.Mock() mock_delayed.return_value = mock_delayed_func @@ -1784,7 +1804,7 @@ def test_run_parallel_invalid_n_jobs( mock_parallel_class.side_effect = ValueError("Invalid n_jobs parameter") with pytest.raises(ValueError, match="Invalid n_jobs parameter"): - evaluate._run_parallel( + evaluate._run_parallel_evaluation( [sample_case_operator], parallel_config={"backend": "threading", "n_jobs": -5}, ) @@ -1869,7 +1889,7 @@ def test_end_to_end_workflow( evaluation_objects=[sample_evaluation_object], ) - result = ewb.run() + result = ewb.run_evaluation() # Verify the result structure assert isinstance(result, pd.DataFrame) @@ -1973,7 +1993,7 @@ def test_multiple_variables_and_metrics( evaluation_objects=[eval_obj], ) - result = ewb.run() + result = ewb.run_evaluation() # Should have results for each metric combination assert len(result) >= 2 # At least 2 metrics * 1 case @@ -2016,12 +2036,12 @@ def test_serial_vs_parallel_results_consistency( # Test serial execution mock_compute_case_operator.side_effect = [result_1, result_2] - serial_result = ewb.run(n_jobs=1) + serial_result = ewb.run_evaluation(n_jobs=1) # Reset mock and test parallel execution mock_compute_case_operator.reset_mock() mock_compute_case_operator.side_effect = [result_1, result_2] - parallel_result = ewb.run(n_jobs=2) + parallel_result = ewb.run_evaluation(n_jobs=2) # Both should produce valid DataFrames with same structure assert isinstance(serial_result, pd.DataFrame) @@ -2031,12 +2051,16 @@ def test_serial_vs_parallel_results_consistency( assert list(serial_result.columns) == list(parallel_result.columns) @mock.patch("extremeweatherbench.evaluate.compute_case_operator") - def test_execution_method_performance_comparison(self, mock_compute_case_operator): + @mock.patch("tqdm.auto.tqdm") + def test_execution_method_performance_comparison( + self, mock_tqdm, mock_compute_case_operator + ): """Test that both execution methods handle the same workload.""" import time # Create many case operators to simulate realistic workload case_operators = [mock.Mock() for _ in range(10)] + mock_tqdm.return_value = case_operators # Mock results mock_results = [ @@ -2051,13 +2075,13 @@ def test_execution_method_performance_comparison(self, mock_compute_case_operato for i in range(10) ] - # Test serial execution timing - call _run_serial directly + # Test serial execution timing - call _run_evaluation in serial mode mock_compute_case_operator.side_effect = mock_results start_time = time.time() - serial_result = evaluate._run_serial(case_operators) + serial_result = evaluate._run_evaluation(case_operators, parallel_config=None) serial_time = time.time() - start_time - # Test parallel execution timing - call _run_parallel directly with mocked + # Test parallel execution timing - call _run_parallel_evaluation directly with mocked # Parallel serial_call_count = mock_compute_case_operator.call_count mock_compute_case_operator.side_effect = mock_results @@ -2070,7 +2094,7 @@ def test_execution_method_performance_comparison(self, mock_compute_case_operato mock_parallel_instance.return_value = mock_results start_time = time.time() - parallel_result = evaluate._run_parallel( + parallel_result = evaluate._run_parallel_evaluation( case_operators, parallel_config={"backend": "threading", "n_jobs": 2} ) parallel_time = time.time() - start_time @@ -2087,9 +2111,11 @@ def test_execution_method_performance_comparison(self, mock_compute_case_operato assert parallel_time >= 0 @mock.patch("extremeweatherbench.evaluate.compute_case_operator") - def test_mixed_execution_parameters(self, mock_compute_case_operator): + @mock.patch("tqdm.auto.tqdm") + def test_mixed_execution_parameters(self, mock_tqdm, mock_compute_case_operator): """Test various parameter combinations for execution methods.""" case_operators = [mock.Mock(), mock.Mock()] + mock_tqdm.return_value = case_operators mock_results = [ pd.DataFrame({"value": [1.0], "case_id_number": [1]}), pd.DataFrame({"value": [2.0], "case_id_number": [2]}), @@ -2112,7 +2138,7 @@ def test_mixed_execution_parameters(self, mock_compute_case_operator): mock_compute_case_operator.side_effect = mock_results if config["method"] == "serial": - result = evaluate._run_serial(*config["args"]) + result = evaluate._run_evaluation(*config["args"], parallel_config=None) # All configurations should produce valid results assert isinstance(result, list) assert len(result) == 2 @@ -2135,7 +2161,9 @@ def test_mixed_execution_parameters(self, mock_compute_case_operator): "n_jobs": n_jobs, } - result = evaluate._run_parallel(*config["args"], **kwargs) + result = evaluate._run_parallel_evaluation( + *config["args"], **kwargs + ) # All configurations should produce valid results assert isinstance(result, list) @@ -2156,13 +2184,19 @@ def mock_compute_with_kwargs(case_op, cache_dir, **kwargs): mock_compute_with_kwargs.captured_kwargs = {} - with mock.patch( - "extremeweatherbench.evaluate.compute_case_operator", - side_effect=mock_compute_with_kwargs, + with ( + mock.patch( + "extremeweatherbench.evaluate.compute_case_operator", + side_effect=mock_compute_with_kwargs, + ), + mock.patch("tqdm.auto.tqdm", return_value=[case_operator]), ): # Test serial kwargs propagation - result = evaluate._run_serial( - [case_operator], custom_param="serial_test", threshold=0.9 + result = evaluate._run_evaluation( + [case_operator], + parallel_config=None, + custom_param="serial_test", + threshold=0.9, ) captured = mock_compute_with_kwargs.captured_kwargs @@ -2185,7 +2219,7 @@ def mock_compute_with_kwargs(case_op, cache_dir, **kwargs): # Reset captured kwargs mock_compute_with_kwargs.captured_kwargs = {} - result = evaluate._run_parallel( + result = evaluate._run_parallel_evaluation( [case_operator], parallel_config={"backend": "threading", "n_jobs": 2}, custom_param="parallel_test", @@ -2198,20 +2232,20 @@ def mock_compute_with_kwargs(case_op, cache_dir, **kwargs): def test_empty_case_operators_all_methods(self): """Test that all execution methods handle empty case operator lists.""" - # Test _run_case_operators - result = evaluate._run_case_operators([], parallel_config={"n_jobs": 1}) + # Test _run_evaluation with parallel config + result = evaluate._run_evaluation([], parallel_config={"n_jobs": 1}) assert result == [] - result = evaluate._run_case_operators( + result = evaluate._run_evaluation( [], parallel_config={"backend": "threading", "n_jobs": 2} ) assert result == [] - # Test _run_serial - result = evaluate._run_serial([]) + # Test _run_evaluation in serial mode + result = evaluate._run_evaluation([], parallel_config=None) assert result == [] - # Test _run_parallel + # Test _run_parallel_evaluation with mock.patch( "extremeweatherbench.utils.ParallelTqdm" ) as mock_parallel_class: @@ -2219,17 +2253,21 @@ def test_empty_case_operators_all_methods(self): mock_parallel_class.return_value = mock_parallel_instance mock_parallel_instance.return_value = [] - result = evaluate._run_parallel( + result = evaluate._run_parallel_evaluation( [], parallel_config={"backend": "threading", "n_jobs": 2} ) assert result == [] @mock.patch("extremeweatherbench.evaluate.compute_case_operator") - def test_large_case_operator_list_handling(self, mock_compute_case_operator): + @mock.patch("tqdm.auto.tqdm") + def test_large_case_operator_list_handling( + self, mock_tqdm, mock_compute_case_operator + ): """Test handling of large numbers of case operators.""" # Create a large list of case operators num_cases = 100 case_operators = [mock.Mock() for _ in range(num_cases)] + mock_tqdm.return_value = case_operators # Create mock results mock_results = [ @@ -2241,7 +2279,7 @@ def test_large_case_operator_list_handling(self, mock_compute_case_operator): # Test serial execution mock_compute_case_operator.side_effect = mock_results - serial_results = evaluate._run_serial(case_operators) + serial_results = evaluate._run_evaluation(case_operators, parallel_config=None) assert len(serial_results) == num_cases assert mock_compute_case_operator.call_count == num_cases @@ -2257,7 +2295,7 @@ def test_large_case_operator_list_handling(self, mock_compute_case_operator): mock_parallel_class.return_value = mock_parallel_instance mock_parallel_instance.return_value = mock_results - parallel_results = evaluate._run_parallel( + parallel_results = evaluate._run_parallel_evaluation( case_operators, parallel_config={"backend": "threading", "n_jobs": 4} ) diff --git a/tests/test_evaluate_cli.py b/tests/test_evaluate_cli.py index e5797712..2bda6879 100644 --- a/tests/test_evaluate_cli.py +++ b/tests/test_evaluate_cli.py @@ -88,7 +88,7 @@ def test_default_mode_basic( # Mock the ExtremeWeatherBench class and its methods mock_ewb = mock.Mock() mock_ewb.case_operators = [mock.Mock(), mock.Mock()] # Mock 2 case operators - mock_ewb.run.return_value = pd.DataFrame({"test": [1, 2]}) + mock_ewb.run_evaluation.return_value = pd.DataFrame({"test": [1, 2]}) mock_ewb_class.return_value = mock_ewb # Mock loading default cases @@ -100,7 +100,7 @@ def test_default_mode_basic( assert result.exit_code == 0 mock_ewb_class.assert_called_once() - mock_ewb.run.assert_called_once() + mock_ewb.run_evaluation.assert_called_once() @mock.patch( "extremeweatherbench.defaults.get_brightband_evaluation_objects", @@ -119,7 +119,7 @@ def test_default_mode_with_cache_dir( """Test default mode with cache directory.""" mock_ewb = mock.Mock() mock_ewb.case_operators = [] - mock_ewb.run.return_value = pd.DataFrame() + mock_ewb.run_evaluation.return_value = pd.DataFrame() mock_ewb_class.return_value = mock_ewb mock_load_cases.return_value = [] @@ -145,7 +145,7 @@ def test_config_file_mode_basic( """Test basic config file mode execution.""" mock_ewb = mock.Mock() mock_ewb.case_operators = [mock.Mock()] - mock_ewb.run.return_value = pd.DataFrame({"test": [1]}) + mock_ewb.run_evaluation.return_value = pd.DataFrame({"test": [1]}) mock_ewb_class.return_value = mock_ewb result = runner.invoke( @@ -219,15 +219,15 @@ def test_parallel_execution( """Test parallel execution mode.""" mock_ewb = mock.Mock() mock_ewb.case_operators = [mock.Mock(), mock.Mock(), mock.Mock()] - mock_ewb.run.return_value = pd.DataFrame({"test": [1, 2, 3]}) + mock_ewb.run_evaluation.return_value = pd.DataFrame({"test": [1, 2, 3]}) mock_ewb_class.return_value = mock_ewb mock_load_cases.return_value = [] result = runner.invoke(evaluate_cli.cli_runner, ["--default", "--n-jobs", "3"]) assert result.exit_code == 0 - # Verify ewb.run was called with parallel config - mock_ewb.run.assert_called_once_with( + # Verify ewb.run_evaluation was called with parallel config + mock_ewb.run_evaluation.assert_called_once_with( n_jobs=3, parallel_config=None, ) @@ -244,7 +244,7 @@ def test_serial_execution_default( """Test that serial execution is default (parallel=1).""" mock_ewb = mock.Mock() mock_ewb.case_operators = [] - mock_ewb.run.return_value = pd.DataFrame() + mock_ewb.run_evaluation.return_value = pd.DataFrame() mock_ewb_class.return_value = mock_ewb mock_load_cases.return_value = [] @@ -252,7 +252,7 @@ def test_serial_execution_default( assert result.exit_code == 0 # Output suppressed - only check exit code - mock_ewb.run.assert_called_once() + mock_ewb.run_evaluation.assert_called_once() class TestCaseOperatorSaving: @@ -278,7 +278,7 @@ def test_save_case_operators( mock_case_op2 = {"id": 2, "type": "test_case_op"} mock_ewb = mock.Mock() mock_ewb.case_operators = [mock_case_op1, mock_case_op2] - mock_ewb.run.return_value = pd.DataFrame() + mock_ewb.run_evaluation.return_value = pd.DataFrame() mock_ewb_class.return_value = mock_ewb mock_load_cases.return_value = [] @@ -317,7 +317,7 @@ def test_save_case_operators_creates_directory( """Test that saving case operators creates parent directories.""" mock_ewb = mock.Mock() mock_ewb.case_operators = [] - mock_ewb.run.return_value = pd.DataFrame() + mock_ewb.run_evaluation.return_value = pd.DataFrame() mock_ewb_class.return_value = mock_ewb mock_load_cases.return_value = [] @@ -369,7 +369,7 @@ def test_output_directory_creation( """Test that output directory is created if it doesn't exist.""" mock_ewb = mock.Mock() mock_ewb.case_operators = [] - mock_ewb.run.return_value = pd.DataFrame() + mock_ewb.run_evaluation.return_value = pd.DataFrame() mock_ewb_class.return_value = mock_ewb mock_load_cases.return_value = [] @@ -395,7 +395,7 @@ def test_default_output_directory( """Test that default output directory is current working directory.""" mock_ewb = mock.Mock() mock_ewb.case_operators = [] - mock_ewb.run.return_value = pd.DataFrame() + mock_ewb.run_evaluation.return_value = pd.DataFrame() mock_ewb_class.return_value = mock_ewb mock_load_cases.return_value = [] @@ -435,7 +435,7 @@ def test_results_saved_to_csv( mock_ewb = mock.Mock() mock_ewb.case_operators = [] - mock_ewb.run.return_value = mock_results + mock_ewb.run_evaluation.return_value = mock_results mock_ewb_class.return_value = mock_ewb mock_load_cases.return_value = [] @@ -456,7 +456,7 @@ def test_empty_results_handling(self, mock_ewb_class, mock_load_cases, runner): """Test handling when no results are returned.""" mock_ewb = mock.Mock() mock_ewb.case_operators = [] - mock_ewb.run.return_value = pd.DataFrame() # Empty results + mock_ewb.run_evaluation.return_value = pd.DataFrame() # Empty results mock_ewb_class.return_value = mock_ewb mock_load_cases.return_value = [] @@ -469,7 +469,9 @@ def test_empty_results_handling(self, mock_ewb_class, mock_load_cases, runner): class TestHelperFunctions: """Test helper function functionality.""" - @mock.patch("extremeweatherbench.cases.load_ewb_events_yaml_into_case_list") + @mock.patch( + "extremeweatherbench.evaluate_cli.cases.load_ewb_events_yaml_into_case_list" + ) def test_load_default_cases(self, mock_load_yaml): """Test _load_default_cases function.""" mock_cases = [{"id": 1}] diff --git a/tests/test_init.py b/tests/test_init.py new file mode 100644 index 00000000..00517ae8 --- /dev/null +++ b/tests/test_init.py @@ -0,0 +1,259 @@ +"""Tests for the extremeweatherbench package __init__.py API.""" + +import types + + +class TestModuleImports: + """Test that submodules are importable and are actual modules.""" + + def test_calc_is_module(self): + """Test that calc is an actual module, not a SimpleNamespace.""" + from extremeweatherbench import calc + + assert isinstance(calc, types.ModuleType) + + def test_utils_is_module(self): + """Test that utils is an actual module, not a SimpleNamespace.""" + from extremeweatherbench import utils + + assert isinstance(utils, types.ModuleType) + + def test_metrics_is_module(self): + """Test that metrics is an actual module, not a SimpleNamespace.""" + from extremeweatherbench import metrics + + assert isinstance(metrics, types.ModuleType) + + def test_regions_is_module(self): + """Test that regions is an actual module, not a SimpleNamespace.""" + from extremeweatherbench import regions + + assert isinstance(regions, types.ModuleType) + + def test_derived_is_module(self): + """Test that derived is an actual module, not a SimpleNamespace.""" + from extremeweatherbench import derived + + assert isinstance(derived, types.ModuleType) + + def test_defaults_is_module(self): + """Test that defaults is an actual module, not a SimpleNamespace.""" + from extremeweatherbench import defaults + + assert isinstance(defaults, types.ModuleType) + + def test_cases_is_module(self): + """Test that cases is an actual module, not a SimpleNamespace.""" + from extremeweatherbench import cases + + assert isinstance(cases, types.ModuleType) + + +class TestModuleAccessPatterns: + """Test both import patterns work identically.""" + + def test_ewb_dot_notation_equals_direct_import_calc(self): + """Test ewb.calc is the same object as direct import.""" + import extremeweatherbench as ewb + from extremeweatherbench import calc + + assert ewb.calc is calc + + def test_ewb_dot_notation_equals_direct_import_metrics(self): + """Test ewb.metrics is the same object as direct import.""" + import extremeweatherbench as ewb + from extremeweatherbench import metrics + + assert ewb.metrics is metrics + + def test_ewb_dot_notation_equals_direct_import_utils(self): + """Test ewb.utils is the same object as direct import.""" + import extremeweatherbench as ewb + from extremeweatherbench import utils + + assert ewb.utils is utils + + +class TestModuleLevelConstants: + """Test that module-level constants are accessible.""" + + def test_calc_g0_accessible(self): + """Test that calc.g0 constant is accessible.""" + from extremeweatherbench import calc + + assert hasattr(calc, "g0") + assert calc.g0 == 9.80665 + + def test_calc_epsilon_accessible(self): + """Test that calc.epsilon constant is accessible.""" + from extremeweatherbench import calc + + assert hasattr(calc, "epsilon") + assert isinstance(calc.epsilon, float) + + +class TestPrivateFunctionAccess: + """Test that private functions are accessible for testing purposes.""" + + def test_calc_private_functions_accessible(self): + """Test that private functions in calc are accessible.""" + from extremeweatherbench import calc + + assert hasattr(calc, "_is_true_landfall") + assert hasattr(calc, "_detect_landfalls_wrapper") + assert hasattr(calc, "_mask_init_time_boundaries") + assert hasattr(calc, "_interpolate_and_format_landfalls") + + def test_utils_private_functions_accessible(self): + """Test that private functions in utils are accessible.""" + from extremeweatherbench import utils + + assert hasattr(utils, "_create_nan_dataarray") + assert hasattr(utils, "_cache_maybe_densify_helper") + + def test_derived_private_functions_accessible(self): + """Test that private functions in derived are accessible.""" + from extremeweatherbench import derived + + assert hasattr(derived, "_maybe_convert_variable_to_string") + + def test_defaults_private_functions_accessible(self): + """Test that private functions in defaults are accessible.""" + from extremeweatherbench import defaults + + assert hasattr(defaults, "_preprocess_cira_forecast_dataset") + + def test_regions_private_functions_accessible(self): + """Test that private functions in regions are accessible.""" + from extremeweatherbench import regions + + assert hasattr(regions, "_adjust_bounds_to_dataset_convention") + + +class TestPublicFunctionAccess: + """Test that all public functions are accessible via module.""" + + def test_calc_public_functions(self): + """Test public functions in calc are accessible.""" + from extremeweatherbench import calc + + assert hasattr(calc, "find_landfalls") + assert hasattr(calc, "nantrapezoid") + assert hasattr(calc, "dewpoint_from_specific_humidity") + assert hasattr(calc, "find_land_intersection") + assert hasattr(calc, "haversine_distance") + + def test_utils_public_functions(self): + """Test public functions in utils are accessible.""" + from extremeweatherbench import utils + + assert hasattr(utils, "reduce_dataarray") + assert hasattr(utils, "stack_dataarray_from_dims") + assert hasattr(utils, "convert_longitude_to_360") + + +class TestTopLevelImports: + """Test that top-level imports work for commonly used items.""" + + def test_top_level_metric_imports(self): + """Test that metrics can be imported at top level.""" + from extremeweatherbench import ( + MeanAbsoluteError, + MeanError, + MeanSquaredError, + RootMeanSquaredError, + ) + + assert MeanAbsoluteError is not None + assert MeanError is not None + assert MeanSquaredError is not None + assert RootMeanSquaredError is not None + + def test_top_level_input_imports(self): + """Test that input classes can be imported at top level.""" + from extremeweatherbench import ERA5, GHCN, IBTrACS, ZarrForecast + + assert ERA5 is not None + assert GHCN is not None + assert IBTrACS is not None + assert ZarrForecast is not None + + def test_top_level_region_imports(self): + """Test that region classes can be imported at top level.""" + from extremeweatherbench import BoundingBoxRegion, CenteredRegion, Region + + assert Region is not None + assert BoundingBoxRegion is not None + assert CenteredRegion is not None + + def test_top_level_case_imports(self): + """Test that case classes can be imported at top level.""" + from extremeweatherbench import CaseOperator, IndividualCase + + assert IndividualCase is not None + assert CaseOperator is not None + + def test_evaluation_alias(self): + """Test that evaluation alias works.""" + from extremeweatherbench import ExtremeWeatherBench, evaluation + + assert evaluation is ExtremeWeatherBench + + def test_load_cases_alias(self): + """Test that load_cases alias works.""" + from extremeweatherbench import ( + load_cases, + load_ewb_events_yaml_into_case_list, + ) + + assert load_cases is load_ewb_events_yaml_into_case_list + + +class TestNamespaceSubmodules: + """Test the convenience namespace submodules.""" + + def test_targets_namespace(self): + """Test targets SimpleNamespace contains expected items.""" + from extremeweatherbench import targets + + assert isinstance(targets, types.SimpleNamespace) + assert hasattr(targets, "ERA5") + assert hasattr(targets, "GHCN") + assert hasattr(targets, "IBTrACS") + assert hasattr(targets, "TargetBase") + + def test_forecasts_namespace(self): + """Test forecasts SimpleNamespace contains expected items.""" + from extremeweatherbench import forecasts + + assert isinstance(forecasts, types.SimpleNamespace) + assert hasattr(forecasts, "ZarrForecast") + assert hasattr(forecasts, "KerchunkForecast") + assert hasattr(forecasts, "ForecastBase") + + +class TestMockPatching: + """Test that mock.patch.object works with module imports.""" + + def test_mock_patch_object_on_calc(self): + """Test that mock.patch.object works on calc module.""" + from unittest import mock + + from extremeweatherbench import calc + + with mock.patch.object(calc, "haversine_distance") as mock_func: + mock_func.return_value = 42.0 + result = calc.haversine_distance([0, 0], [1, 1]) + assert result == 42.0 + mock_func.assert_called_once() + + def test_mock_patch_string_on_calc(self): + """Test that mock.patch with string path works on calc module.""" + from unittest import mock + + with mock.patch("extremeweatherbench.calc.haversine_distance") as mock_func: + mock_func.return_value = 100.0 + from extremeweatherbench import calc + + result = calc.haversine_distance([0, 0], [1, 1]) + assert result == 100.0 diff --git a/tests/test_integration.py b/tests/test_integration.py index 07fba348..43ba0d30 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -600,7 +600,7 @@ def test_full_workflow_single_variable( evaluation_objects=[evaluation_obj], ) - result = ewb.run() + result = ewb.run_evaluation() # Verify results assert isinstance(result, pd.DataFrame) @@ -679,7 +679,7 @@ def test_full_workflow_multiple_variables( evaluation_objects=[evaluation_obj], ) - result = ewb.run() + result = ewb.run_evaluation() # Verify results assert isinstance(result, pd.DataFrame) From aca8518d5821ce25e75ea699f6091a5cb3077a6d Mon Sep 17 00:00:00 2001 From: Taylor Mandelbaum Date: Mon, 26 Jan 2026 15:34:04 -0500 Subject: [PATCH 14/15] Golden tests (#323) * first pass for gt test infra + yaml * use shapefile for severe convection and catch latitude swap * add ignore for golden test when running pytest by default * ruff * move pytest addopts and markers to pyproject.toml * Remove `IndividualCaseCollection` (#317) * update all references to IndividualCaseCollection and convert dicts/ "cases": keys to lists * update template * make questions bold * add whitespace * remove indent error and typo from evaluate_cli * make load_individual_cases include passthrough for existing dataclasses * ruff * add comment for clarification on list comp * ruff (again) * remove all references to collection, replace with list * ruff * rename collection -> list * ruff * Cleanup docstrings in repo (#318) * update these docstrings * remove docstring changes markdown * update docstrings * update other docstrings * remove individualcasecollection reference, update based on develop changes * add explanation for dim reqs (#320) * Update `defaults` and `inputs` to include new CIRA icechunk store (#319) * more explicit naming, add func and model names var * add test coverage, ruff, linting * update readme for new cira approach * move cira func and model ref to inputs * update docs * module wasnt called for moved func * update tests for moving func and var * ruff * fix mock typos * Bump version from 0.2.0 to 0.3.0 (#324) * Updated API (#321) * move cache dir creation to init, rename funcs, add parallel/serial check function, update test names * update naming * add run method for backwards compatibility * update tests * add tests and cover if serial and parallel_config is not None * feat: redesign public API with hierarchical namespace submodules - Add ewb.evaluation() as main entry point (alias for ExtremeWeatherBench) - Create namespace submodules: ewb.targets, ewb.forecasts, ewb.metrics, ewb.derived, ewb.regions, ewb.cases, ewb.defaults - Expose all classes at top level for convenience (ewb.ERA5, etc.) - Add ewb.load_cases() convenience alias - Update all example files to use new import pattern - Update usage.md documentation - Maintain backward compatibility with existing imports * ruff/linting. add utils to init * add test coverage for module loading patterns * ruff * Cleanup docstrings in repo (#318) * update these docstrings * remove docstring changes markdown * update docstrings * update other docstrings * remove individualcasecollection reference, update based on develop changes * add explanation for dim reqs (#320) * Update `defaults` and `inputs` to include new CIRA icechunk store (#319) * more explicit naming, add func and model names var * add test coverage, ruff, linting * update readme for new cira approach * move cira func and model ref to inputs * update docs * module wasnt called for moved func * update tests for moving func and var * ruff * fix mock typos * update defaults var refs * remove to_csv --- pyproject.toml | 7 + pytest.ini | 5 - src/extremeweatherbench/regions.py | 8 +- tests/data/golden_tests.yaml | 57 ++++++ tests/data/south_carolina_110m.dbf | Bin 0 -> 11881 bytes tests/data/south_carolina_110m.prj | 1 + tests/data/south_carolina_110m.shp | Bin 0 -> 700 bytes tests/data/south_carolina_110m.shx | Bin 0 -> 108 bytes tests/test_golden.py | 282 +++++++++++++++++++++++++++++ 9 files changed, 354 insertions(+), 6 deletions(-) delete mode 100644 pytest.ini create mode 100644 tests/data/golden_tests.yaml create mode 100644 tests/data/south_carolina_110m.dbf create mode 100644 tests/data/south_carolina_110m.prj create mode 100644 tests/data/south_carolina_110m.shp create mode 100644 tests/data/south_carolina_110m.shx create mode 100644 tests/test_golden.py diff --git a/pyproject.toml b/pyproject.toml index 73eae194..be0ee124 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -174,3 +174,10 @@ build_command = """ github_release_mode = "latest" # Ensure assets are only uploaded for the current release, not past ones upload_assets_for_all_releases = false + +[tool.pytest] +addopts = ["--ignore=tests/test_golden.py", "--cov=extremeweatherbench"] +markers = [ + "integration: marks tests as integration tests (may be slow)", + "slow: marks tests as slow (may take longer to complete)", +] diff --git a/pytest.ini b/pytest.ini deleted file mode 100644 index 29aab096..00000000 --- a/pytest.ini +++ /dev/null @@ -1,5 +0,0 @@ -[pytest] -addopts = --cov=extremeweatherbench -markers = - integration: marks tests as integration tests (may be slow) - slow: marks tests as slow (may take longer to complete) \ No newline at end of file diff --git a/src/extremeweatherbench/regions.py b/src/extremeweatherbench/regions.py index 4e6c7730..7f86189e 100644 --- a/src/extremeweatherbench/regions.py +++ b/src/extremeweatherbench/regions.py @@ -361,8 +361,14 @@ def mask(self, dataset: xr.Dataset, drop: bool = False) -> xr.Dataset: # Note: ShapefileRegion.mask uses slice which doesn't support # prime/antimeridian crossing with OR logic, but regionmask handles it + # Check if latitude is ascending or descending to handle slice correctly + lat_ascending = dataset.latitude[0] < dataset.latitude[-1] + if lat_ascending: + lat_slice = slice(latitude_min, latitude_max) + else: + lat_slice = slice(latitude_max, latitude_min) dataset = dataset.sel( - latitude=slice(latitude_max, latitude_min), + latitude=lat_slice, longitude=slice(longitude_min, longitude_max), drop=drop, ) diff --git a/tests/data/golden_tests.yaml b/tests/data/golden_tests.yaml new file mode 100644 index 00000000..35036e6b --- /dev/null +++ b/tests/data/golden_tests.yaml @@ -0,0 +1,57 @@ +- case_id_number: 1 + title: NYC Heat Wave + start_date: 2022-06-19 12:00:00 + end_date: 2022-06-24 00:00:00 + location: + type: bounded_region + parameters: + latitude_min: 40.5 + latitude_max: 41.5 + longitude_min: -75 + longitude_max: -73.5 + event_type: heat_wave +- case_id_number: 2 + title: Europe Freeze + start_date: 2022-12-14 06:00:00 + end_date: 2022-12-18 00:00:00 + location: + type: bounded_region + parameters: + latitude_min: 50 + latitude_max: 55 + longitude_min: -5 + longitude_max: 5 + event_type: freeze +- case_id_number: 3 + title: April 2022 South Carolina + start_date: 2022-04-05 12:00:00 + end_date: 2022-04-06 12:00:00 + location: + type: shapefile_region + parameters: + shapefile_path: tests/data/south_carolina_110m.shp + event_type: severe_convection +- case_id_number: 4 + title: Atmospheric River Alaska + start_date: 2021-06-24 00:00:00 + end_date: 2021-06-27 00:00:00 + location: + type: bounded_region + parameters: + latitude_min: 50 + latitude_max: 55 + longitude_min: 185 + longitude_max: 200 + event_type: atmospheric_river +- case_id_number: 5 + title: Tropical Cyclone Max + start_date: 2023-10-06 00:00:00 + end_date: 2023-10-12 00:00:00 + location: + type: bounded_region + parameters: + latitude_min: 11.6 + latitude_max: 20.1 + longitude_min: 256.1 + longitude_max: 262.4 + event_type: tropical_cyclone diff --git a/tests/data/south_carolina_110m.dbf b/tests/data/south_carolina_110m.dbf new file mode 100644 index 0000000000000000000000000000000000000000..cdc66c520dd0d799c6d75f416679d514e0d98dc2 GIT binary patch literal 11881 zcmeHM&2Q6Y7!TTMhY4wiY1(DJ?Xsr%Y{+gQK-HC%4maD;dYib3SI5p`r>tn2SjvEv zRzn($4XtcwY1(dV6b9qF9MG;IoRYxBQq(bt{tVCKN4`XOB_{p-rw{5p3nEg z=__CCcw>jhb zeG?aetW$AWg;~Pu>(e<^eiNuS2sm%+FwQ|1rqFBvreHoh##B~@9?M+x>F z+#kf-pOA7IijEM!iPQS~kSYpN3Nc{)GjS9nf#1aKIfm_T8R(K;FqaR@|2KEq-r6@= zRnoQqvcDbI)F}Z3kua7Y&3{XPj9UR<v|@DDETr2HfPjKshTblh5?$o`?}Ns0dxc)B3; z5AttPrQ-@Cz8niTmk)}6c}Dyi+Th6X>x}wLG4g*nL-wz=Jy`wyx+(e0Wz`oQLCF3U z;OUYUe}Uou1fzb3GsN|IUwnUz=Y=8u06zaVp5O`o+4#PLbo^j~j_+mg$Km3G&7W<5 z>|;7U#Lyps=YAUh=rA4MvzLw!?xy3r8TlL9PuCx1W{{!vE$2U+}#{ERU8c}9QjX850A)NgEm9N+)1aWV4**yrca zn=oe{yFDoYAz%UXhOORQrdKk|j^$QP;+KK?oHjeDsw|0l@nKo?zfz-_(+?{0S|c(1FC z?PBV>LS^-HSEyYj^*L;Q`ysG7;`RjURs+}`_eG;lFis5qkTc?h$NAmvY{!Co_dqZd z2}h$_m&A9u1P}kF{?hT}pN-3olRGy$Zd&Z(U^wdc(dgmu)RV80P{X`a|IX7kz5X)R*d)8z&k^ z>r0K<#E7j6TsEIgQdaI|u>a+#}~ zt6YBYW#wMwcI8s#E*yA~%P5tN(ivlM-Y8u+O81P?FUH~#4hXJg^&)8h9S(YA)%o*%8!j;enb z3n@iX554d~#tdMy9x9w4a$7OcgsOj3|Ek`tu0H&RtDb|mBh|I)?dlD4h^9uw+|J+N z{X2>{&9%} literal 0 HcmV?d00001 diff --git a/tests/data/south_carolina_110m.prj b/tests/data/south_carolina_110m.prj new file mode 100644 index 00000000..f45cbadf --- /dev/null +++ b/tests/data/south_carolina_110m.prj @@ -0,0 +1 @@ +GEOGCS["GCS_WGS_1984",DATUM["D_WGS_1984",SPHEROID["WGS_1984",6378137.0,298.257223563]],PRIMEM["Greenwich",0.0],UNIT["Degree",0.0174532925199433]] \ No newline at end of file diff --git a/tests/data/south_carolina_110m.shp b/tests/data/south_carolina_110m.shp new file mode 100644 index 0000000000000000000000000000000000000000..b08f9227eb7296ab1815dc482ff5bc286e5062d9 GIT binary patch literal 700 zcmZQzQ0HR63K-*FFf%Z)0_9>@Ieuzf3pvp6HgusKi-Uv5G?qDKi-Qjs2;c7Eo$ly> zRXeI6BLf4Y1~Fy<6*4G6_2!(v;u>R)uy@0Qr(&T_AqPB6&1EidIXG-Fn)lVfBji9p z`+KWCeg}sR%Nt7KApUddEg#e!92|J>&q?+RIUrzL$obmd!6C=YI>$dMlsgTo%SrTUc#AqPHOxUls|lY@hW z=JAVfOF|AvIPp15p5x%~XC_NbVhzx|521>DyB!>6ypJsG?G8CmAnX79)p=kzscA)j z=?^(@fR*Lt)%y+(GASlTw#z~e{MfKgSnnUueE+srx7Pyg``OQTU&PVjO~s>~$9IMt z*kSxvFTvK)VMgx+$HzxP4miw87S(cfbb!T!!$sr&c~OoID{ATuo}CUka6|X-!F5Fl z^&Q*1PX*LEI&3&__)E<`ko$JLl5TW#P`P-e_2cXic)V3~d_MIO=)VBjO&fkpb#xFY zSmE;#7`_SzmEH+YaCEqFbM?zSAU{C8Gvr8%qr(;^j-CIlf$pEjaK^e9XkXc*%AbZn z_r^x<>nnG3aA*k=Q8fbc7tU>3k>Tj@g|Fr^-|OH56?s|yap{imbhF^ZPeJE*4h}wL zIV_J32Os$H_VoeF{lM^(DfePO8+_md<9vCyDGm-VWS7725C+=UmfLLvG#))Y004;R B8lwOJ literal 0 HcmV?d00001 diff --git a/tests/data/south_carolina_110m.shx b/tests/data/south_carolina_110m.shx new file mode 100644 index 0000000000000000000000000000000000000000..5070f67fc67363f8ae6897e29a746c2650175ea0 GIT binary patch literal 108 zcmZQzQ0HR64$NLKGcd3MR literal 0 HcmV?d00001 diff --git a/tests/test_golden.py b/tests/test_golden.py new file mode 100644 index 00000000..8476f0ba --- /dev/null +++ b/tests/test_golden.py @@ -0,0 +1,282 @@ +"""Tests which use the full end-to-end EWB workflow. + +These tests are likely incompatible with Github Actions and will be used on a VM +or other virtual environment. These are intended to be fairly lightweight marquee +examples of each event type and core metrics. If the values deviate from expected +for a release, it will be flagged as a failure.""" + + +# Load case data from the default events.yaml + +import pathlib + +import pytest + +from extremeweatherbench import cases, defaults, derived, evaluate, inputs, metrics + + +@pytest.fixture(scope="module") +def reference_data_dir(): + """Path to reference data directory.""" + path = pathlib.Path(__file__).parent / "data" + if not path.exists(): + pytest.skip( + "Reference data not found. Run 'uv run data/generate_cape_reference_data.py' first." + ) + return path + + +@pytest.fixture(scope="module") +def golden_tests_event_data(reference_data_dir): + """Load golden tests event data.""" + ref_file = reference_data_dir / "golden_tests.yaml" + if not ref_file.exists(): + pytest.skip(f"Golden tests event data not found: {ref_file}") + + return cases.load_individual_cases_from_yaml(ref_file) + + +@pytest.mark.integration +class TestGoldenTests: + """Golden tests.""" + + def test_heatwaves(self, golden_tests_event_data): + """Heatwave tests.""" + # Define heatwave objects + era5_heatwave_target = inputs.ERA5() + ghcn_heatwave_target = inputs.GHCN() + + heatwave_metrics = [ + metrics.MaximumMeanAbsoluteError( + forecast_variable="surface_air_temperature", + target_variable="surface_air_temperature", + ), + metrics.RootMeanSquaredError( + forecast_variable="surface_air_temperature", + target_variable="surface_air_temperature", + ), + metrics.DurationMeanError( + threshold_criteria=defaults.get_climatology(quantile=0.85) + ), + metrics.MaximumLowestMeanAbsoluteError( + forecast_variable="surface_air_temperature", + target_variable="surface_air_temperature", + ), + ] + + hres_heatwave_forecast = inputs.ZarrForecast( + name="hres_heatwave_forecast", + source="gs://weatherbench2/datasets/hres/2016-2022-0012-1440x721.zarr", + variables=["surface_air_temperature"], + variable_mapping=inputs.HRES_metadata_variable_mapping, + ) + + heatwave_evaluation_objects = [ + inputs.EvaluationObject( + event_type="heat_wave", + metric_list=heatwave_metrics, + target=era5_heatwave_target, + forecast=hres_heatwave_forecast, + ), + inputs.EvaluationObject( + event_type="heat_wave", + metric_list=heatwave_metrics, + target=ghcn_heatwave_target, + forecast=hres_heatwave_forecast, + ), + ] + # Run the evaluation + ewb = evaluate.ExtremeWeatherBench( + case_metadata=golden_tests_event_data, + evaluation_objects=heatwave_evaluation_objects, + ) + + ewb.run( + parallel_config={ + "backend": "loky", + "n_jobs": len(heatwave_evaluation_objects) * len(heatwave_metrics), + }, + ) + + def test_freezes(self, golden_tests_event_data): + """Freeze tests.""" + era5_freeze_target = inputs.ERA5() + ghcn_freeze_target = inputs.GHCN() + hres_freeze_forecast = inputs.ZarrForecast( + name="hres_freeze_forecast", + source="gs://weatherbench2/datasets/hres/2016-2022-0012-1440x721.zarr", + variables=["surface_air_temperature"], + variable_mapping=inputs.HRES_metadata_variable_mapping, + ) + freeze_metrics = [ + metrics.MinimumMeanAbsoluteError( + forecast_variable="surface_air_temperature", + target_variable="surface_air_temperature", + ), + metrics.RootMeanSquaredError( + forecast_variable="surface_air_temperature", + target_variable="surface_air_temperature", + ), + metrics.DurationMeanError( + threshold_criteria=defaults.get_climatology(quantile=0.15) + ), + ] + freeze_evaluation_objects = [ + inputs.EvaluationObject( + event_type="freeze", + metric_list=freeze_metrics, + target=era5_freeze_target, + forecast=hres_freeze_forecast, + ), + inputs.EvaluationObject( + event_type="freeze", + metric_list=freeze_metrics, + target=ghcn_freeze_target, + forecast=hres_freeze_forecast, + ), + ] + # Run the evaluation + ewb = evaluate.ExtremeWeatherBench( + case_metadata=golden_tests_event_data, + evaluation_objects=freeze_evaluation_objects, + ) + ewb.run( + parallel_config={ + "backend": "loky", + "n_jobs": len(freeze_evaluation_objects) * len(freeze_metrics), + }, + ) + + def test_severe_convection(self, golden_tests_event_data): + """Severe convection tests.""" + lsr_severe_convection_target = inputs.LSR() + pph_severe_convection_target = inputs.PPH() + hres_severe_convection_forecast = inputs.ZarrForecast( + name="hres_severe_convection_forecast", + source="gs://weatherbench2/datasets/hres/2016-2022-0012-1440x721.zarr", + variables=[derived.CravenBrooksSignificantSevere()], + variable_mapping=inputs.HRES_metadata_variable_mapping, + ) + severe_convection_metrics = [ + metrics.ThresholdMetric( + metrics=[metrics.CriticalSuccessIndex, metrics.FalseAlarmRatio], + forecast_threshold=15000, + target_threshold=0.3, + ), + metrics.EarlySignal(threshold=15000), + ] + severe_convection_evaluation_objects = [ + inputs.EvaluationObject( + event_type="severe_convection", + metric_list=severe_convection_metrics, + target=lsr_severe_convection_target, + forecast=hres_severe_convection_forecast, + ), + inputs.EvaluationObject( + event_type="severe_convection", + metric_list=severe_convection_metrics, + target=pph_severe_convection_target, + forecast=hres_severe_convection_forecast, + ), + ] + # Run the evaluation + ewb = evaluate.ExtremeWeatherBench( + case_metadata=golden_tests_event_data, + evaluation_objects=severe_convection_evaluation_objects, + ) + ewb.run( + parallel_config={ + "backend": "loky", + "n_jobs": len(severe_convection_evaluation_objects) + * len(severe_convection_metrics), + }, + ) + + def test_atmospheric_river(self, golden_tests_event_data): + """Atmospheric river tests.""" + era5_atmospheric_river_target = inputs.ERA5() + hres_atmospheric_river_forecast = inputs.ZarrForecast( + name="hres_atmospheric_river_forecast", + source="gs://weatherbench2/datasets/hres/2016-2022-0012-1440x721.zarr", + variables=[ + derived.AtmosphericRiverVariables( + output_variables=[ + "atmospheric_river_mask", + "integrated_vapor_transport", + "atmospheric_river_land_intersection", + ] + ) + ], + variable_mapping=inputs.HRES_metadata_variable_mapping, + ) + atmospheric_river_metrics = [ + metrics.CriticalSuccessIndex(), + metrics.EarlySignal(), + metrics.SpatialDisplacement(), + ] + atmospheric_river_evaluation_objects = [ + inputs.EvaluationObject( + event_type="atmospheric_river", + metric_list=atmospheric_river_metrics, + target=era5_atmospheric_river_target, + forecast=hres_atmospheric_river_forecast, + ), + ] + # Run the evaluation + ewb = evaluate.ExtremeWeatherBench( + case_metadata=golden_tests_event_data, + evaluation_objects=atmospheric_river_evaluation_objects, + ) + ewb.run( + parallel_config={ + "backend": "loky", + "n_jobs": len(atmospheric_river_evaluation_objects) + * len(atmospheric_river_metrics), + }, + ) + + def test_tropical_cyclone(self, golden_tests_event_data): + """Tropical cyclone tests.""" + ibtracs_tropical_cyclone_target = inputs.IBTrACS() + hres_tropical_cyclone_forecast = inputs.ZarrForecast( + name="hres_tropical_cyclone_forecast", + source="gs://weatherbench2/datasets/hres/2016-2022-0012-1440x721.zarr", + variables=[derived.TropicalCycloneTrackVariables()], + variable_mapping=inputs.HRES_metadata_variable_mapping, + ) + tropical_cyclone_metrics = [ + metrics.LandfallMetric( + metrics=[ + metrics.LandfallIntensityMeanAbsoluteError, + metrics.LandfallTimeMeanError, + metrics.LandfallDisplacement, + ], + approach="next", + forecast_variable="air_pressure_at_mean_sea_level", + target_variable="air_pressure_at_mean_sea_level", + ), + ] + tropical_cyclone_evaluation_objects = [ + inputs.EvaluationObject( + event_type="tropical_cyclone", + metric_list=tropical_cyclone_metrics, + target=ibtracs_tropical_cyclone_target, + forecast=hres_tropical_cyclone_forecast, + ), + ] + # Run the evaluation + ewb = evaluate.ExtremeWeatherBench( + case_metadata=golden_tests_event_data, + evaluation_objects=tropical_cyclone_evaluation_objects, + ) + ewb.run( + parallel_config={ + "backend": "loky", + "n_jobs": len(tropical_cyclone_evaluation_objects) + * len(tropical_cyclone_metrics), + }, + ) + + +if __name__ == "__main__": + pytest.main([__file__]) From ab1fe35d2ba6153667db29fef1207c1128dd5549 Mon Sep 17 00:00:00 2001 From: aaTman Date: Mon, 26 Jan 2026 23:02:00 +0000 Subject: [PATCH 15/15] swap pyproject tools to hatch; add if and packages-dir to publish --- .github/workflows/publish.yaml | 2 ++ pyproject.toml | 10 ++++------ 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index 04dfe2e9..c4167a05 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -61,6 +61,7 @@ jobs: publish-to-testpypi: name: Publish release distribution to TestPyPI runs-on: ubuntu-latest + if: startsWith(github.ref, 'refs/tags/') needs: - release-build @@ -82,3 +83,4 @@ jobs: uses: pypa/gh-action-pypi-publish@release/v1 with: repository-url: https://test.pypi.org/legacy/ + packages-dir: dist/ diff --git a/pyproject.toml b/pyproject.toml index be0ee124..d7dea885 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,13 +98,11 @@ complete = ["extremeweatherbench[data-prep,multiprocessing]"] requires = ["hatchling >= 1.26"] build-backend = "hatchling.build" -[tool.setuptools] -packages = ["extremeweatherbench"] -package-dir = { "" = "src" } -include-package-data = true +[tool.hatch.build.targets.wheel] +packages = ["src/extremeweatherbench"] -[tool.setuptools.package-data] -extremeweatherbench = ["data/**/*", "data/**/.*"] +[tool.hatch.build.targets.sdist] +include = ["src/extremeweatherbench/**/*"] [project.urls] Documentation = "https://extremeweatherbench.readthedocs.io/"