Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datasets/.gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
/*/download.zarr
/*/download
/*/standardized.zarr
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@ dependencies = [
"cftime~=1.6.0",
"dask~=2024.12.0",
"fsspec~=2024.10.0",
"requests~=2.32.3",
"tqdm~=4.67.1",
"typed-classproperties~=1.1.0",
"xarray~=2024.11.0",
"zarr~=2.18.0",
]

optional-dependencies.data = [
"aiohttp~=3.11.0",
"netcdf4~=1.7.2",
"pandas~=2.2.0",
]

Expand All @@ -26,6 +29,8 @@ dev = [
"pandas-stubs~=2.2.0",
"pytest~=8.3",
"ruff~=0.8",
"types-requests~=2.32.0.20241016",
"types-tqdm~=4.67.0.20250301",
"universal-pathlib~=0.2.0",
]

Expand Down
43 changes: 29 additions & 14 deletions src/climatebenchpress/data_loader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
]

from pathlib import Path
from typing import Optional

import xarray as xr

Expand All @@ -20,16 +21,17 @@ def open_downloaded_canonicalized_dataset(
) -> xr.Dataset:
datasets = basepath / "datasets"

download = datasets / cls.name / "download.zarr"
download = datasets / cls.name / "download"
if not download.exists():
ds = cls.open()

with monitor.progress_bar(progress):
ds.to_zarr(download, encoding=dict(), compute=False).compute()
download.mkdir(parents=True, exist_ok=True)
# The download function is responsible for checking whether the download is
# complete or not. If the previous download was interrupt it will resume the download.
# If the download is complete it will skip the download.
cls.download(download, progress)

standardized = datasets / cls.name / "standardized.zarr"
if not standardized.exists():
ds = xr.open_dataset(download, chunks=dict(), engine="zarr")
ds = cls.open(download)
ds = canon.canonicalize_dataset(ds)

with monitor.progress_bar(progress):
Expand All @@ -42,23 +44,36 @@ def open_downloaded_tiny_canonicalized_dataset(
cls: type[Dataset],
basepath: Path = Path(),
progress: bool = True,
slices: Optional[dict[str, slice]] = None,
) -> xr.Dataset:
datasets = basepath / "datasets"

download = datasets / f"{cls.name}-tiny" / "download.zarr"
download = datasets / f"{cls.name}" / "download"
Comment thread
treigerm marked this conversation as resolved.
if not download.exists():
ds = cls.open()
ds = canon.canonical_tiny_dataset(ds)

with monitor.progress_bar(progress):
ds.to_zarr(download, encoding=dict(), compute=False).compute()
download.mkdir(parents=True, exist_ok=True)
cls.download(download, progress)

standardized = datasets / f"{cls.name}-tiny" / "standardized.zarr"
if not standardized.exists():
ds = xr.open_dataset(download, chunks=dict(), engine="zarr")
ds = cls.open(download)
ds = canon.canonicalize_dataset(ds)
ds = canon.canonical_tiny_dataset(ds, slices=slices)
# Rechunk the data because "tiny-fication" can lead to inconsistent or
# suboptimal chunking.
ds = _rechunk_dataset(ds)

with monitor.progress_bar(progress):
ds.to_zarr(standardized, encoding=dict(), compute=False).compute()
ds.to_zarr(
standardized, encoding=dict(), compute=False, consolidated=True
).compute()

return xr.open_dataset(standardized, chunks=dict(), engine="zarr")


def _rechunk_dataset(ds: xr.Dataset) -> xr.Dataset:
rechunked = ds.copy()
for var_name in ds.data_vars:
if hasattr(ds[var_name].data, "chunks"):
rechunked[var_name] = ds[var_name].chunk("auto")

return rechunked
16 changes: 15 additions & 1 deletion src/climatebenchpress/data_loader/canon.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,27 @@ def _ensure_axis(da: xr.DataArray, c: str) -> tuple[xr.DataArray, str]:


def canonicalize_variable(da: xr.DataArray) -> xr.DataArray:
# It makes little sense to invent every coordinate, so keep
# zero-dimensional variables as-is
if len(da.dims) == 0:
return da

da_old = da.copy(deep=False)

da, realization = _ensure_axis(da, "E")
da, time = _ensure_axis(da, "T")
da, vertical = _ensure_axis(da, "Z")
da, latitude = _ensure_axis(da, "Y")
da, longitude = _ensure_axis(da, "X")

return da.transpose(realization, time, vertical, latitude, longitude)
new_dims = [realization, time, vertical, latitude, longitude]

# Some variables contain other dimensions (e.g. DIM_bnds),
# let's not touch these'
if not all(d in new_dims for d in da.dims):
return da_old

return da.transpose(*new_dims)


def canonicalize_dataset(ds: xr.Dataset):
Expand Down
8 changes: 7 additions & 1 deletion src/climatebenchpress/data_loader/datasets/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from abc import ABC, abstractmethod
from collections.abc import Mapping
from inspect import isabstract
from pathlib import Path
from types import MappingProxyType

import xarray as xr
Expand All @@ -15,7 +16,12 @@ class Dataset(ABC):

@staticmethod
@abstractmethod
def open() -> xr.Dataset:
def download(download_path: Path, progress: bool = True):
pass

@staticmethod
@abstractmethod
def open(download_path: Path) -> xr.Dataset:
pass

# Class interface
Expand Down
3 changes: 2 additions & 1 deletion src/climatebenchpress/data_loader/datasets/all.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# ruff: noqa: F403

from .era5 import *
from .cmip6.all import *
from .era5 import *
from .esa_biomass_cci import *
32 changes: 28 additions & 4 deletions src/climatebenchpress/data_loader/datasets/cmip6/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
]

from functools import lru_cache
from pathlib import Path
from typing import Optional

import fsspec
import pandas as pd
import xarray as xr

from ... import monitor
from ..abc import Dataset


Expand All @@ -20,9 +23,20 @@ class Cmip6Dataset(Dataset):
table_id: str

@staticmethod
def open_with(
model_id: str, ssp_id: str, variable_id: str, table_id: str
) -> xr.Dataset:
def download_with(
download_path: Path,
model_id: str,
ssp_id: str,
variable_id: str,
table_id: str,
variable_selector: Optional[list[str]] = None,
progress: bool = True,
):
downloadfile = download_path / "download.zarr"
donefile = downloadfile.parent / (downloadfile.name + ".done")
if donefile.exists():
return

df = Cmip6Dataset.get_stores()

df_ta = df.query(
Expand All @@ -33,7 +47,17 @@ def open_with(
zstore = df_ta.zstore.values[-1]
zstore = zstore.replace("gs://", "https://storage.googleapis.com/")

return xr.open_zarr(fsspec.get_mapper(zstore), consolidated=True)
ds = xr.open_zarr(fsspec.get_mapper(zstore), consolidated=True)
if variable_selector is not None:
ds = ds[variable_selector]
with monitor.progress_bar(progress):
ds.to_zarr(downloadfile, mode="w", encoding=dict(), compute=False).compute()

donefile.touch()

@staticmethod
def open(download_path: Path) -> xr.Dataset:
return xr.open_zarr(download_path / "download.zarr")

@lru_cache
@staticmethod
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__all__ = ["Cmip6AtmosphereAccessDataset"]

import xarray as xr
from pathlib import Path

from ... import (
open_downloaded_canonicalized_dataset,
Expand All @@ -16,12 +16,14 @@ class Cmip6AtmosphereAccessDataset(Cmip6AtmosphereDataset):
ssp_id = "ssp585"

@staticmethod
def open() -> xr.Dataset:
return Cmip6Dataset.open_with(
def download(download_path: Path, progress: bool = True):
Cmip6Dataset.download_with(
download_path,
Cmip6AtmosphereAccessDataset.model_id,
Cmip6AtmosphereAccessDataset.ssp_id,
Cmip6AtmosphereAccessDataset.variable_id,
Cmip6AtmosphereAccessDataset.table_id,
progress=progress,
)


Expand Down
12 changes: 7 additions & 5 deletions src/climatebenchpress/data_loader/datasets/cmip6/access_ocean.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__all__ = ["Cmip6OceanAccessDataset"]

import xarray as xr
from pathlib import Path

from ... import (
open_downloaded_canonicalized_dataset,
Expand All @@ -16,15 +16,17 @@ class Cmip6OceanAccessDataset(Cmip6OceanDataset):
ssp_id = "ssp585"

@staticmethod
def open() -> xr.Dataset:
ds = Cmip6Dataset.open_with(
def download(download_path: Path, progress: bool = True):
Cmip6Dataset.download_with(
download_path,
Cmip6OceanAccessDataset.model_id,
Cmip6OceanAccessDataset.ssp_id,
Cmip6OceanAccessDataset.variable_id,
Cmip6OceanAccessDataset.table_id,
# Only download the actual sea surface temperature.
variable_selector=["tos"],
progress=progress,
)
# Only keep the actual sea surface temperature.
return ds[["tos"]]


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__all__ = ["Cmip6AtmosphereCanEsm5Dataset"]

import xarray as xr
from pathlib import Path

from ... import (
open_downloaded_canonicalized_dataset,
Expand All @@ -16,12 +16,14 @@ class Cmip6AtmosphereCanEsm5Dataset(Cmip6AtmosphereDataset):
ssp_id = "ssp585"

@staticmethod
def open() -> xr.Dataset:
return Cmip6Dataset.open_with(
def download(download_path: Path, progress: bool = True):
Cmip6Dataset.download_with(
download_path,
Cmip6AtmosphereCanEsm5Dataset.model_id,
Cmip6AtmosphereCanEsm5Dataset.ssp_id,
Cmip6AtmosphereCanEsm5Dataset.variable_id,
Cmip6AtmosphereCanEsm5Dataset.table_id,
progress=progress,
)


Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__all__ = ["Cmip6OceanCanEsm5Dataset"]

import xarray as xr
from pathlib import Path

from ... import (
open_downloaded_canonicalized_dataset,
Expand All @@ -16,14 +16,17 @@ class Cmip6OceanCanEsm5Dataset(Cmip6OceanDataset):
ssp_id = "ssp585"

@staticmethod
def open() -> xr.Dataset:
ds = Cmip6Dataset.open_with(
def download(download_path: Path, progress: bool = True):
Cmip6Dataset.download_with(
download_path,
Cmip6OceanCanEsm5Dataset.model_id,
Cmip6OceanCanEsm5Dataset.ssp_id,
Cmip6OceanCanEsm5Dataset.variable_id,
Cmip6OceanCanEsm5Dataset.table_id,
# Only download the actual sea surface temperature.
variable_selector=["tos"],
progress=progress,
)
return ds[["tos"]]


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__all__ = ["Cmip6AtmosphereUkEsmDataset"]

import xarray as xr
from pathlib import Path

from ... import (
open_downloaded_canonicalized_dataset,
Expand All @@ -16,12 +16,14 @@ class Cmip6AtmosphereUkEsmDataset(Cmip6AtmosphereDataset):
ssp_id = "ssp585"

@staticmethod
def open() -> xr.Dataset:
return Cmip6Dataset.open_with(
def download(download_path: Path, progress: bool = True):
Cmip6Dataset.download_with(
download_path,
Cmip6AtmosphereUkEsmDataset.model_id,
Cmip6AtmosphereUkEsmDataset.ssp_id,
Cmip6AtmosphereUkEsmDataset.variable_id,
Cmip6AtmosphereUkEsmDataset.table_id,
progress=progress,
)


Expand Down
11 changes: 7 additions & 4 deletions src/climatebenchpress/data_loader/datasets/cmip6/ukesm_ocean.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__all__ = ["Cmip6OceanUkEsmDataset"]

import xarray as xr
from pathlib import Path

from ... import (
open_downloaded_canonicalized_dataset,
Expand All @@ -16,14 +16,17 @@ class Cmip6OceanUkEsmDataset(Cmip6OceanDataset):
ssp_id = "ssp585"

@staticmethod
def open() -> xr.Dataset:
ds = Cmip6Dataset.open_with(
def download(download_path: Path, progress: bool = True):
Cmip6Dataset.download_with(
download_path,
Cmip6OceanUkEsmDataset.model_id,
Cmip6OceanUkEsmDataset.ssp_id,
Cmip6OceanUkEsmDataset.variable_id,
Cmip6OceanUkEsmDataset.table_id,
# Only download the actual sea surface temperature.
variable_selector=["tos"],
progress=progress,
)
return ds[["tos"]]


if __name__ == "__main__":
Expand Down
Loading