diff --git a/src/extremeweatherbench/metrics.py b/src/extremeweatherbench/metrics.py index e80061ed..21bed2b5 100644 --- a/src/extremeweatherbench/metrics.py +++ b/src/extremeweatherbench/metrics.py @@ -1192,22 +1192,30 @@ class DurationMeanError(MeanError): threshold_criteria: The criteria for event detection. Can be either a DataArray of a climatology with dimensions (dayofyear, hour, latitude, longitude) or a float value representing a fixed threshold. + reduce_spatial_dims: The spatial dimensions to reduce prior to applying threshold + criteria. Defaults to ["latitude", "longitude"]. op_func: Comparison operator or string (e.g., operator.ge for >=). name: Name of the metric. preserve_dims: Dimensions to preserve during aggregation. Defaults to "init_time". + product_time_resolution_hours: Whether to product the duration by the time + resolution of the forecast (in hours). Defaults to False. """ def __init__( self, threshold_criteria: xr.DataArray | float, + reduce_spatial_dims: list[str] = ["latitude", "longitude"], op_func: Union[Callable, Literal[">", ">=", "<", "<=", "==", "!="]] = ">=", - name: str = "duration_me", + name: str = "DurationMeanError", preserve_dims: str = "init_time", + product_time_resolution_hours: bool = False, ): super().__init__(name=name, preserve_dims=preserve_dims) + self.reduce_spatial_dims = reduce_spatial_dims self.threshold_criteria = threshold_criteria self.op_func = utils.maybe_get_operator(op_func) + self.product_time_resolution_hours = product_time_resolution_hours def _compute_metric( self, @@ -1218,22 +1226,21 @@ def _compute_metric( """Compute spatially averaged duration mean error. Args: - forecast: Forecast dataset with dims (init_time, lead_time, valid_time) - target: Target dataset with dims (valid_time) + forecast: the forecast DataArray. + target: the target DataArray. Returns: - Mean error between forecast and target event durations + The mean error between forecast and target event durations. """ - spatial_dims = [ - dim - for dim in forecast.dims - if dim not in ["init_time", "lead_time", "valid_time"] - ] # Handle criteria - either climatology (xr.DataArray) or float threshold # Use local variable to avoid mutating self.threshold_criteria threshold_criteria = self.threshold_criteria + + # Need to get climatology into the correct format and interpolation for + # comparison if isinstance(threshold_criteria, xr.DataArray): - # Climatology case, convert from dayofyear/hour to valid_time + # Climatology case, convert from dayofyear/hour to valid_time. + # Note that unintended behavior may occur if the case spans multiple years. threshold_criteria = utils.convert_day_yearofday_to_time( threshold_criteria, forecast.valid_time.dt.year.values[0] ) @@ -1242,14 +1249,28 @@ def _compute_metric( threshold_criteria = utils.interp_climatology_to_target( target, threshold_criteria ) - forecast = utils.reduce_dataarray( - forecast, method="mean", reduce_dims=spatial_dims - ) - target = utils.reduce_dataarray(target, method="mean", reduce_dims=spatial_dims) - forecast = forecast.compute() - target = target.compute() + # Reduce spatial dimensions if specified (default is to reduce) + if len(self.reduce_spatial_dims) > 0: + target = utils.reduce_dataarray( + target, method="mean", reduce_dims=self.reduce_spatial_dims, skipna=True + ) + forecast = utils.reduce_dataarray( + forecast, + method="mean", + reduce_dims=self.reduce_spatial_dims, + skipna=True, + ) + + if isinstance(threshold_criteria, xr.DataArray): + threshold_criteria = utils.reduce_dataarray( + threshold_criteria, + method="mean", + reduce_dims=self.reduce_spatial_dims, + skipna=True, + ) forecast_mask = self.op_func(forecast, threshold_criteria) target_mask = self.op_func(target, threshold_criteria) + # Track NaN locations in forecast data forecast_valid_mask = ~forecast.isnull() @@ -1259,10 +1280,13 @@ def _compute_metric( target_mask_final = target_mask.where(forecast_valid_mask) # If sparse, will need to expand_dims first as transpose is not supported except AttributeError: - print("target_mask is sparse") - target_mask_final = target_mask.expand_dims(dim={"lead_time": 41}).where( - forecast_valid_mask + logger.info( + "Target mask is sparse, expanding dimensions to handle unsupported " + "transpose operation." ) + target_mask_final = target_mask.expand_dims( + dim={"lead_time": target.lead_time.size} + ).where(forecast_valid_mask) # Sum to get durations (NaN values are excluded by default) forecast_duration = forecast_mask_final.groupby(self.preserve_dims).sum( @@ -1270,7 +1294,11 @@ def _compute_metric( ) target_duration = target_mask_final.groupby(self.preserve_dims).sum(skipna=True) - # TODO: product of time resolution hours and duration + if self.product_time_resolution_hours: + time_resolution_hours = utils.determine_temporal_resolution(forecast) + forecast_duration = forecast_duration * time_resolution_hours + target_duration = target_duration * time_resolution_hours + return super()._compute_metric( forecast=forecast_duration, target=target_duration, diff --git a/tests/test_metrics.py b/tests/test_metrics.py index ef964005..a7056806 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -1009,7 +1009,7 @@ def test_instantiation(self): climatology = self.create_climatology() metric = metrics.DurationMeanError(threshold_criteria=climatology) assert isinstance(metric, metrics.MeanError) - assert metric.name == "duration_me" + assert metric.name == "DurationMeanError" def test_base_metric_inheritance(self): """Test that DurationMeanError inherits from ME.""" @@ -1420,7 +1420,7 @@ def test_instantiation_with_float_threshold_criteria(self): criteria.""" metric = metrics.DurationMeanError(threshold_criteria=300.0) assert isinstance(metric, metrics.MeanError) - assert metric.name == "duration_me" + assert metric.name == "DurationMeanError" assert metric.threshold_criteria == 300.0 def test_me_with_float_threshold_all_forecast_exceeds(self):