Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 236 additions & 0 deletions data_prep/cira_icechunk_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
"""
This script is used to generate an icechunk store for the CIRA MLWP data.

Credit to CIRA for producing the AIWP model data, Tom Nicholas for VirtualiZarr,
the Earthmover team for icechunk.

To access the icechunk store, you can use the following code:

import icechunk
import xarray as xr

test_storage = icechunk.gcs_storage(
bucket="extremeweatherbench", prefix="cira-icechunk", anonymous=True
)
test_repo = icechunk.Repository.open(test_storage)
session = test_repo.readonly_session(branch="main")
dt = xr.open_datatree(session.store, engine="zarr")

Which will return a DataTree object with the CIRA MLWP data for the models (except
FCNv1, which is not compatible with this approach).
"""

import logging
import warnings

import icechunk
import joblib
import numpy as np
import obstore as obs
import pandas as pd
import virtualizarr
import virtualizarr.parsers
import virtualizarr.registry
import xarray as xr

from extremeweatherbench import utils

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Try to suppress warnings about numcodecs codecs not being in the Zarr version 3
# specification. Warnings will still output when running it in parallel in joblib as
# there doesn't seem to be a way to suppress them in each newly spawned process.
warnings.filterwarnings(
"ignore",
message="Numcodecs codecs are not in the Zarr version 3 specification*",
category=UserWarning,
)


def process_single_virtual_dataset(
path: str,
parser: virtualizarr.parsers.HDFParser,
registry: virtualizarr.registry.ObjectStoreRegistry,
loadable_variables: list[str] = ["time", "latitude", "longitude", "level"],
decode_times: bool = True,
) -> xr.Dataset:
"""Process a single HDF/netCDF virtual dataset from a path.

Args:
path: The path to the virtual dataset (e.g. "s3://this/path/to/a/netcdf.nc").
parser: The parser to use to parse the virtual dataset.
registry: The virtualizarr.registry.ObjectStoreRegistry to use to access the virtual dataset.

Returns:
A Virtualizarr dataset
"""
vds = virtualizarr.open_virtual_dataset(
url=path,
parser=parser,
registry=registry,
loadable_variables=loadable_variables,
decode_times=decode_times,
)
return vds


def process_cira_model(
model_key: str, model_data: list[xr.Dataset]
) -> tuple[str, xr.Dataset | None]:
"""Merge a list of singular virtual datasets into a single concatenated dataset.

Args:
model_key: The key of the model.
model_data: A list of singular virtual datasets.

Returns:
A tuple of the model key and the concatenated dataset if successful, otherwise
None.
"""

# Some models e.g. FCNv1 are not compatible with this approach
# Also, in the scenario of a problem like variable chunking (ZEP003), concatenation
# will fail, so return None in that case.
try:
# Combine the virtual datasets into a single dataset. Args here are established
# defaults for virtualizarr that also work for this case. When creating a new
# dimension with concat_dim, join="override" and combine_attrs="drop_conflicts"
# prevents fancy indexing errors.
combined_vds = xr.combine_nested(
model_data,
concat_dim="init_time",
coords="minimal",
compat="override",
combine_attrs="drop_conflicts",
join="override",
)

# Rename the time coordinate to valid_time to be consistent with EWB conventions
combined_vds = combined_vds.rename({"time": "valid_time"})

# Assign the init_time attribute to the concatenated dataset
combined_vds = combined_vds.assign_coords(
init_time=[
pd.to_datetime(f.attrs["initialization_time"]) for f in model_data
],
)

# Hard code lead times for now, inconsistent values in netcdf attributes
lead_times = np.linspace(0, 240, 41).astype("timedelta64[h]")

# Assign the lead time coordinate to the concatenated dataset
combined_vds = combined_vds.assign_coords(lead_time=("valid_time", lead_times))

# Swap the valid_time and lead_time dimensions to be consistent with EWB
# conventions
combined_vds = combined_vds.swap_dims({"valid_time": "lead_time"})

# Return the model key and the concatenated dataset
return model_key, combined_vds

# If there is an error, log it and return a dict with value being None
except Exception as e:
logger.error(f"Error processing model {model_key}: {e}, returning None")
return model_key, None


def generate_cira_icechunk_store():
"""Generate a CIRA icechunk store from the CIRA MLWP data."""

# CIRA bucket URI
bucket = "s3://noaa-oar-mlwp-data"

# Build the ObjectStore from the URI, knowing the region and skipping signature
store = obs.store.from_url(bucket, region="us-east-1", skip_signature=True)

# Subset the prefixes to only include the model directories
prefix_list = [
n for n in obs.list_with_delimiter(store)["common_prefixes"] if n.endswith("FS")
]

# Build the ObjectStoreRegistry and HDFParser
registry = virtualizarr.registry.ObjectStoreRegistry({bucket: store})
parser = virtualizarr.parsers.HDFParser()

model_dict = {}
for model in prefix_list:
stream = obs.list(store, model + "/", chunk_size=1)
t = [n for n in stream]
t = [n for ns in t for n in ns]
stream = obs.list(store, model + "/", chunk_size=1)

with joblib.parallel_config(**{"backend": "loky", "n_jobs": -1}):
model_dict[model] = utils.ParallelTqdm(total=len(t))(
# None is the cache_dir, we can't cache in parallel mode
joblib.delayed(process_single_virtual_dataset)(
"s3://noaa-oar-mlwp-data/" + i[0]["path"], parser, registry
)
for i in stream
)

# Runs starting 27 May 2025 have a different chunking scheme, which cannot be
# concatenated for now (ZEP003)
single_chunk_model_dict = {
n: [
item
for item in model_dict[n]
if item["time"][0] < pd.to_datetime("2025-05-27T00:00:00.000000000")
]
for n in model_dict.keys()
}

concat_model_dict = {}

with joblib.parallel_config(**{"backend": "loky", "n_jobs": -1}):
results = utils.ParallelTqdm(total=len(single_chunk_model_dict))(
joblib.delayed(process_cira_model)(
model_key, single_chunk_model_dict[model_key]
)
for model_key in single_chunk_model_dict.keys()
)

# Filter out any None results and create a dictionary of model keys and concatenated
# datasets
concat_model_dict = {
model_key: result for model_key, result in results if result is not None
}

# Create a DataTree from the dictionary of model keys and concatenated datasets
cira_datatree = xr.DataTree.from_dict(concat_model_dict)

# Build the GCS storage for the icechunk repository. This will fail if you do not
# have the application credentials for write access to the EWB bucket. For another
# GCS store, run gcloud auth application-default login and find where the generated
# json credentials are stored, and use that path in application_credentials.
storage = icechunk.gcs_storage(
bucket="extremeweatherbench", prefix="cira-icechunk", application_credentials=""
)

# Build the RepositoryConfig with default config settings
config = icechunk.RepositoryConfig.default()

# Set the virtual chunk container to the CIRA bucket in S3
config.set_virtual_chunk_container(
icechunk.VirtualChunkContainer(
url_prefix="s3://noaa-oar-mlwp-data/",
store=icechunk.s3_store(region="us-east-1", anonymous=True),
),
)

# Create the repository
repo = icechunk.Repository.create(storage, config)

# Create a writable session to the repository
session = repo.writable_session(branch="main")

# Convert the DataTree to icechunk and commit to the repository
cira_datatree.vz.to_icechunk(session.store)

# Commit the changes to the repository; required for icechunk to be used from the
# GCS store
session.commit("drop in cira icechunk store")


if __name__ == "__main__":
generate_cira_icechunk_store()
121 changes: 121 additions & 0 deletions docs/recipes/cira_forecast.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Accessing a CIRA Forecast

We have a dedicated virtual reference icechunk store for CIRA data **up to May 26th, 2025** available at `gs://extremeweatherbench/cira-icechunk`. Compared to using parquet virtual references, we have seen a speed improvements of around 2x with ~25% more memory usage.

## Loading the store

```python

from extremeweatherbench import cases, inputs, metrics, evaluate, defaults
import datetime
import icechunk

storage = icechunk.gcs_storage(
bucket="extremeweatherbench", prefix="cira-icechunk", anonymous=True
)
```

## Accessing a CIRA Model from the store

```python

group_list = inputs.list_groups_in_icechunk_datatree(storage)
```

`group_list` is list of each group within the DataTree. Note that the list order will change, do not code a fixed numerical index in based on this output.

```['/',
'/GRAP_v100_IFS',
'/FOUR_v200_GFS',
'/PANG_v100_IFS',
'/PANG_v100_GFS',
'/AURO_v100_IFS',
'/FOUR_v200_IFS',
'/GRAP_v100_GFS',
'/AURO_v100_GFS']
```

## Loading the data as an XarrayObject

```python

# Find FCNv2's name in the group list
fcnv2_group = [n for n in group_list if 'FOUR_v200_GFS' in n][0]

# Helper function to access the virtual dataset
fcnv2 = inputs.open_icechunk_dataset_from_datatree(
storage=storage,
group=fcnv2_group,
authorize_virtual_chunk_access=inputs.CIRA_CREDENTIALS
)
fcnv2_icechunk_forecast_object = inputs.XarrayForecast(
ds=fcnv2,
variable_mapping=inputs.CIRA_metadata_variable_mapping
)
```

`fcnv2_icechunk_forecast_object` is a `ForecastBase` object ready to be used within EWB's evaluation framework.

## Set up metrics and target for evaluation

```python
metrics_list = [

# Assign the forecast and target variable based on EWB's variable naming
metrics.MaximumMeanAbsoluteError(
forecast_variable='surface_air_temperature',
target_variable='surface_air_temperature'
)

# Arbitrary thresholds to check CSI on the temperature; how did the models do
# spatially for the upper echelons of heat?
metrics.CriticalSuccessIndex(
forecast_variable='surface_air_temperature',
target_variable='surface_air_temperature',
forecast_threshold=310,
target_threshold=310
)

]

# Load in GHCNh target
ghcn_target = inputs.GHCN()
```

## Load in case metadata

```python

# Use EWB's cases and subset to the first two heat waves
case_vals = cases.load_ewb_events_yaml_into_case_collection()
case_vals.select_cases('case_id_number', [1,2],inplace=True)
```

From here, all we need to do is plug in the event type, metric list, target, and forecast
to an `EvaluationObject` and run EWB's evaluation engine:

```python

evaluation_object = [
inputs.EvaluationObject(
event_type="heat_wave",
metric_list=metrics_list,
target=ghcn_target,
forecast=fcnv2_icechunk_forecast_object,
),
]

ewb = evaluate.ExtremeWeatherBench(
case_metadata=case_vals,
evaluation_objects=evaluation_object
)

# Set up parallel configuration for the run to pass into joblib
parallel_config = {
'backend':'loky',
'n_jobs':4,
'backend_params':{'timeout':1}
}

output = ewb.run(parallel_config=parallel_config)
```
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,13 @@ dependencies = [
[project.optional-dependencies]
data-prep = [
"fsspec>=2024.12.0",
"icechunk>=1.1.14",
"matplotlib>=3.10.0",
"obstore>=0.8.2",
"scipy>=1.13",
"seaborn>=0.13.2",
"ujson>=5.10.0",
"scipy>=1.13",
"virtualizarr==2.1.2",
]
multiprocessing = ["dask[complete]>=2025.1.0", "distributed>=2025.1.0"]

Expand Down
Loading
Loading