Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/examples/applied_ar.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


# Preprocess function for CIRA data using Brightband kerchunk parquets
def _preprocess_bb_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset:
def _preprocess_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset:
"""An example preprocess function that renames the time coordinate to lead_time,
creates a valid_time coordinate, and sets the lead time range and resolution not
present in the original dataset.
Expand Down Expand Up @@ -75,7 +75,7 @@ def _preprocess_bb_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset:
],
variable_mapping=inputs.CIRA_metadata_variable_mapping,
storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}},
preprocess=_preprocess_bb_cira_forecast_dataset,
preprocess=_preprocess_cira_forecast_dataset,
)

pang_forecast = inputs.KerchunkForecast(
Expand All @@ -88,7 +88,7 @@ def _preprocess_bb_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset:
],
variable_mapping=inputs.CIRA_metadata_variable_mapping,
storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}},
preprocess=_preprocess_bb_cira_forecast_dataset,
preprocess=_preprocess_cira_forecast_dataset,
)
# Create a list of evaluation objects for atmospheric river
ar_evaluation_objects = [
Expand Down
4 changes: 2 additions & 2 deletions docs/examples/applied_freeze.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import operator

from extremeweatherbench import cases, evaluate, inputs, metrics, defaults
from extremeweatherbench import cases, defaults, evaluate, inputs, metrics

# Set the logger level to INFO
logger = logging.getLogger("extremeweatherbench")
Expand All @@ -28,7 +28,7 @@
variables=["surface_air_temperature"],
variable_mapping=inputs.CIRA_metadata_variable_mapping,
storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}},
preprocess=defaults._preprocess_bb_cira_forecast_dataset,
preprocess=defaults._preprocess_cira_forecast_dataset,
)

# Load the climatology for DurationMeanError
Expand Down
54 changes: 4 additions & 50 deletions docs/examples/applied_tc.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,12 @@
import logging

import numpy as np
import xarray as xr

from extremeweatherbench import calc, cases, derived, evaluate, inputs, metrics
from extremeweatherbench import cases, defaults, derived, evaluate, inputs, metrics

# Set the logger level to INFO
logger = logging.getLogger("extremeweatherbench")
logger.setLevel(logging.INFO)


# Preprocessing function for CIRA data that includes geopotential thickness calculation
# required for tropical cyclone tracks
def _preprocess_bb_cira_tc_forecast_dataset(ds: xr.Dataset) -> xr.Dataset:
"""An example preprocess function that renames the time coordinate to lead_time,
creates a valid_time coordinate, and sets the lead time range and resolution not
present in the original dataset.

Args:
ds: The forecast dataset to rename.

Returns:
The renamed forecast dataset.
"""
ds = ds.rename({"time": "lead_time"})
# The evaluation configuration is used to set the lead time range and resolution.
ds["lead_time"] = np.array(
[i for i in range(0, 241, 6)], dtype="timedelta64[h]"
).astype("timedelta64[ns]")
ds["geopotential_thickness"] = calc.geopotential_thickness(
ds["z"], top_level=300, bottom_level=500
)
return ds


# Preprocessing function for HRES data that includes geopotential thickness calculation
# required for tropical cyclone tracks
def _preprocess_hres_forecast_dataset(ds: xr.Dataset) -> xr.Dataset:
"""An example preprocess function that renames the time coordinate to lead_time,
creates a valid_time coordinate, and sets the lead time range and resolution not
present in the original dataset.

Args:
ds: The forecast dataset to rename.
"""
ds["geopotential_thickness"] = calc.geopotential_thickness(
ds["geopotential"],
top_level=300,
bottom_level=500,
geopotential=True,
)
return ds


# Load the case collection from the YAML file
case_yaml = cases.load_ewb_events_yaml_into_case_collection()

Expand All @@ -72,7 +26,7 @@ def _preprocess_hres_forecast_dataset(ds: xr.Dataset) -> xr.Dataset:
variable_mapping=inputs.HRES_metadata_variable_mapping,
storage_options={"remote_options": {"anon": True}},
# Preprocess the HRES forecast to include geopotential thickness calculation
preprocess=_preprocess_hres_forecast_dataset,
preprocess=defaults._preprocess_hres_tc_forecast_dataset,
)

# Define FCNv2 forecast
Expand All @@ -83,7 +37,7 @@ def _preprocess_hres_forecast_dataset(ds: xr.Dataset) -> xr.Dataset:
# Define metadata variable mapping for FCNv2 forecast
variable_mapping=inputs.CIRA_metadata_variable_mapping,
# Preprocess the FCNv2 forecast to include geopotential thickness calculation
preprocess=_preprocess_bb_cira_tc_forecast_dataset,
preprocess=defaults._preprocess_cira_tc_forecast_dataset,
storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}},
)

Expand All @@ -96,7 +50,7 @@ def _preprocess_hres_forecast_dataset(ds: xr.Dataset) -> xr.Dataset:
variable_mapping=inputs.CIRA_metadata_variable_mapping,
# Preprocess the Pangu forecast to include geopotential thickness calculation
# which uses the same preprocessing function as the FCNv2 forecast
preprocess=_preprocess_bb_cira_tc_forecast_dataset,
preprocess=defaults._preprocess_cira_tc_forecast_dataset,
storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}},
)

Expand Down
29 changes: 15 additions & 14 deletions src/extremeweatherbench/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
]


def _preprocess_bb_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset:
def _preprocess_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset:
"""A preprocess function for CIRA data that renames the time coordinate to
lead_time, creates a valid_time coordinate, and sets the lead time range and
resolution not present in the original dataset.
Expand All @@ -79,7 +79,7 @@ def _preprocess_bb_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset:

# Preprocessing function for CIRA data that includes geopotential thickness calculation
# required for tropical cyclone tracks
def _preprocess_bb_cira_tc_forecast_dataset(ds: xr.Dataset) -> xr.Dataset:
def _preprocess_cira_tc_forecast_dataset(ds: xr.Dataset) -> xr.Dataset:
"""A preprocess function for CIRA data that includes geopotential thickness
calculation required for tropical cyclone tracks.

Expand All @@ -101,15 +101,15 @@ def _preprocess_bb_cira_tc_forecast_dataset(ds: xr.Dataset) -> xr.Dataset:
).astype("timedelta64[ns]")

# Calculate the geopotential thickness required for tropical cyclone tracks
ds["geopotential_thickness"] = calc.geopotential_thickness(
ds["z"], top_level=300, bottom_level=500
ds["geopotential_thickness"] = (
calc.geopotential_thickness(ds["z"], top_level=300, bottom_level=500) / 9.81
)
return ds


# Preprocessing function for HRES data that includes geopotential thickness calculation
# required for tropical cyclone tracks
def _preprocess_bb_hres_tc_forecast_dataset(ds: xr.Dataset) -> xr.Dataset:
def _preprocess_hres_tc_forecast_dataset(ds: xr.Dataset) -> xr.Dataset:
"""A preprocess function for CIRA data that includes geopotential thickness
calculation required for tropical cyclone tracks.

Expand All @@ -125,14 +125,15 @@ def _preprocess_bb_hres_tc_forecast_dataset(ds: xr.Dataset) -> xr.Dataset:
"""

# Calculate the geopotential thickness required for tropical cyclone tracks
ds["geopotential_thickness"] = calc.geopotential_thickness(
ds["geopotential"], top_level=300, bottom_level=500
ds["geopotential_thickness"] = (
calc.geopotential_thickness(ds["geopotential"], top_level=300, bottom_level=500)
/ 9.81
)
return ds


# Preprocess function for CIRA data using Brightband kerchunk parquets
def _preprocess_bb_ar_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset:
def _preprocess_ar_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset:
"""An example preprocess function that renames the time coordinate to lead_time,
creates a valid_time coordinate, and sets the lead time range and resolution not
present in the original dataset.
Expand Down Expand Up @@ -160,7 +161,7 @@ def _preprocess_bb_ar_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset:


# Preprocess function for CIRA data using Brightband kerchunk parquets
def _preprocess_bb_severe_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset:
def _preprocess_severe_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset:
"""An example preprocess function that renames the time coordinate to lead_time,
creates a valid_time coordinate, and sets the lead time range and resolution not
present in the original dataset.
Expand Down Expand Up @@ -248,7 +249,7 @@ def get_climatology(quantile: float = 0.85) -> xr.DataArray:
variables=["surface_air_temperature"],
variable_mapping=inputs.CIRA_metadata_variable_mapping,
storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}},
preprocess=_preprocess_bb_cira_forecast_dataset,
preprocess=_preprocess_cira_forecast_dataset,
)

cira_freeze_forecast = inputs.KerchunkForecast(
Expand All @@ -257,7 +258,7 @@ def get_climatology(quantile: float = 0.85) -> xr.DataArray:
variables=["surface_air_temperature"],
variable_mapping=inputs.CIRA_metadata_variable_mapping,
storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}},
preprocess=_preprocess_bb_cira_forecast_dataset,
preprocess=_preprocess_cira_forecast_dataset,
)

cira_tropical_cyclone_forecast = inputs.KerchunkForecast(
Expand All @@ -266,7 +267,7 @@ def get_climatology(quantile: float = 0.85) -> xr.DataArray:
variables=[derived.TropicalCycloneTrackVariables()],
variable_mapping=inputs.CIRA_metadata_variable_mapping,
storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}},
preprocess=_preprocess_bb_cira_tc_forecast_dataset,
preprocess=_preprocess_cira_tc_forecast_dataset,
)
cira_atmospheric_river_forecast = inputs.KerchunkForecast(
name="FourCastNetv2",
Expand All @@ -278,7 +279,7 @@ def get_climatology(quantile: float = 0.85) -> xr.DataArray:
],
variable_mapping=inputs.CIRA_metadata_variable_mapping,
storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}},
preprocess=_preprocess_bb_ar_cira_forecast_dataset,
preprocess=_preprocess_ar_cira_forecast_dataset,
)

cira_severe_convection_forecast = inputs.KerchunkForecast(
Expand All @@ -287,7 +288,7 @@ def get_climatology(quantile: float = 0.85) -> xr.DataArray:
variables=[derived.CravenBrooksSignificantSevere()],
variable_mapping=inputs.CIRA_metadata_variable_mapping,
storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}},
preprocess=_preprocess_bb_severe_cira_forecast_dataset,
preprocess=_preprocess_severe_cira_forecast_dataset,
)


Expand Down
10 changes: 5 additions & 5 deletions tests/test_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
class TestDefaults:
"""Test the defaults module."""

def test_preprocess_bb_cira_forecast_dataset(self):
"""Test the _preprocess_bb_cira_forecast_dataset function."""
def test_preprocess_cira_forecast_dataset(self):
"""Test the _preprocess_cira_forecast_dataset function."""

# Create a mock dataset with 'time' coordinate matching expected output size
# The function creates lead_time with 41 values (0 to 240 by 6)
Expand All @@ -21,7 +21,7 @@ def test_preprocess_bb_cira_forecast_dataset(self):
{"temperature": (["time"], temp_data)}, coords={"time": time_data}
)

result = defaults._preprocess_bb_cira_forecast_dataset(mock_ds)
result = defaults._preprocess_cira_forecast_dataset(mock_ds)

# Check that 'time' was renamed to 'lead_time'
assert "lead_time" in result.coords
Expand Down Expand Up @@ -157,11 +157,11 @@ def test_cira_forecasts_have_preprocess_function(self):
# Test that the preprocess function is the expected one
assert (
defaults.cira_heatwave_forecast.preprocess
== defaults._preprocess_bb_cira_forecast_dataset
== defaults._preprocess_cira_forecast_dataset
)
assert (
defaults.cira_freeze_forecast.preprocess
== defaults._preprocess_bb_cira_forecast_dataset
== defaults._preprocess_cira_forecast_dataset
)

def test_get_brightband_evaluation_objects_no_exceptions(self):
Expand Down
Loading