Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 10 additions & 18 deletions .github/pull_request_template.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,18 @@ Please include a summary of the change and which issue is fixed. Please also inc

## Type of change

Please delete options that are not relevant.
- [ ] Refactor
- [ ] Bug fix
- [ ] New feature
- [ ] Chore

- [ ] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
- [ ] This change requires a documentation update
### If this PR is a breaking change, please explain:

## How Has This Been Tested?

Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce. Please also list any relevant details for your test configuration.
What functionality is now broken?

What is the new approach?

Did you fix the documentation and up + downstream code accordingly?
## How Has This Been Tested?

## Checklist:

- [ ] My code follows the style guidelines of this project
- [ ] I have performed a self-review of my own code
- [ ] I have commented my code, particularly in hard-to-understand areas
- [ ] I have made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my feature works
- [ ] New and existing unit tests pass locally with my changes
- [ ] Any dependent changes have been merged and published in downstream modules
Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce. Please also list any relevant details for your test configuration.
51 changes: 7 additions & 44 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,48 +67,11 @@ $ ewb --default
```python
from extremeweatherbench import cases, inputs, metrics, evaluate, utils

# Select model
model = 'FOUR_v200_GFS'

# Set up path to directory of file - zarr or kerchunk/virtualizarr json/parquet
forecast_dir = f'gs://extremeweatherbench/{model}.parq'

# Preprocessing function exclusive to handling the CIRA parquets
def preprocess_bb_cira_forecast_dataset(ds: xr.Dataset) -> xr.Dataset:
"""Preprocess CIRA kerchunk (parquet) data in the ExtremeWeatherBench bucket.
A preprocess function that renames the time coordinate to lead_time,
creates a valid_time coordinate, and sets the lead time range and resolution not
present in the original dataset.
Args:
ds: The forecast dataset to rename.
Returns:
The renamed forecast dataset.
"""
ds = ds.rename({"time": "lead_time"})

# The evaluation configuration is used to set the lead time range and resolution.
ds["lead_time"] = np.array(
[i for i in range(0, 241, 6)], dtype="timedelta64[h]"
).astype("timedelta64[ns]")

return ds

# Define a forecast object; in this case, a KerchunkForecast
fcnv2_forecast = inputs.KerchunkForecast(
name="fcnv2_forecast", # identifier for this forecast in results
source=forecast_dir, # source path
variables=["surface_air_temperature"], # variables to use in the evaluation
variable_mapping=inputs.CIRA_metadata_variable_mapping, # mapping to use for variables in forecast dataset to EWB variable names
storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, # storage options for access
preprocess=preprocess_bb_cira_forecast_dataset # required preprocessing function for CIRA references
)
# Load in a forecast; here, we load in GFS initialized FCNv2 from the CIRA MLWP archive with a default variable built-in for convenience
fcnv2_heatwave_forecast = defaults.cira_fcnv2_heatwave_forecast

# Load in ERA5; source defaults to the ARCO ERA5 dataset from Google and variable mapping is provided by default as well
era5_heatwave_target = inputs.ERA5(
variables=["surface_air_temperature"], # variable to use in the evaluation
storage_options={"remote_options": {"anon": True}}, # storage options for access
chunks=None, # define chunks for the ERA5 data
)
# Load in ERA5 with another default convenience variable
era5_heatwave_target = defaults.era5_heatwave_target

# EvaluationObjects are used to evaluate a single forecast source against a single target source with a defined event type. Event types are declared with each case. One or more metrics can be evaluated with each EvaluationObject.
heatwave_evaluation_list = [
Expand All @@ -120,11 +83,11 @@ heatwave_evaluation_list = [
metrics.MaximumLowestMeanAbsoluteError(),
],
target=era5_heatwave_target,
forecast=fcnv2_forecast,
forecast=fcnv2_heatwave_forecast,
),
]
# Load in the EWB default list of event cases
case_metadata = cases.load_ewb_events_yaml_into_case_collection()
case_metadata = cases.load_ewb_events_yaml_into_case_list()

# Create the evaluation class, with cases and evaluation objects declared
ewb_instance = evaluate.ExtremeWeatherBench(
Expand All @@ -134,7 +97,7 @@ ewb_instance = evaluate.ExtremeWeatherBench(

# Execute a parallel run and return the evaluation results as a pandas DataFrame
heatwave_outputs = ewb_instance.run(
parallel_config={'backend':'loky','n_jobs':16} # Uses 16 jobs with the loky backend
parallel_config={'n_jobs':16} # Uses 16 jobs with the loky backend as default
)

# Save the results
Expand Down
8 changes: 4 additions & 4 deletions data_prep/ar_bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,8 +810,8 @@ def process_ar_event(
"\nProcessing: %s (Case %s)", single_case.title, single_case.case_id_number
)
# Create a case object for this event
case_collection = cases.load_individual_cases({"cases": [single_case]})
case = case_collection.cases[0]
case_list = cases.load_individual_cases([single_case])
case = case_list[0]
case.start_date = case.start_date - pd.Timedelta(days=3)
case.end_date = case.end_date + pd.Timedelta(days=3)

Expand Down Expand Up @@ -1034,8 +1034,8 @@ def main():
parallel = True

# Load atmospheric river events from the events.yaml file
events_yaml = cases.load_ewb_events_yaml_into_case_collection()
ar_events = events_yaml.select_cases(by="event_type", value="atmospheric_river")
events_yaml = cases.load_ewb_events_yaml_into_case_list()
ar_events = [n for n in events_yaml if n.event_type == "atmospheric_river"]
logger.info("Found %s atmospheric river events in events.yaml", len(ar_events))

# Process each atmospheric river event with enhanced object-based bounds calculation
Expand Down
236 changes: 236 additions & 0 deletions data_prep/cira_icechunk_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
"""
This script is used to generate an icechunk store for the CIRA MLWP data.

Credit to CIRA for producing the AIWP model data, Tom Nicholas for VirtualiZarr,
the Earthmover team for icechunk.

To access the icechunk store, you can use the following code:

import icechunk
import xarray as xr

test_storage = icechunk.gcs_storage(
bucket="extremeweatherbench", prefix="cira-icechunk", anonymous=True
)
test_repo = icechunk.Repository.open(test_storage)
session = test_repo.readonly_session(branch="main")
dt = xr.open_datatree(session.store, engine="zarr")

Which will return a DataTree object with the CIRA MLWP data for the models (except
FCNv1, which is not compatible with this approach).
"""

import logging
import warnings

import icechunk
import joblib
import numpy as np
import obstore as obs
import pandas as pd
import virtualizarr
import virtualizarr.parsers
import virtualizarr.registry
import xarray as xr

from extremeweatherbench import utils

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Try to suppress warnings about numcodecs codecs not being in the Zarr version 3
# specification. Warnings will still output when running it in parallel in joblib as
# there doesn't seem to be a way to suppress them in each newly spawned process.
warnings.filterwarnings(
"ignore",
message="Numcodecs codecs are not in the Zarr version 3 specification*",
category=UserWarning,
)


def process_single_virtual_dataset(
path: str,
parser: virtualizarr.parsers.HDFParser,
registry: virtualizarr.registry.ObjectStoreRegistry,
loadable_variables: list[str] = ["time", "latitude", "longitude", "level"],
decode_times: bool = True,
) -> xr.Dataset:
"""Process a single HDF/netCDF virtual dataset from a path.

Args:
path: The path to the virtual dataset (e.g. "s3://this/path/to/a/netcdf.nc").
parser: The parser to use to parse the virtual dataset.
registry: The virtualizarr.registry.ObjectStoreRegistry to use to access the virtual dataset.

Returns:
A Virtualizarr dataset
"""
vds = virtualizarr.open_virtual_dataset(
url=path,
parser=parser,
registry=registry,
loadable_variables=loadable_variables,
decode_times=decode_times,
)
return vds


def process_cira_model(
model_key: str, model_data: list[xr.Dataset]
) -> tuple[str, xr.Dataset | None]:
"""Merge a list of singular virtual datasets into a single concatenated dataset.

Args:
model_key: The key of the model.
model_data: A list of singular virtual datasets.

Returns:
A tuple of the model key and the concatenated dataset if successful, otherwise
None.
"""

# Some models e.g. FCNv1 are not compatible with this approach
# Also, in the scenario of a problem like variable chunking (ZEP003), concatenation
# will fail, so return None in that case.
try:
# Combine the virtual datasets into a single dataset. Args here are established
# defaults for virtualizarr that also work for this case. When creating a new
# dimension with concat_dim, join="override" and combine_attrs="drop_conflicts"
# prevents fancy indexing errors.
combined_vds = xr.combine_nested(
model_data,
concat_dim="init_time",
coords="minimal",
compat="override",
combine_attrs="drop_conflicts",
join="override",
)

# Rename the time coordinate to valid_time to be consistent with EWB conventions
combined_vds = combined_vds.rename({"time": "valid_time"})

# Assign the init_time attribute to the concatenated dataset
combined_vds = combined_vds.assign_coords(
init_time=[
pd.to_datetime(f.attrs["initialization_time"]) for f in model_data
],
)

# Hard code lead times for now, inconsistent values in netcdf attributes
lead_times = np.linspace(0, 240, 41).astype("timedelta64[h]")

# Assign the lead time coordinate to the concatenated dataset
combined_vds = combined_vds.assign_coords(lead_time=("valid_time", lead_times))

# Swap the valid_time and lead_time dimensions to be consistent with EWB
# conventions
combined_vds = combined_vds.swap_dims({"valid_time": "lead_time"})

# Return the model key and the concatenated dataset
return model_key, combined_vds

# If there is an error, log it and return a dict with value being None
except Exception as e:
logger.error(f"Error processing model {model_key}: {e}, returning None")
return model_key, None


def generate_cira_icechunk_store():
"""Generate a CIRA icechunk store from the CIRA MLWP data."""

# CIRA bucket URI
bucket = "s3://noaa-oar-mlwp-data"

# Build the ObjectStore from the URI, knowing the region and skipping signature
store = obs.store.from_url(bucket, region="us-east-1", skip_signature=True)

# Subset the prefixes to only include the model directories
prefix_list = [
n for n in obs.list_with_delimiter(store)["common_prefixes"] if n.endswith("FS")
]

# Build the ObjectStoreRegistry and HDFParser
registry = virtualizarr.registry.ObjectStoreRegistry({bucket: store})
parser = virtualizarr.parsers.HDFParser()

model_dict = {}
for model in prefix_list:
stream = obs.list(store, model + "/", chunk_size=1)
t = [n for n in stream]
t = [n for ns in t for n in ns]
stream = obs.list(store, model + "/", chunk_size=1)

with joblib.parallel_config(**{"backend": "loky", "n_jobs": -1}):
model_dict[model] = utils.ParallelTqdm(total=len(t))(
# None is the cache_dir, we can't cache in parallel mode
joblib.delayed(process_single_virtual_dataset)(
"s3://noaa-oar-mlwp-data/" + i[0]["path"], parser, registry
)
for i in stream
)

# Runs starting 27 May 2025 have a different chunking scheme, which cannot be
# concatenated for now (ZEP003)
single_chunk_model_dict = {
n: [
item
for item in model_dict[n]
if item["time"][0] < pd.to_datetime("2025-05-27T00:00:00.000000000")
]
for n in model_dict.keys()
}

concat_model_dict = {}

with joblib.parallel_config(**{"backend": "loky", "n_jobs": -1}):
results = utils.ParallelTqdm(total=len(single_chunk_model_dict))(
joblib.delayed(process_cira_model)(
model_key, single_chunk_model_dict[model_key]
)
for model_key in single_chunk_model_dict.keys()
)

# Filter out any None results and create a dictionary of model keys and concatenated
# datasets
concat_model_dict = {
model_key: result for model_key, result in results if result is not None
}

# Create a DataTree from the dictionary of model keys and concatenated datasets
cira_datatree = xr.DataTree.from_dict(concat_model_dict)

# Build the GCS storage for the icechunk repository. This will fail if you do not
# have the application credentials for write access to the EWB bucket. For another
# GCS store, run gcloud auth application-default login and find where the generated
# json credentials are stored, and use that path in application_credentials.
storage = icechunk.gcs_storage(
bucket="extremeweatherbench", prefix="cira-icechunk", application_credentials=""
)

# Build the RepositoryConfig with default config settings
config = icechunk.RepositoryConfig.default()

# Set the virtual chunk container to the CIRA bucket in S3
config.set_virtual_chunk_container(
icechunk.VirtualChunkContainer(
url_prefix="s3://noaa-oar-mlwp-data/",
store=icechunk.s3_store(region="us-east-1", anonymous=True),
),
)

# Create the repository
repo = icechunk.Repository.create(storage, config)

# Create a writable session to the repository
session = repo.writable_session(branch="main")

# Convert the DataTree to icechunk and commit to the repository
cira_datatree.vz.to_icechunk(session.store)

# Commit the changes to the repository; required for icechunk to be used from the
# GCS store
session.commit("drop in cira icechunk store")


if __name__ == "__main__":
generate_cira_icechunk_store()
Loading
Loading