diff --git a/README.md b/README.md index f0249f50..d90aff16 100644 --- a/README.md +++ b/README.md @@ -67,48 +67,11 @@ $ ewb --default ```python from extremeweatherbench import cases, inputs, metrics, evaluate, utils -# Select model -model = 'FOUR_v200_GFS' - -# Set up path to directory of file - zarr or kerchunk/virtualizarr json/parquet -forecast_dir = f'gs://extremeweatherbench/{model}.parq' - -# Preprocessing function exclusive to handling the CIRA parquets -def preprocess_bb_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: - """Preprocess CIRA kerchunk (parquet) data in the ExtremeWeatherBench bucket. - A preprocess function that renames the time coordinate to lead_time, - creates a valid_time coordinate, and sets the lead time range and resolution not - present in the original dataset. - Args: - ds: The forecast dataset to rename. - Returns: - The renamed forecast dataset. - """ - ds = ds.rename({"time": "lead_time"}) - - # The evaluation configuration is used to set the lead time range and resolution. - ds["lead_time"] = np.array( - [i for i in range(0, 241, 6)], dtype="timedelta64[h]" - ).astype("timedelta64[ns]") - - return ds - -# Define a forecast object; in this case, a KerchunkForecast -fcnv2_forecast = inputs.KerchunkForecast( - name="fcnv2_forecast", # identifier for this forecast in results - source=forecast_dir, # source path - variables=["surface_air_temperature"], # variables to use in the evaluation - variable_mapping=inputs.CIRA_metadata_variable_mapping, # mapping to use for variables in forecast dataset to EWB variable names - storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, # storage options for access - preprocess=preprocess_bb_cira_forecast_dataset # required preprocessing function for CIRA references -) +# Load in a forecast; here, we load in GFS initialized FCNv2 from the CIRA MLWP archive with a default variable built-in for convenience +fcnv2_heatwave_forecast = defaults.cira_fcnv2_heatwave_forecast -# Load in ERA5; source defaults to the ARCO ERA5 dataset from Google and variable mapping is provided by default as well -era5_heatwave_target = inputs.ERA5( - variables=["surface_air_temperature"], # variable to use in the evaluation - storage_options={"remote_options": {"anon": True}}, # storage options for access - chunks=None, # define chunks for the ERA5 data -) +# Load in ERA5 with another default convenience variable +era5_heatwave_target = defaults.era5_heatwave_target # EvaluationObjects are used to evaluate a single forecast source against a single target source with a defined event type. Event types are declared with each case. One or more metrics can be evaluated with each EvaluationObject. heatwave_evaluation_list = [ @@ -120,7 +83,7 @@ heatwave_evaluation_list = [ metrics.MaximumLowestMeanAbsoluteError(), ], target=era5_heatwave_target, - forecast=fcnv2_forecast, + forecast=fcnv2_heatwave_forecast, ), ] # Load in the EWB default list of event cases @@ -134,7 +97,7 @@ ewb_instance = evaluate.ExtremeWeatherBench( # Execute a parallel run and return the evaluation results as a pandas DataFrame heatwave_outputs = ewb_instance.run( - parallel_config={'backend':'loky','n_jobs':16} # Uses 16 jobs with the loky backend + parallel_config={'n_jobs':16} # Uses 16 jobs with the loky backend as default ) # Save the results diff --git a/docs/recipes/cira_forecast.md b/docs/recipes/cira_forecast.md index 43a7a6ac..9cf57b90 100644 --- a/docs/recipes/cira_forecast.md +++ b/docs/recipes/cira_forecast.md @@ -2,22 +2,10 @@ We have a dedicated virtual reference icechunk store for CIRA data **up to May 26th, 2025** available at `gs://extremeweatherbench/cira-icechunk`. Compared to using parquet virtual references, we have seen a speed improvements of around 2x with ~25% more memory usage. -## Loading the store - -```python - -from extremeweatherbench import cases, inputs, metrics, evaluate, defaults -import datetime -import icechunk - -storage = icechunk.gcs_storage( - bucket="extremeweatherbench", prefix="cira-icechunk", anonymous=True -) -``` - ## Accessing a CIRA Model from the store ```python +from extremeweatherbench import inputs group_list = inputs.list_groups_in_icechunk_datatree(storage) ``` @@ -39,22 +27,33 @@ group_list = inputs.list_groups_in_icechunk_datatree(storage) ```python -# Find FCNv2's name in the group list -fcnv2_group = [n for n in group_list if 'FOUR_v200_GFS' in n][0] - # Helper function to access the virtual dataset -fcnv2 = inputs.open_icechunk_dataset_from_datatree( +fcnv2 = inputs.get_cira_icechunk(model_name='FOUR_v200_IFS') +``` + +`fcnv2` is a `ForecastBase` object ready to be used within EWB's evaluation framework. + +> **Detailed Explanation**: `inputs.get_cira_icechunk` is syntactic sugar for this: +```python +import icechunk + +storage = icechunk.gcs_storage( + bucket="extremeweatherbench", prefix="cira-icechunk", anonymous=True +) + +fcnv2_icechunk_ds = inputs.open_icechunk_dataset_from_datatree( storage=storage, - group=fcnv2_group, + group="FOUR_v200_IFS", authorize_virtual_chunk_access=inputs.CIRA_CREDENTIALS ) -fcnv2_icechunk_forecast_object = inputs.XarrayForecast( + +fcnv2 = inputs.XarrayForecast( ds=fcnv2, variable_mapping=inputs.CIRA_metadata_variable_mapping ) ``` -`fcnv2_icechunk_forecast_object` is a `ForecastBase` object ready to be used within EWB's evaluation framework. +Which is a three step process of accessing the icechunk storage, loading the dataset from the datatree/zarr group format, and finally applying that `Dataset` in a `ForecastBase` object. ## Set up metrics and target for evaluation diff --git a/docs/usage.md b/docs/usage.md index a5e67b8c..89b10011 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -37,7 +37,7 @@ To run an evaluation, there are three components required: a forecast, a target, ```python from extremeweatherbench import inputs ``` -There are two built-in `ForecastBase` classes to set up a forecast: `ZarrForecast` and `KerchunkForecast`. Here is an example of a `ZarrForecast`, using Weatherbench2's HRES zarr store: +There are three built-in `ForecastBase` classes to set up a forecast: `ZarrForecast`, `XarrayForecast`, and `KerchunkForecast`. Here is an example of a `ZarrForecast`, using Weatherbench2's HRES zarr store: ```python hres_forecast = inputs.ZarrForecast( @@ -56,9 +56,9 @@ There are required arguments, namely: - `variables`* - `variable_mapping` -* `variables` can be defined within one or more metrics instead of in a `ForecastBase` object. +* `variables` can alternatively be defined within one or more metrics, instead of in a `ForecastBase` object. -A forecast needs a `source`, which is a link to the zarr store in this case. A `name` is required to identify the outputs. It also needs `variables` defined, which are based on CF Conventions. A list of variable namings exists in `defaults.py` as `DEFAULT_VARIABLE_NAMES`. Each forecast will likely have different names for their variables, so a `variable_mapping` dictionary is also essential to process the variables, as well as the coordinates and dimensions. EWB uses `lead_time`, `init_time`, and `valid_time` as time coordinates. The HRES data is mapped from `prediction_timedelta` to `lead_time`, as an example. `storage_options` define access patterns for the data if needed. These are passed to the opening function, e.g. `xarray.open_zarr`. +> **Detailed Explanation**: A forecast needs a `source`, which is a link to the zarr store in this case. A `name` is required to identify the outputs. It also needs `variables` defined, which are based on CF Conventions. A list of variable namings exists in `defaults.py` as `DEFAULT_VARIABLE_NAMES`. Each forecast will likely have different names for their variables, so a `variable_mapping` dictionary is also essential to process the variables, as well as the coordinates and dimensions. EWB uses `lead_time`, `init_time`, and `valid_time` as time coordinates. The HRES data is mapped from `prediction_timedelta` to `lead_time`, as an example. `storage_options` define access patterns for the data if needed. These are passed to the opening function, e.g. `xarray.open_zarr`. Next, a target dataset must be defined as well to evaluate against. For this evaluation, we'll use ERA5: @@ -71,7 +71,19 @@ era5_heatwave_target = inputs.ERA5( ) ``` -Similarly to forecasts, we need to define the `source`, which here is the ARCO ERA5 provided by Google. `variables` are again required to be set for the `inputs.ERA5` class; `variable_mapping` defaults to `inputs.ERA5_metadata_variable_mapping` for many existing variables and likely is not required to be set unless your use case is for less common variables. Both forecasts and targets, if relevant, have an optional `chunks` parameter which defaults to what should be the most efficient value - usually `None` or `'auto'`, but can be changed as seen above. +Note that EWB provides defaults for arguments, so most users will be able to instead write this (if defining variables with the intent of it applying to all metrics): + +```python +era5_heatwave_target = inputs.ERA5(variables=['surface_air_temperature']) +``` + +Or (if defining variables as arguments to the metrics): + +```python +era5_heatwave_target = inputs.ERA5() +``` + +> **Detailed Explanation**: Similarly to forecasts, we need to define the `source`, which here is the ARCO ERA5 provided by Google. `variables` are used to subset `inputs.ERA5` in an evaluation; `variable_mapping` defaults to `inputs.ERA5_metadata_variable_mapping` for many existing variables and likely is not required to be set unless your use case is for less common variables. Both forecasts and targets, if relevant, have an optional `chunks` parameter which defaults to what should be the most efficient value - usually `None` or `'auto'`, but can be changed as seen above. *If using the ARCO ERA5 and setting `chunks=None`, it is critical to order your subsetting by variables -> time -> `.sel` or `.isel` latitude & longitude -> rechunk. [See this Github comment](https://github.com/pydata/xarray/issues/8902#issuecomment-2036435045). We then set up an `EvaluationObject` list: @@ -98,11 +110,11 @@ Plugging these all in: ```python from extremeweatherbench import cases, evaluate -case_yaml = cases.load_ewb_events_yaml_into_case_list() +case_list = cases.load_ewb_events_yaml_into_case_list() ewb_instance = evaluate.ExtremeWeatherBench( - cases=case_yaml, + cases=case_list, evaluation_objects=heatwave_evaluation_list, ) @@ -111,6 +123,8 @@ outputs = ewb_instance.run() outputs.to_csv('your_file_name.csv') ``` -Where the EWB default events YAML file is loaded in using a built-in utility helper function, then applied to an instance of `evaluate.ExtremeWeatherBench` along with the `EvaluationObject` list. Finally, we run the evaluation with the `.run()` method, where defaults are typically sufficient to run with a small to moderate-sized virtual machine. after subsetting and prior to metric calculation. +Where the EWB default events YAML file is loaded in using a built-in utility helper function, then applied to an instance of `evaluate.ExtremeWeatherBench` along with the `EvaluationObject` list. Finally, we trigger the evaluation with the `.run()` method, where defaults are typically sufficient to run with a small to moderate-sized virtual machine. after subsetting and prior to metric calculation. + +Running locally is feasible but is typically bottlenecked heavily by IO and network bandwidth. Even on a gigabit connection, the rate of data access is significantly slower compared to within a cloud provider VM. The outputs are returned as a pandas DataFrame and can be manipulated in the script, a notebook, or post-hoc after saving it. diff --git a/src/extremeweatherbench/defaults.py b/src/extremeweatherbench/defaults.py index a78cdcca..7dcc68d6 100644 --- a/src/extremeweatherbench/defaults.py +++ b/src/extremeweatherbench/defaults.py @@ -58,28 +58,36 @@ ] -def _preprocess_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: +def _preprocess_cira_forecast_dataset( + ds: xr.Dataset, kerchunk: bool = True +) -> xr.Dataset: """A preprocess function for CIRA data that renames the time coordinate to lead_time, creates a valid_time coordinate, and sets the lead time range and resolution not present in the original dataset. Args: ds: The forecast dataset to preprocess. - + kerchunk: Whether the dataset is a kerchunk reference. Defaults to True. Returns: The preprocessed forecast dataset. """ - ds = ds.rename({"time": "lead_time"}) - # The evaluation configuration is used to set the lead time range and resolution. - ds["lead_time"] = np.array( - [i for i in range(0, 241, 6)], dtype="timedelta64[h]" - ).astype("timedelta64[ns]") + + # If the dataset is a kerchunk, we need to rename the time coordinate to lead_time + # and set the lead time range and resolution. Otherwise, pass through the dataset. + if kerchunk: + ds = ds.rename({"time": "lead_time"}) + # The evaluation configuration is used to set the lead time range and resolution. + ds["lead_time"] = np.array( + [i for i in range(0, 241, 6)], dtype="timedelta64[h]" + ).astype("timedelta64[ns]") return ds # Preprocessing function for CIRA data that includes geopotential thickness calculation # required for tropical cyclone tracks -def _preprocess_cira_tc_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: +def _preprocess_cira_tc_forecast_dataset( + ds: xr.Dataset, kerchunk: bool = True +) -> xr.Dataset: """A preprocess function for CIRA data that includes geopotential thickness calculation required for tropical cyclone tracks. @@ -89,16 +97,18 @@ def _preprocess_cira_tc_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: Args: ds: The forecast dataset to rename. - + kerchunk: Whether the dataset is a kerchunk reference. Defaults to True. Returns: The renamed forecast dataset. """ - ds = ds.rename({"time": "lead_time"}) - - # The evaluation configuration is used to set the lead time range and resolution. - ds["lead_time"] = np.array( - [i for i in range(0, 241, 6)], dtype="timedelta64[h]" - ).astype("timedelta64[ns]") + # If the dataset is a kerchunk, we need to rename the time coordinate to lead_time + # and set the lead time range and resolution. Otherwise, pass through the dataset. + if kerchunk: + ds = ds.rename({"time": "lead_time"}) + # The evaluation configuration is used to set the lead time range and resolution. + ds["lead_time"] = np.array( + [i for i in range(0, 241, 6)], dtype="timedelta64[h]" + ).astype("timedelta64[ns]") # Calculate the geopotential thickness required for tropical cyclone tracks ds["geopotential_thickness"] = ( @@ -133,23 +143,27 @@ def _preprocess_hres_tc_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: # Preprocess function for CIRA data using Brightband kerchunk parquets -def _preprocess_ar_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: +def _preprocess_cira_ar_forecast_dataset( + ds: xr.Dataset, kerchunk: bool = True +) -> xr.Dataset: """An example preprocess function that renames the time coordinate to lead_time, creates a valid_time coordinate, and sets the lead time range and resolution not present in the original dataset. Args: ds: The forecast dataset to rename. - + kerchunk: Whether the dataset is a kerchunk reference. Defaults to True. Returns: The renamed forecast dataset. """ - ds = ds.rename({"time": "lead_time"}) - - # The evaluation configuration is used to set the lead time range and resolution. - ds["lead_time"] = np.array( - [i for i in range(0, 241, 6)], dtype="timedelta64[h]" - ).astype("timedelta64[ns]") + # If the dataset is a kerchunk, we need to rename the time coordinate to lead_time + # and set the lead time range and resolution. Otherwise, pass through the dataset. + if kerchunk: + ds = ds.rename({"time": "lead_time"}) + # The evaluation configuration is used to set the lead time range and resolution. + ds["lead_time"] = np.array( + [i for i in range(0, 241, 6)], dtype="timedelta64[h]" + ).astype("timedelta64[ns]") if "q" not in ds.variables: # Calculate specific humidity from relative humidity and air temperature ds["specific_humidity"] = calc.specific_humidity_from_relative_humidity( @@ -161,23 +175,27 @@ def _preprocess_ar_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: # Preprocess function for CIRA data using Brightband kerchunk parquets -def _preprocess_severe_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset: +def _preprocess_severe_cira_forecast_dataset( + ds: xr.Dataset, kerchunk: bool = True +) -> xr.Dataset: """An example preprocess function that renames the time coordinate to lead_time, creates a valid_time coordinate, and sets the lead time range and resolution not present in the original dataset. Args: ds: The forecast dataset to rename. - + kerchunk: Whether the dataset is a kerchunk reference. Defaults to True. Returns: The renamed forecast dataset. """ - ds = ds.rename({"time": "lead_time"}) - - # The evaluation configuration is used to set the lead time range and resolution. - ds["lead_time"] = np.array( - [i for i in range(0, 241, 6)], dtype="timedelta64[h]" - ).astype("timedelta64[ns]") + # If the dataset is a kerchunk, we need to rename the time coordinate to lead_time + # and set the lead time range and resolution. Otherwise, pass through the dataset. + if kerchunk: + ds = ds.rename({"time": "lead_time"}) + # The evaluation configuration is used to set the lead time range and resolution. + ds["lead_time"] = np.array( + [i for i in range(0, 241, 6)], dtype="timedelta64[h]" + ).astype("timedelta64[ns]") if "q" not in ds.variables: # Calculate specific humidity from relative humidity and air temperature ds["specific_humidity"] = calc.specific_humidity_from_relative_humidity( @@ -243,51 +261,39 @@ def get_climatology(quantile: float = 0.85) -> xr.DataArray: ibtracs_target = inputs.IBTrACS() # Forecasts -cira_heatwave_forecast = inputs.KerchunkForecast( - name="FourCastNetv2", - source="gs://extremeweatherbench/FOUR_v200_GFS.parq", +cira_fcnv2_heatwave_forecast = inputs.get_cira_icechunk( + model_name="FOUR_v200_GFS", variables=["surface_air_temperature"], - variable_mapping=inputs.CIRA_metadata_variable_mapping, - storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, - preprocess=_preprocess_cira_forecast_dataset, + name="FourCastNetv2", ) -cira_freeze_forecast = inputs.KerchunkForecast( - name="FourCastNetv2", - source="gs://extremeweatherbench/FOUR_v200_GFS.parq", +cira_fcnv2_freeze_forecast = inputs.get_cira_icechunk( + model_name="FOUR_v200_GFS", variables=["surface_air_temperature"], - variable_mapping=inputs.CIRA_metadata_variable_mapping, - storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, - preprocess=_preprocess_cira_forecast_dataset, + name="FourCastNetv2", ) -cira_tropical_cyclone_forecast = inputs.KerchunkForecast( - name="FourCastNetv2", - source="gs://extremeweatherbench/FOUR_v200_GFS.parq", +cira_fcnv2_tropical_cyclone_forecast = inputs.get_cira_icechunk( + model_name="FOUR_v200_GFS", variables=[derived.TropicalCycloneTrackVariables()], - variable_mapping=inputs.CIRA_metadata_variable_mapping, - storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, + name="FourCastNetv2", preprocess=_preprocess_cira_tc_forecast_dataset, ) -cira_atmospheric_river_forecast = inputs.KerchunkForecast( - name="FourCastNetv2", - source="gs://extremeweatherbench/FOUR_v200_GFS.parq", +cira_fcnv2_atmospheric_river_forecast = inputs.get_cira_icechunk( + model_name="FOUR_v200_GFS", variables=[ derived.AtmosphericRiverVariables( output_variables=["atmospheric_river_land_intersection"] ) ], - variable_mapping=inputs.CIRA_metadata_variable_mapping, - storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, - preprocess=_preprocess_ar_cira_forecast_dataset, + name="FourCastNetv2", + preprocess=_preprocess_cira_ar_forecast_dataset, ) -cira_severe_convection_forecast = inputs.KerchunkForecast( - name="FourCastNetv2", - source="gs://extremeweatherbench/FOUR_v200_GFS.parq", +cira_fcnv2_severe_convection_forecast = inputs.get_cira_icechunk( + model_name="FOUR_v200_GFS", variables=[derived.CravenBrooksSignificantSevere()], - variable_mapping=inputs.CIRA_metadata_variable_mapping, - storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, + name="FourCastNetv2", preprocess=_preprocess_severe_cira_forecast_dataset, ) @@ -363,37 +369,37 @@ def get_brightband_evaluation_objects() -> list[inputs.EvaluationObject]: event_type="heat_wave", metric_list=heatwave_metric_list, target=era5_heatwave_target, - forecast=cira_heatwave_forecast, + forecast=cira_fcnv2_heatwave_forecast, ), inputs.EvaluationObject( event_type="heat_wave", metric_list=heatwave_metric_list, target=ghcn_heatwave_target, - forecast=cira_heatwave_forecast, + forecast=cira_fcnv2_heatwave_forecast, ), inputs.EvaluationObject( event_type="freeze", metric_list=freeze_metric_list, target=era5_freeze_target, - forecast=cira_freeze_forecast, + forecast=cira_fcnv2_freeze_forecast, ), inputs.EvaluationObject( event_type="freeze", metric_list=freeze_metric_list, target=ghcn_freeze_target, - forecast=cira_freeze_forecast, + forecast=cira_fcnv2_freeze_forecast, ), inputs.EvaluationObject( event_type="severe_convection", metric_list=pph_metric_list, target=pph_target, - forecast=cira_severe_convection_forecast, + forecast=cira_fcnv2_severe_convection_forecast, ), inputs.EvaluationObject( event_type="severe_convection", metric_list=lsr_metric_list, target=lsr_target, - forecast=cira_severe_convection_forecast, + forecast=cira_fcnv2_severe_convection_forecast, ), inputs.EvaluationObject( event_type="atmospheric_river", @@ -403,12 +409,12 @@ def get_brightband_evaluation_objects() -> list[inputs.EvaluationObject]: metrics.EarlySignal(), ], target=era5_atmospheric_river_target, - forecast=cira_atmospheric_river_forecast, + forecast=cira_fcnv2_atmospheric_river_forecast, ), inputs.EvaluationObject( event_type="tropical_cyclone", metric_list=composite_landfall_metrics, target=ibtracs_target, - forecast=cira_tropical_cyclone_forecast, + forecast=cira_fcnv2_tropical_cyclone_forecast, ), ] diff --git a/src/extremeweatherbench/inputs.py b/src/extremeweatherbench/inputs.py index cd0f52cd..8708dde9 100644 --- a/src/extremeweatherbench/inputs.py +++ b/src/extremeweatherbench/inputs.py @@ -145,6 +145,17 @@ {"s3://noaa-oar-mlwp-data/": icechunk.s3_credentials(anonymous=True)} ) +CIRA_MODEL_NAMES = [ + "AURO_v100_GFS", + "FOUR_v200_IFS", + "PANG_v100_IFS", + "FOUR_v200_GFS", + "GRAP_v100_GFS", + "AURO_v100_IFS", + "PANG_v100_GFS", + "GRAP_v100_IFS", +] + def _default_preprocess(input_data: IncomingDataInput) -> IncomingDataInput: """Default forecast preprocess function that does nothing.""" @@ -1268,3 +1279,51 @@ def check_for_missing_data( return False else: return True + + +def get_cira_icechunk( + model_name: str, + variables: list[Union[str, derived.DerivedVariable]] = [], + preprocess: Callable = _default_preprocess, + name: Optional[str] = None, +) -> XarrayForecast: + """Get a CIRA icechunk forecast object for a given model name. + + Args: + model_name: The name of the model from CIRA to get the forecast object for. For + example, "FOUR_v200_GFS". For a list of available models, see + `extremeweatherbench.defaults.CIRA_MODEL_NAMES`. + variables: The variables to select from the model. Defaults to all variables. + preprocess: The preprocessing function to apply to the model. Defaults to the + default passthrough preprocess function. + name: The name of the forecast object. Defaults to model_name by default unless + `name` is provided. + Returns: + An XarrayForecast object for the given model. + """ + # Check if the model name is valid + if model_name not in CIRA_MODEL_NAMES: + raise ValueError( + f"Model name {model_name} not found in CIRA_MODEL_NAMES. Model names must be one of: {CIRA_MODEL_NAMES}" + ) + + # Get the CIRA icechunkstorage + cira_storage = icechunk.gcs_storage( + bucket="extremeweatherbench", prefix="cira-icechunk", anonymous=True + ) + + # The models are distinct groups within the icechunk store; open the group + # corresponding to the model name + cira_model_ds = open_icechunk_dataset_from_datatree( + cira_storage, model_name, authorize_virtual_chunk_access=CIRA_CREDENTIALS + ) + + # Create the XarrayForecast object for the given model + cira_model_forecast = XarrayForecast( + ds=cira_model_ds, + variables=variables, + variable_mapping=CIRA_metadata_variable_mapping, + name=name if name else model_name, + preprocess=preprocess, + ) + return cira_model_forecast diff --git a/tests/test_defaults.py b/tests/test_defaults.py index cbbea8bd..ff452c91 100644 --- a/tests/test_defaults.py +++ b/tests/test_defaults.py @@ -120,10 +120,22 @@ def test_target_objects_exist(self): def test_forecast_objects_exist(self): """Test that forecast objects are properly defined.""" - assert hasattr(defaults, "cira_heatwave_forecast") - assert hasattr(defaults, "cira_freeze_forecast") - assert isinstance(defaults.cira_heatwave_forecast, inputs.KerchunkForecast) - assert isinstance(defaults.cira_freeze_forecast, inputs.KerchunkForecast) + assert hasattr(defaults, "cira_fcnv2_heatwave_forecast") + assert hasattr(defaults, "cira_fcnv2_freeze_forecast") + assert hasattr(defaults, "cira_fcnv2_tropical_cyclone_forecast") + assert hasattr(defaults, "cira_fcnv2_atmospheric_river_forecast") + assert hasattr(defaults, "cira_fcnv2_severe_convection_forecast") + assert isinstance(defaults.cira_fcnv2_heatwave_forecast, inputs.XarrayForecast) + assert isinstance(defaults.cira_fcnv2_freeze_forecast, inputs.XarrayForecast) + assert isinstance( + defaults.cira_fcnv2_tropical_cyclone_forecast, inputs.XarrayForecast + ) + assert isinstance( + defaults.cira_fcnv2_atmospheric_river_forecast, inputs.XarrayForecast + ) + assert isinstance( + defaults.cira_fcnv2_severe_convection_forecast, inputs.XarrayForecast + ) def test_era5_heatwave_target_configuration(self): """Test ERA5 heatwave target configuration.""" @@ -149,21 +161,6 @@ def test_era5_freeze_target_configuration(self): for key, value in expected_mapping.items(): assert target.variable_mapping[key] == value - def test_cira_forecasts_have_preprocess_function(self): - """Test that CIRA forecasts have the preprocess function set.""" - assert defaults.cira_heatwave_forecast.preprocess is not None - assert defaults.cira_freeze_forecast.preprocess is not None - - # Test that the preprocess function is the expected one - assert ( - defaults.cira_heatwave_forecast.preprocess - == defaults._preprocess_cira_forecast_dataset - ) - assert ( - defaults.cira_freeze_forecast.preprocess - == defaults._preprocess_cira_forecast_dataset - ) - def test_get_brightband_evaluation_objects_no_exceptions(self): """Test that get_brightband_evaluation_objects runs without exceptions.""" try: @@ -173,3 +170,63 @@ def test_get_brightband_evaluation_objects_no_exceptions(self): assert len(result) > 0 except Exception as e: pytest.fail(f"get_brightband_evaluation_objects raised an exception: {e}") + + +class TestCiraFcnv2PreprocessFunctions: + """Tests that each cira_fcnv2 forecast has the correct preprocessing function.""" + + def test_heatwave_forecast_has_default_preprocess(self): + """Test that cira_fcnv2_heatwave_forecast uses default preprocess.""" + forecast = defaults.cira_fcnv2_heatwave_forecast + assert forecast.preprocess == inputs._default_preprocess + + def test_freeze_forecast_has_default_preprocess(self): + """Test that cira_fcnv2_freeze_forecast uses default preprocess.""" + forecast = defaults.cira_fcnv2_freeze_forecast + assert forecast.preprocess == inputs._default_preprocess + + def test_tropical_cyclone_forecast_has_tc_preprocess(self): + """Test that cira_fcnv2_tropical_cyclone_forecast uses TC preprocess.""" + forecast = defaults.cira_fcnv2_tropical_cyclone_forecast + assert forecast.preprocess == defaults._preprocess_cira_tc_forecast_dataset + + def test_atmospheric_river_forecast_has_ar_preprocess(self): + """Test that cira_fcnv2_atmospheric_river_forecast uses AR preprocess.""" + forecast = defaults.cira_fcnv2_atmospheric_river_forecast + assert forecast.preprocess == defaults._preprocess_cira_ar_forecast_dataset + + def test_severe_convection_forecast_has_severe_preprocess(self): + """Test that cira_fcnv2_severe_convection_forecast uses severe preprocess.""" + forecast = defaults.cira_fcnv2_severe_convection_forecast + assert forecast.preprocess == defaults._preprocess_severe_cira_forecast_dataset + + def test_all_forecasts_have_preprocess_attribute(self): + """Test that all cira_fcnv2 forecasts have a preprocess attribute set.""" + forecasts = [ + defaults.cira_fcnv2_heatwave_forecast, + defaults.cira_fcnv2_freeze_forecast, + defaults.cira_fcnv2_tropical_cyclone_forecast, + defaults.cira_fcnv2_atmospheric_river_forecast, + defaults.cira_fcnv2_severe_convection_forecast, + ] + for forecast in forecasts: + assert hasattr(forecast, "preprocess") + assert forecast.preprocess is not None + assert callable(forecast.preprocess) + + def test_preprocess_functions_are_distinct_where_expected(self): + """Test that different event types use different preprocess functions.""" + # TC, AR, and severe should have distinct preprocess functions + tc_preprocess = defaults.cira_fcnv2_tropical_cyclone_forecast.preprocess + ar_preprocess = defaults.cira_fcnv2_atmospheric_river_forecast.preprocess + severe_preprocess = defaults.cira_fcnv2_severe_convection_forecast.preprocess + + assert tc_preprocess != ar_preprocess + assert tc_preprocess != severe_preprocess + # Note: AR and severe could be the same or different depending on impl + + def test_heatwave_and_freeze_use_same_preprocess(self): + """Test that heatwave and freeze forecasts use the same preprocess.""" + heatwave_preprocess = defaults.cira_fcnv2_heatwave_forecast.preprocess + freeze_preprocess = defaults.cira_fcnv2_freeze_forecast.preprocess + assert heatwave_preprocess == freeze_preprocess diff --git a/tests/test_inputs.py b/tests/test_inputs.py index 4531d150..0725afb0 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -2233,7 +2233,9 @@ def test_xarray_forecast_none_handling_for_optional_params( """Test that None values are properly converted to empty defaults.""" # Explicitly pass None to test the None handling in __init__ forecast = inputs.XarrayForecast( - ds=sample_forecast_with_valid_time, variables=None, variable_mapping=None + ds=sample_forecast_with_valid_time, + variables=None, + variable_mapping=None, # type: ignore ) # Should be converted to empty containers @@ -2323,3 +2325,266 @@ def test_default_preprocess(): df = pd.DataFrame({"a": [1, 2, 3]}) result_df = inputs._default_preprocess(df) assert result_df is df + + +class TestGetCIRAIcechunk: + """Tests for get_cira_icechunk function.""" + + def test_invalid_model_name_raises_value_error(self): + """Test that an invalid model name raises ValueError.""" + with pytest.raises(ValueError) as exc_info: + inputs.get_cira_icechunk(model_name="INVALID_MODEL") + + assert "INVALID_MODEL" in str(exc_info.value) + assert "CIRA_MODEL_NAMES" in str(exc_info.value) + + def test_empty_model_name_raises_value_error(self): + """Test that an empty model name raises ValueError.""" + with pytest.raises(ValueError): + inputs.get_cira_icechunk(model_name="") + + def test_none_model_name_raises_error(self): + """Test that None as model name raises appropriate error.""" + with pytest.raises((ValueError, TypeError)): + inputs.get_cira_icechunk(model_name=None) # type: ignore + + def test_case_sensitive_model_name(self): + """Test that model name matching is case-sensitive.""" + # Lowercase version of a valid model name should fail + with pytest.raises(ValueError): + inputs.get_cira_icechunk(model_name="four_v200_gfs") + + # Mixed case should fail + with pytest.raises(ValueError): + inputs.get_cira_icechunk(model_name="Four_V200_GFS") + + def test_partial_model_name_raises_value_error(self): + """Test that partial model names are rejected.""" + with pytest.raises(ValueError): + inputs.get_cira_icechunk(model_name="FOUR") + + with pytest.raises(ValueError): + inputs.get_cira_icechunk(model_name="GFS") + + def test_model_name_with_extra_chars_raises_value_error(self): + """Test that model names with extra characters are rejected.""" + with pytest.raises(ValueError): + inputs.get_cira_icechunk(model_name="FOUR_v200_GFS_extra") + + with pytest.raises(ValueError): + inputs.get_cira_icechunk(model_name=" FOUR_v200_GFS") + + def test_error_message_lists_valid_model_names(self): + """Test that the error message includes the list of valid model names.""" + with pytest.raises(ValueError) as exc_info: + inputs.get_cira_icechunk(model_name="BAD_MODEL") + + error_msg = str(exc_info.value) + # Check that at least some valid model names are shown in the error + assert "FOUR_v200_GFS" in error_msg or "CIRA_MODEL_NAMES" in error_msg + + @mock.patch("extremeweatherbench.inputs.icechunk.gcs_storage") + @mock.patch("extremeweatherbench.inputs.open_icechunk_dataset_from_datatree") + @mock.patch("extremeweatherbench.inputs.XarrayForecast") + def test_valid_model_name_four_v200_gfs( + self, mock_forecast, mock_open, mock_storage + ): + """Test that FOUR_v200_GFS is a valid model name.""" + mock_storage.return_value = mock.MagicMock() + mock_open.return_value = mock.MagicMock() + mock_forecast.return_value = mock.MagicMock() + + result = inputs.get_cira_icechunk(model_name="FOUR_v200_GFS") + + assert result is not None + mock_storage.assert_called_once() + mock_open.assert_called_once() + + @mock.patch("extremeweatherbench.inputs.icechunk.gcs_storage") + @mock.patch("extremeweatherbench.inputs.open_icechunk_dataset_from_datatree") + @mock.patch("extremeweatherbench.inputs.XarrayForecast") + def test_valid_model_name_auro_v100_gfs( + self, mock_forecast, mock_open, mock_storage + ): + """Test that AURO_v100_GFS is a valid model name.""" + mock_storage.return_value = mock.MagicMock() + mock_open.return_value = mock.MagicMock() + mock_forecast.return_value = mock.MagicMock() + + result = inputs.get_cira_icechunk(model_name="AURO_v100_GFS") + + assert result is not None + + @mock.patch("extremeweatherbench.inputs.icechunk.gcs_storage") + @mock.patch("extremeweatherbench.inputs.open_icechunk_dataset_from_datatree") + @mock.patch("extremeweatherbench.inputs.XarrayForecast") + def test_all_cira_model_names_are_valid( + self, mock_forecast, mock_open, mock_storage + ): + """Test that all model names in CIRA_MODEL_NAMES are accepted.""" + mock_storage.return_value = mock.MagicMock() + mock_open.return_value = mock.MagicMock() + mock_forecast.return_value = mock.MagicMock() + + for model_name in inputs.CIRA_MODEL_NAMES: + result = inputs.get_cira_icechunk(model_name=model_name) + assert result is not None, f"Model {model_name} should be valid" + + @mock.patch("extremeweatherbench.inputs.icechunk.gcs_storage") + @mock.patch("extremeweatherbench.inputs.open_icechunk_dataset_from_datatree") + @mock.patch("extremeweatherbench.inputs.XarrayForecast") + def test_custom_name_parameter(self, mock_forecast, mock_open, mock_storage): + """Test that a custom name parameter is passed to XarrayForecast.""" + mock_storage.return_value = mock.MagicMock() + mock_open.return_value = mock.MagicMock() + mock_forecast.return_value = mock.MagicMock() + + inputs.get_cira_icechunk(model_name="FOUR_v200_GFS", name="CustomName") + + # Check that XarrayForecast was called with the custom name + call_kwargs = mock_forecast.call_args[1] + assert call_kwargs["name"] == "CustomName" + + @mock.patch("extremeweatherbench.inputs.icechunk.gcs_storage") + @mock.patch("extremeweatherbench.inputs.open_icechunk_dataset_from_datatree") + @mock.patch("extremeweatherbench.inputs.XarrayForecast") + def test_default_name_uses_model_name(self, mock_forecast, mock_open, mock_storage): + """Test that name inputs to model_name when not provided.""" + mock_storage.return_value = mock.MagicMock() + mock_open.return_value = mock.MagicMock() + mock_forecast.return_value = mock.MagicMock() + + inputs.get_cira_icechunk(model_name="FOUR_v200_GFS") + + call_kwargs = mock_forecast.call_args[1] + assert call_kwargs["name"] == "FOUR_v200_GFS" + + @mock.patch("extremeweatherbench.inputs.icechunk.gcs_storage") + @mock.patch("extremeweatherbench.inputs.open_icechunk_dataset_from_datatree") + @mock.patch("extremeweatherbench.inputs.XarrayForecast") + def test_empty_variables_list(self, mock_forecast, mock_open, mock_storage): + """Test that empty variables list is valid.""" + mock_storage.return_value = mock.MagicMock() + mock_open.return_value = mock.MagicMock() + mock_forecast.return_value = mock.MagicMock() + + result = inputs.get_cira_icechunk(model_name="FOUR_v200_GFS", variables=[]) + + assert result is not None + call_kwargs = mock_forecast.call_args[1] + assert call_kwargs["variables"] == [] + + @mock.patch("extremeweatherbench.inputs.icechunk.gcs_storage") + @mock.patch("extremeweatherbench.inputs.open_icechunk_dataset_from_datatree") + @mock.patch("extremeweatherbench.inputs.XarrayForecast") + def test_custom_variables_list(self, mock_forecast, mock_open, mock_storage): + """Test that a custom variables list is passed through.""" + mock_storage.return_value = mock.MagicMock() + mock_open.return_value = mock.MagicMock() + mock_forecast.return_value = mock.MagicMock() + + variables = ["surface_air_temperature", "air_pressure"] + inputs.get_cira_icechunk(model_name="FOUR_v200_GFS", variables=variables) + + call_kwargs = mock_forecast.call_args[1] + assert call_kwargs["variables"] == variables + + @mock.patch("extremeweatherbench.inputs.icechunk.gcs_storage") + @mock.patch("extremeweatherbench.inputs.open_icechunk_dataset_from_datatree") + @mock.patch("extremeweatherbench.inputs.XarrayForecast") + def test_custom_preprocess_function(self, mock_forecast, mock_open, mock_storage): + """Test that a custom preprocess function is passed through.""" + mock_storage.return_value = mock.MagicMock() + mock_open.return_value = mock.MagicMock() + mock_forecast.return_value = mock.MagicMock() + + def custom_preprocess(ds: xr.Dataset) -> xr.Dataset: + return ds + + inputs.get_cira_icechunk( + model_name="FOUR_v200_GFS", preprocess=custom_preprocess + ) + + call_kwargs = mock_forecast.call_args[1] + assert call_kwargs["preprocess"] == custom_preprocess + + @mock.patch("extremeweatherbench.inputs.icechunk.gcs_storage") + @mock.patch("extremeweatherbench.inputs.open_icechunk_dataset_from_datatree") + @mock.patch("extremeweatherbench.inputs.XarrayForecast") + def test_returns_xarray_forecast_object( + self, mock_forecast, mock_open, mock_storage + ): + """Test that the function returns an XarrayForecast object.""" + mock_storage.return_value = mock.MagicMock() + mock_open.return_value = mock.MagicMock() + expected_forecast = mock.MagicMock() + mock_forecast.return_value = expected_forecast + + result = inputs.get_cira_icechunk(model_name="FOUR_v200_GFS") + + assert result is expected_forecast + + @mock.patch("extremeweatherbench.inputs.icechunk.gcs_storage") + @mock.patch("extremeweatherbench.inputs.open_icechunk_dataset_from_datatree") + @mock.patch("extremeweatherbench.inputs.XarrayForecast") + def test_gcs_storage_configuration(self, mock_forecast, mock_open, mock_storage): + """Test that GCS storage is configured with correct parameters.""" + mock_storage.return_value = mock.MagicMock() + mock_open.return_value = mock.MagicMock() + mock_forecast.return_value = mock.MagicMock() + + inputs.get_cira_icechunk(model_name="FOUR_v200_GFS") + + mock_storage.assert_called_once_with( + bucket="extremeweatherbench", prefix="cira-icechunk", anonymous=True + ) + + @mock.patch("extremeweatherbench.inputs.icechunk.gcs_storage") + @mock.patch("extremeweatherbench.inputs.open_icechunk_dataset_from_datatree") + @mock.patch("extremeweatherbench.inputs.XarrayForecast") + def test_uses_cira_variable_mapping(self, mock_forecast, mock_open, mock_storage): + """Test that CIRA metadata variable mapping is used.""" + mock_storage.return_value = mock.MagicMock() + mock_open.return_value = mock.MagicMock() + mock_forecast.return_value = mock.MagicMock() + + inputs.get_cira_icechunk(model_name="FOUR_v200_GFS") + + call_kwargs = mock_forecast.call_args[1] + assert call_kwargs["variable_mapping"] == inputs.CIRA_metadata_variable_mapping + + +class TestCiraModelNames: + """Tests for CIRA_MODEL_NAMES constant.""" + + def test_cira_model_names_is_list(self): + """Test that CIRA_MODEL_NAMES is a list.""" + assert isinstance(inputs.CIRA_MODEL_NAMES, list) + + def test_cira_model_names_not_empty(self): + """Test that CIRA_MODEL_NAMES is not empty.""" + assert len(inputs.CIRA_MODEL_NAMES) > 0 + + def test_cira_model_names_contains_expected_models(self): + """Test that CIRA_MODEL_NAMES contains expected model names.""" + expected_models = [ + "FOUR_v200_GFS", + "FOUR_v200_IFS", + "AURO_v100_GFS", + "AURO_v100_IFS", + "PANG_v100_GFS", + "PANG_v100_IFS", + "GRAP_v100_GFS", + "GRAP_v100_IFS", + ] + for model in expected_models: + assert model in inputs.CIRA_MODEL_NAMES + + def test_cira_model_names_all_strings(self): + """Test that all entries in CIRA_MODEL_NAMES are strings.""" + for model in inputs.CIRA_MODEL_NAMES: + assert isinstance(model, str) + + def test_cira_model_names_no_duplicates(self): + """Test that CIRA_MODEL_NAMES has no duplicate entries.""" + assert len(inputs.CIRA_MODEL_NAMES) == len(set(inputs.CIRA_MODEL_NAMES))