diff --git a/docs/examples/applied_ar.py b/docs/examples/applied_ar.py index d07a3edf..bc607f4d 100644 --- a/docs/examples/applied_ar.py +++ b/docs/examples/applied_ar.py @@ -14,7 +14,7 @@ # Preprocess function for CIRA data using Brightband kerchunk parquets -def _preprocess_bb_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: +def _preprocess_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: """An example preprocess function that renames the time coordinate to lead_time, creates a valid_time coordinate, and sets the lead time range and resolution not present in the original dataset. @@ -75,7 +75,7 @@ def _preprocess_bb_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: ], variable_mapping=inputs.CIRA_metadata_variable_mapping, storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, - preprocess=_preprocess_bb_cira_forecast_dataset, + preprocess=_preprocess_cira_forecast_dataset, ) pang_forecast = inputs.KerchunkForecast( @@ -88,7 +88,7 @@ def _preprocess_bb_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: ], variable_mapping=inputs.CIRA_metadata_variable_mapping, storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, - preprocess=_preprocess_bb_cira_forecast_dataset, + preprocess=_preprocess_cira_forecast_dataset, ) # Create a list of evaluation objects for atmospheric river ar_evaluation_objects = [ diff --git a/docs/examples/applied_freeze.py b/docs/examples/applied_freeze.py index 70cf00d9..6505081c 100644 --- a/docs/examples/applied_freeze.py +++ b/docs/examples/applied_freeze.py @@ -1,7 +1,7 @@ import logging import operator -from extremeweatherbench import cases, evaluate, inputs, metrics, defaults +from extremeweatherbench import cases, defaults, evaluate, inputs, metrics # Set the logger level to INFO logger = logging.getLogger("extremeweatherbench") @@ -28,7 +28,7 @@ variables=["surface_air_temperature"], variable_mapping=inputs.CIRA_metadata_variable_mapping, storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, - preprocess=defaults._preprocess_bb_cira_forecast_dataset, + preprocess=defaults._preprocess_cira_forecast_dataset, ) # Load the climatology for DurationMeanError diff --git a/docs/examples/applied_tc.py b/docs/examples/applied_tc.py index 743c2163..b5365218 100644 --- a/docs/examples/applied_tc.py +++ b/docs/examples/applied_tc.py @@ -1,58 +1,12 @@ import logging -import numpy as np -import xarray as xr - -from extremeweatherbench import calc, cases, derived, evaluate, inputs, metrics +from extremeweatherbench import cases, defaults, derived, evaluate, inputs, metrics # Set the logger level to INFO logger = logging.getLogger("extremeweatherbench") logger.setLevel(logging.INFO) -# Preprocessing function for CIRA data that includes geopotential thickness calculation -# required for tropical cyclone tracks -def _preprocess_bb_cira_tc_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: - """An example preprocess function that renames the time coordinate to lead_time, - creates a valid_time coordinate, and sets the lead time range and resolution not - present in the original dataset. - - Args: - ds: The forecast dataset to rename. - - Returns: - The renamed forecast dataset. - """ - ds = ds.rename({"time": "lead_time"}) - # The evaluation configuration is used to set the lead time range and resolution. - ds["lead_time"] = np.array( - [i for i in range(0, 241, 6)], dtype="timedelta64[h]" - ).astype("timedelta64[ns]") - ds["geopotential_thickness"] = calc.geopotential_thickness( - ds["z"], top_level=300, bottom_level=500 - ) - return ds - - -# Preprocessing function for HRES data that includes geopotential thickness calculation -# required for tropical cyclone tracks -def _preprocess_hres_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: - """An example preprocess function that renames the time coordinate to lead_time, - creates a valid_time coordinate, and sets the lead time range and resolution not - present in the original dataset. - - Args: - ds: The forecast dataset to rename. - """ - ds["geopotential_thickness"] = calc.geopotential_thickness( - ds["geopotential"], - top_level=300, - bottom_level=500, - geopotential=True, - ) - return ds - - # Load the case collection from the YAML file case_yaml = cases.load_ewb_events_yaml_into_case_collection() @@ -72,7 +26,7 @@ def _preprocess_hres_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: variable_mapping=inputs.HRES_metadata_variable_mapping, storage_options={"remote_options": {"anon": True}}, # Preprocess the HRES forecast to include geopotential thickness calculation - preprocess=_preprocess_hres_forecast_dataset, + preprocess=defaults._preprocess_hres_tc_forecast_dataset, ) # Define FCNv2 forecast @@ -83,7 +37,7 @@ def _preprocess_hres_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: # Define metadata variable mapping for FCNv2 forecast variable_mapping=inputs.CIRA_metadata_variable_mapping, # Preprocess the FCNv2 forecast to include geopotential thickness calculation - preprocess=_preprocess_bb_cira_tc_forecast_dataset, + preprocess=defaults._preprocess_cira_tc_forecast_dataset, storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, ) @@ -96,7 +50,7 @@ def _preprocess_hres_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: variable_mapping=inputs.CIRA_metadata_variable_mapping, # Preprocess the Pangu forecast to include geopotential thickness calculation # which uses the same preprocessing function as the FCNv2 forecast - preprocess=_preprocess_bb_cira_tc_forecast_dataset, + preprocess=defaults._preprocess_cira_tc_forecast_dataset, storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, ) diff --git a/src/extremeweatherbench/defaults.py b/src/extremeweatherbench/defaults.py index 461e55b4..a78cdcca 100644 --- a/src/extremeweatherbench/defaults.py +++ b/src/extremeweatherbench/defaults.py @@ -58,7 +58,7 @@ ] -def _preprocess_bb_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: +def _preprocess_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: """A preprocess function for CIRA data that renames the time coordinate to lead_time, creates a valid_time coordinate, and sets the lead time range and resolution not present in the original dataset. @@ -79,7 +79,7 @@ def _preprocess_bb_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: # Preprocessing function for CIRA data that includes geopotential thickness calculation # required for tropical cyclone tracks -def _preprocess_bb_cira_tc_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: +def _preprocess_cira_tc_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: """A preprocess function for CIRA data that includes geopotential thickness calculation required for tropical cyclone tracks. @@ -101,15 +101,15 @@ def _preprocess_bb_cira_tc_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: ).astype("timedelta64[ns]") # Calculate the geopotential thickness required for tropical cyclone tracks - ds["geopotential_thickness"] = calc.geopotential_thickness( - ds["z"], top_level=300, bottom_level=500 + ds["geopotential_thickness"] = ( + calc.geopotential_thickness(ds["z"], top_level=300, bottom_level=500) / 9.81 ) return ds # Preprocessing function for HRES data that includes geopotential thickness calculation # required for tropical cyclone tracks -def _preprocess_bb_hres_tc_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: +def _preprocess_hres_tc_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: """A preprocess function for CIRA data that includes geopotential thickness calculation required for tropical cyclone tracks. @@ -125,14 +125,15 @@ def _preprocess_bb_hres_tc_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: """ # Calculate the geopotential thickness required for tropical cyclone tracks - ds["geopotential_thickness"] = calc.geopotential_thickness( - ds["geopotential"], top_level=300, bottom_level=500 + ds["geopotential_thickness"] = ( + calc.geopotential_thickness(ds["geopotential"], top_level=300, bottom_level=500) + / 9.81 ) return ds # Preprocess function for CIRA data using Brightband kerchunk parquets -def _preprocess_bb_ar_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: +def _preprocess_ar_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: """An example preprocess function that renames the time coordinate to lead_time, creates a valid_time coordinate, and sets the lead time range and resolution not present in the original dataset. @@ -160,7 +161,7 @@ def _preprocess_bb_ar_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: # Preprocess function for CIRA data using Brightband kerchunk parquets -def _preprocess_bb_severe_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: +def _preprocess_severe_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: """An example preprocess function that renames the time coordinate to lead_time, creates a valid_time coordinate, and sets the lead time range and resolution not present in the original dataset. @@ -248,7 +249,7 @@ def get_climatology(quantile: float = 0.85) -> xr.DataArray: variables=["surface_air_temperature"], variable_mapping=inputs.CIRA_metadata_variable_mapping, storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, - preprocess=_preprocess_bb_cira_forecast_dataset, + preprocess=_preprocess_cira_forecast_dataset, ) cira_freeze_forecast = inputs.KerchunkForecast( @@ -257,7 +258,7 @@ def get_climatology(quantile: float = 0.85) -> xr.DataArray: variables=["surface_air_temperature"], variable_mapping=inputs.CIRA_metadata_variable_mapping, storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, - preprocess=_preprocess_bb_cira_forecast_dataset, + preprocess=_preprocess_cira_forecast_dataset, ) cira_tropical_cyclone_forecast = inputs.KerchunkForecast( @@ -266,7 +267,7 @@ def get_climatology(quantile: float = 0.85) -> xr.DataArray: variables=[derived.TropicalCycloneTrackVariables()], variable_mapping=inputs.CIRA_metadata_variable_mapping, storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, - preprocess=_preprocess_bb_cira_tc_forecast_dataset, + preprocess=_preprocess_cira_tc_forecast_dataset, ) cira_atmospheric_river_forecast = inputs.KerchunkForecast( name="FourCastNetv2", @@ -278,7 +279,7 @@ def get_climatology(quantile: float = 0.85) -> xr.DataArray: ], variable_mapping=inputs.CIRA_metadata_variable_mapping, storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, - preprocess=_preprocess_bb_ar_cira_forecast_dataset, + preprocess=_preprocess_ar_cira_forecast_dataset, ) cira_severe_convection_forecast = inputs.KerchunkForecast( @@ -287,7 +288,7 @@ def get_climatology(quantile: float = 0.85) -> xr.DataArray: variables=[derived.CravenBrooksSignificantSevere()], variable_mapping=inputs.CIRA_metadata_variable_mapping, storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, - preprocess=_preprocess_bb_severe_cira_forecast_dataset, + preprocess=_preprocess_severe_cira_forecast_dataset, ) diff --git a/tests/test_defaults.py b/tests/test_defaults.py index 092c7ced..cbbea8bd 100644 --- a/tests/test_defaults.py +++ b/tests/test_defaults.py @@ -10,8 +10,8 @@ class TestDefaults: """Test the defaults module.""" - def test_preprocess_bb_cira_forecast_dataset(self): - """Test the _preprocess_bb_cira_forecast_dataset function.""" + def test_preprocess_cira_forecast_dataset(self): + """Test the _preprocess_cira_forecast_dataset function.""" # Create a mock dataset with 'time' coordinate matching expected output size # The function creates lead_time with 41 values (0 to 240 by 6) @@ -21,7 +21,7 @@ def test_preprocess_bb_cira_forecast_dataset(self): {"temperature": (["time"], temp_data)}, coords={"time": time_data} ) - result = defaults._preprocess_bb_cira_forecast_dataset(mock_ds) + result = defaults._preprocess_cira_forecast_dataset(mock_ds) # Check that 'time' was renamed to 'lead_time' assert "lead_time" in result.coords @@ -157,11 +157,11 @@ def test_cira_forecasts_have_preprocess_function(self): # Test that the preprocess function is the expected one assert ( defaults.cira_heatwave_forecast.preprocess - == defaults._preprocess_bb_cira_forecast_dataset + == defaults._preprocess_cira_forecast_dataset ) assert ( defaults.cira_freeze_forecast.preprocess - == defaults._preprocess_bb_cira_forecast_dataset + == defaults._preprocess_cira_forecast_dataset ) def test_get_brightband_evaluation_objects_no_exceptions(self):