Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions src/extremeweatherbench/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,55 @@ def _open_data_from_source(self) -> IncomingDataInput:
)


@dataclasses.dataclass
class XarrayForecast(ForecastBase):
"""Forecast class for datasets that were previously constructed and opened using xarray.

This class is intended for situations where the user has to manually prepare a dataset to
use in their evaluation. This can happen when the user is manually constructed such a
dataset from a collection of NetCDF or Zarr archives which need to be assembled into a
single, master dataset.

Attributes:
ds: The xarray dataset containing the forecast data.
source: The source of the data, defaults to "memory" for in-memory datasets.
name: The name of the input data source, defaults to "in-memory dataset".
"""

#: The xarray dataset containing the forecast data. This is required for the class to be instantiated
#: because we inherit from ForecastBase, which has its own set of required attributes.
ds: Optional[xr.Dataset] = None # type: ignore[assignment]
source: str = "memory"
name: str = "in-memory dataset"

def __post_init__(self):
"""Validate that ds is provided and normalize None values to defaults.

This ensures backwards compatibility with the ForecastBase's __init__ behavior
where None values for variables and variable_mapping were converted to empty
containers. If the user does not provide a ds, we raise a ValueError.
"""
if self.ds is None:
raise ValueError(
"The 'ds' parameter is required for XarrayForecast. "
"Please provide an xarray.Dataset."
)

# Convert None to empty containers for backwards compatibility
if self.variables is None:
object.__setattr__(self, "variables", [])
if self.variable_mapping is None:
object.__setattr__(self, "variable_mapping", {})

def _open_data_from_source(self) -> xr.Dataset:
"""Open the input data from the source.

Returns:
The xarray dataset that was provided during initialization.
"""
return self.ds


@dataclasses.dataclass
class TargetBase(InputBase):
"""An abstract base class for target data.
Expand Down
Loading
Loading