diff --git a/src/extremeweatherbench/inputs.py b/src/extremeweatherbench/inputs.py index 89763b54..eb1c8425 100644 --- a/src/extremeweatherbench/inputs.py +++ b/src/extremeweatherbench/inputs.py @@ -423,6 +423,55 @@ def _open_data_from_source(self) -> IncomingDataInput: ) +@dataclasses.dataclass +class XarrayForecast(ForecastBase): + """Forecast class for datasets that were previously constructed and opened using xarray. + + This class is intended for situations where the user has to manually prepare a dataset to + use in their evaluation. This can happen when the user is manually constructed such a + dataset from a collection of NetCDF or Zarr archives which need to be assembled into a + single, master dataset. + + Attributes: + ds: The xarray dataset containing the forecast data. + source: The source of the data, defaults to "memory" for in-memory datasets. + name: The name of the input data source, defaults to "in-memory dataset". + """ + + #: The xarray dataset containing the forecast data. This is required for the class to be instantiated + #: because we inherit from ForecastBase, which has its own set of required attributes. + ds: Optional[xr.Dataset] = None # type: ignore[assignment] + source: str = "memory" + name: str = "in-memory dataset" + + def __post_init__(self): + """Validate that ds is provided and normalize None values to defaults. + + This ensures backwards compatibility with the ForecastBase's __init__ behavior + where None values for variables and variable_mapping were converted to empty + containers. If the user does not provide a ds, we raise a ValueError. + """ + if self.ds is None: + raise ValueError( + "The 'ds' parameter is required for XarrayForecast. " + "Please provide an xarray.Dataset." + ) + + # Convert None to empty containers for backwards compatibility + if self.variables is None: + object.__setattr__(self, "variables", []) + if self.variable_mapping is None: + object.__setattr__(self, "variable_mapping", {}) + + def _open_data_from_source(self) -> xr.Dataset: + """Open the input data from the source. + + Returns: + The xarray dataset that was provided during initialization. + """ + return self.ds + + @dataclasses.dataclass class TargetBase(InputBase): """An abstract base class for target data. diff --git a/tests/test_inputs.py b/tests/test_inputs.py index 2dc9dadd..4531d150 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -1871,6 +1871,445 @@ def test_era5_alignment_comprehensive( # have different time coverage +class TestXarrayForecast: + """Test the XarrayForecast class for in-memory datasets.""" + + def test_xarray_forecast_instantiation_with_defaults( + self, sample_forecast_with_valid_time + ): + """Test creating XarrayForecast with default parameters.""" + forecast = inputs.XarrayForecast(ds=sample_forecast_with_valid_time) + + assert forecast.ds is sample_forecast_with_valid_time + assert forecast.source == "memory" + assert forecast.name == "in-memory dataset" + assert forecast.variables == [] + assert forecast.variable_mapping == {} + assert forecast.preprocess == inputs._default_preprocess + + def test_xarray_forecast_instantiation_with_custom_params( + self, sample_forecast_with_valid_time + ): + """Test creating XarrayForecast with custom parameters.""" + custom_mapping = {"temp": "surface_air_temperature"} + custom_variables = ["surface_air_temperature"] + + def custom_preprocess(ds): + return ds * 2 + + forecast = inputs.XarrayForecast( + ds=sample_forecast_with_valid_time, + variables=custom_variables, + variable_mapping=custom_mapping, + source="custom_source", + name="custom_forecast", + preprocess=custom_preprocess, + ) + + assert forecast.ds is sample_forecast_with_valid_time + assert forecast.source == "custom_source" + assert forecast.name == "custom_forecast" + assert forecast.variables == custom_variables + assert forecast.variable_mapping == custom_mapping + assert forecast.preprocess == custom_preprocess + + def test_xarray_forecast_open_data_from_source( + self, sample_forecast_with_valid_time + ): + """Test that _open_data_from_source returns the stored dataset.""" + forecast = inputs.XarrayForecast(ds=sample_forecast_with_valid_time) + result = forecast._open_data_from_source() + + assert result is sample_forecast_with_valid_time + xr.testing.assert_identical(result, sample_forecast_with_valid_time) + + def test_xarray_forecast_open_and_preprocess(self, sample_forecast_with_valid_time): + """Test open_and_maybe_preprocess_data_from_source with preprocessing.""" + + def multiply_by_two(ds): + return ds * 2 + + forecast = inputs.XarrayForecast( + ds=sample_forecast_with_valid_time, preprocess=multiply_by_two + ) + + result = forecast.open_and_maybe_preprocess_data_from_source() + + # Should apply preprocessing + assert isinstance(result, xr.Dataset) + # Values should be doubled + expected = sample_forecast_with_valid_time * 2 + xr.testing.assert_allclose(result, expected) + + def test_xarray_forecast_no_preprocessing(self, sample_forecast_with_valid_time): + """Test that default preprocessing does nothing.""" + forecast = inputs.XarrayForecast(ds=sample_forecast_with_valid_time) + result = forecast.open_and_maybe_preprocess_data_from_source() + + # Should return unchanged dataset + xr.testing.assert_identical(result, sample_forecast_with_valid_time) + + def test_xarray_forecast_with_chunks_parameter( + self, sample_forecast_with_valid_time + ): + """Test XarrayForecast accepts chunks parameter from ForecastBase.""" + custom_chunks = {"valid_time": 10, "latitude": 45, "longitude": 90} + + forecast = inputs.XarrayForecast( + ds=sample_forecast_with_valid_time, chunks=custom_chunks + ) + + # ForecastBase should have chunks attribute + assert forecast.chunks == custom_chunks + + def test_xarray_forecast_inherits_subset_data_to_case( + self, sample_forecast_dataset + ): + """Test that XarrayForecast inherits subset_data_to_case from ForecastBase.""" + # Convert forecast to have init_time and lead_time like ForecastBase expects + forecast = inputs.XarrayForecast( + ds=sample_forecast_dataset, variables=["surface_air_temperature"] + ) + + # Create mock case metadata + mock_case = mock.Mock() + mock_case.start_date = pd.Timestamp("2021-06-20") + mock_case.end_date = pd.Timestamp("2021-06-22") + mock_case.location.mask.return_value = sample_forecast_dataset + + with ( + mock.patch( + "extremeweatherbench.utils.derive_indices_from_init_time_and_lead_time" + ) as mock_derive, + mock.patch( + "extremeweatherbench.utils.convert_init_time_to_valid_time" + ) as mock_convert, + ): + # Setup mocks + mock_derive.return_value = (np.array([0, 1]), np.array([0, 1])) + result_data = xr.Dataset( + { + "surface_air_temperature": ( + ["valid_time", "latitude", "longitude"], + np.random.randn(3, 3, 3), + ) + }, + coords={ + "valid_time": pd.date_range("2021-06-20", periods=3, freq="6h"), + "latitude": [40, 41, 42], + "longitude": [-100, -101, -102], + }, + ) + mock_convert.return_value = result_data + + result = forecast.subset_data_to_case(sample_forecast_dataset, mock_case) + + # Should use inherited method from ForecastBase + assert isinstance(result, xr.Dataset) + + def test_xarray_forecast_with_variable_mapping( + self, sample_forecast_with_valid_time + ): + """Test XarrayForecast with variable mapping.""" + # Create dataset with original variable names + test_data = sample_forecast_with_valid_time.copy() + + forecast = inputs.XarrayForecast( + ds=test_data, + variables=["temp"], + variable_mapping={"surface_air_temperature": "temp"}, + ) + + # Test that variable mapping is stored + assert forecast.variable_mapping == {"surface_air_temperature": "temp"} + + # Test that maybe_map_variable_names works + mapped_data = forecast.maybe_map_variable_names(test_data) + assert "temp" in mapped_data.data_vars + assert "surface_air_temperature" not in mapped_data.data_vars + + def test_xarray_forecast_empty_dataset(self): + """Test XarrayForecast with empty dataset.""" + empty_ds = xr.Dataset() + forecast = inputs.XarrayForecast(ds=empty_ds) + + assert forecast.ds.equals(empty_ds) + result = forecast._open_data_from_source() + assert result.equals(empty_ds) + + def test_xarray_forecast_with_multiple_variables(self): + """Test XarrayForecast with dataset containing multiple variables.""" + # Create dataset with multiple variables + ds = xr.Dataset( + { + "surface_air_temperature": ( + ["valid_time", "latitude", "longitude"], + np.random.randn(10, 5, 5), + ), + "surface_pressure": ( + ["valid_time", "latitude", "longitude"], + np.random.randn(10, 5, 5), + ), + "surface_eastward_wind": ( + ["valid_time", "latitude", "longitude"], + np.random.randn(10, 5, 5), + ), + }, + coords={ + "valid_time": pd.date_range("2021-06-20", periods=10, freq="6h"), + "latitude": [40, 41, 42, 43, 44], + "longitude": [-100, -101, -102, -103, -104], + }, + ) + + forecast = inputs.XarrayForecast( + ds=ds, + variables=["surface_air_temperature", "surface_pressure"], + name="multi_var_forecast", + ) + + assert forecast.name == "multi_var_forecast" + assert set(forecast.variables) == { + "surface_air_temperature", + "surface_pressure", + } + result = forecast._open_data_from_source() + assert "surface_air_temperature" in result.data_vars + assert "surface_pressure" in result.data_vars + assert "surface_eastward_wind" in result.data_vars + + def test_xarray_forecast_add_source_to_attrs(self, sample_forecast_with_valid_time): + """Test adding source name to dataset attributes.""" + forecast = inputs.XarrayForecast( + ds=sample_forecast_with_valid_time, name="test_in_memory" + ) + + result = forecast.add_source_to_dataset_attrs(sample_forecast_with_valid_time) + + assert result.attrs["source"] == "test_in_memory" + + def test_xarray_forecast_preserves_dataset_attrs(self): + """Test that XarrayForecast preserves existing dataset attributes.""" + ds = xr.Dataset( + { + "temp": (["time", "x"], np.random.randn(5, 3)), + }, + coords={"time": pd.date_range("2021-06-20", periods=5), "x": [1, 2, 3]}, + ) + ds.attrs["existing_attr"] = "test_value" + ds.attrs["description"] = "Test dataset" + + forecast = inputs.XarrayForecast(ds=ds) + result = forecast._open_data_from_source() + + # Original attributes should be preserved + assert result.attrs["existing_attr"] == "test_value" + assert result.attrs["description"] == "Test dataset" + + def test_xarray_forecast_complex_preprocessing_pipeline(self): + """Test XarrayForecast with complex preprocessing pipeline.""" + ds = xr.Dataset( + { + "surface_air_temperature": ( + ["valid_time", "latitude", "longitude"], + np.random.randn(10, 5, 5) + 273.15, + ), + }, + coords={ + "valid_time": pd.date_range("2021-06-20", periods=10, freq="6h"), + "latitude": [40, 41, 42, 43, 44], + "longitude": [-100, -101, -102, -103, -104], + }, + ) + + def complex_preprocess(dataset): + """Convert Kelvin to Celsius and add metadata.""" + result = dataset.copy() + result["surface_air_temperature"] = ( + result["surface_air_temperature"] - 273.15 + ) + result.attrs["units_converted"] = "K_to_C" + return result + + forecast = inputs.XarrayForecast(ds=ds, preprocess=complex_preprocess) + + result = forecast.open_and_maybe_preprocess_data_from_source() + + # Check preprocessing was applied + assert result.attrs.get("units_converted") == "K_to_C" + # Check values were converted (approximately) + assert result["surface_air_temperature"].mean() < 100 # Should be in Celsius + + def test_xarray_forecast_with_storage_options( + self, sample_forecast_with_valid_time + ): + """Test XarrayForecast with storage_options parameter.""" + storage_opts = {"anon": True, "project": "test-project"} + + forecast = inputs.XarrayForecast( + ds=sample_forecast_with_valid_time, storage_options=storage_opts + ) + + # Storage options should be stored (from InputBase) + assert forecast.storage_options == storage_opts + + def test_xarray_forecast_comparison_with_zarr_forecast( + self, sample_forecast_with_valid_time + ): + """Test that XarrayForecast behaves similarly to ZarrForecast.""" + # Create XarrayForecast + xarray_fc = inputs.XarrayForecast( + ds=sample_forecast_with_valid_time, + name="test_forecast", + variables=["surface_air_temperature"], + ) + + # Both should have same inherited methods + assert hasattr(xarray_fc, "subset_data_to_case") + assert hasattr(xarray_fc, "maybe_convert_to_dataset") + assert hasattr(xarray_fc, "maybe_map_variable_names") + assert hasattr(xarray_fc, "add_source_to_dataset_attrs") + + # XarrayForecast should return data directly + result = xarray_fc._open_data_from_source() + assert result is sample_forecast_with_valid_time + + def test_xarray_forecast_integration_with_evaluation_object( + self, sample_forecast_with_valid_time + ): + """Test XarrayForecast can be used in EvaluationObject.""" + mock_metric = mock.Mock() + mock_target = mock.Mock() + + forecast = inputs.XarrayForecast( + ds=sample_forecast_with_valid_time, + name="in_memory_forecast", + variables=["surface_air_temperature"], + ) + + # Should be able to create EvaluationObject + eval_obj = inputs.EvaluationObject( + event_type="test_event", + metric_list=[mock_metric], + target=mock_target, + forecast=forecast, + ) + + assert eval_obj.forecast is forecast + assert isinstance(eval_obj.forecast, inputs.ForecastBase) + assert isinstance(eval_obj.forecast, inputs.XarrayForecast) + + def test_xarray_forecast_maybe_convert_to_dataset_passthrough( + self, sample_forecast_with_valid_time + ): + """Test that maybe_convert_to_dataset works correctly.""" + forecast = inputs.XarrayForecast(ds=sample_forecast_with_valid_time) + + result = forecast.maybe_convert_to_dataset(sample_forecast_with_valid_time) + + # Should return the dataset unchanged (already a Dataset) + assert isinstance(result, xr.Dataset) + assert result is sample_forecast_with_valid_time + + def test_xarray_forecast_with_derived_variables_placeholder( + self, sample_forecast_with_valid_time + ): + """Test XarrayForecast can specify derived variables (placeholder).""" + # This tests that the variables parameter can include derived variable + # specifications Note: actual derived variable calculation would happen + # in the pipeline + forecast = inputs.XarrayForecast( + ds=sample_forecast_with_valid_time, + variables=["surface_air_temperature", "surface_wind_speed"], + ) + + assert len(forecast.variables) == 2 + assert "surface_air_temperature" in forecast.variables + assert "surface_wind_speed" in forecast.variables + + def test_xarray_forecast_none_handling_for_optional_params( + self, sample_forecast_with_valid_time + ): + """Test that None values are properly converted to empty defaults.""" + # Explicitly pass None to test the None handling in __init__ + forecast = inputs.XarrayForecast( + ds=sample_forecast_with_valid_time, variables=None, variable_mapping=None + ) + + # Should be converted to empty containers + assert forecast.variables == [] + assert forecast.variable_mapping == {} + + def test_xarray_forecast_requires_ds_parameter(self): + """Test that XarrayForecast raises ValueError when ds is None.""" + with pytest.raises( + ValueError, match="The 'ds' parameter is required for XarrayForecast" + ): + inputs.XarrayForecast(ds=None) + + def test_xarray_forecast_repr_includes_key_info( + self, sample_forecast_with_valid_time + ): + """Test that XarrayForecast has proper repr from dataclass.""" + forecast = inputs.XarrayForecast( + ds=sample_forecast_with_valid_time, + name="test_forecast", + source="memory", + ) + + repr_str = repr(forecast) + # Should include class name and key attributes + assert "XarrayForecast" in repr_str + assert "name=" in repr_str or "test_forecast" in repr_str + + def test_xarray_forecast_dataclass_equality(self, sample_forecast_with_valid_time): + """Test that XarrayForecast supports dataclass equality comparison.""" + forecast1 = inputs.XarrayForecast( + ds=sample_forecast_with_valid_time, + name="test", + source="memory", + ) + forecast2 = inputs.XarrayForecast( + ds=sample_forecast_with_valid_time, + name="test", + source="memory", + ) + + # Same dataset and parameters should be equal + assert forecast1 == forecast2 + + # Different name should be not equal + forecast3 = inputs.XarrayForecast( + ds=sample_forecast_with_valid_time, + name="different", + source="memory", + ) + assert forecast1 != forecast3 + + def test_xarray_forecast_default_source_and_name( + self, sample_forecast_with_valid_time + ): + """Test that XarrayForecast provides sensible defaults for source and name.""" + forecast = inputs.XarrayForecast(ds=sample_forecast_with_valid_time) + + assert forecast.source == "memory" + assert forecast.name == "in-memory dataset" + + def test_xarray_forecast_can_override_parent_defaults( + self, sample_forecast_with_valid_time + ): + """Test that XarrayForecast can override defaults from parent classes.""" + forecast = inputs.XarrayForecast( + ds=sample_forecast_with_valid_time, + chunks={"time": 10}, + storage_options={"anon": True}, + ) + + # Should be able to override ForecastBase defaults + assert forecast.chunks == {"time": 10} + assert forecast.storage_options == {"anon": True} + + def test_default_preprocess(): """Test default preprocess function.""" # Import the function from inputs module since it was moved there