Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 48 additions & 20 deletions src/extremeweatherbench/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,22 +1192,30 @@ class DurationMeanError(MeanError):
threshold_criteria: The criteria for event detection. Can be either a DataArray
of a climatology with dimensions (dayofyear, hour, latitude, longitude) or a
float value representing a fixed threshold.
reduce_spatial_dims: The spatial dimensions to reduce prior to applying threshold
criteria. Defaults to ["latitude", "longitude"].
op_func: Comparison operator or string (e.g., operator.ge for >=).
name: Name of the metric.
preserve_dims: Dimensions to preserve during aggregation. Defaults to
"init_time".
product_time_resolution_hours: Whether to product the duration by the time
resolution of the forecast (in hours). Defaults to False.
"""

def __init__(
self,
threshold_criteria: xr.DataArray | float,
reduce_spatial_dims: list[str] = ["latitude", "longitude"],
op_func: Union[Callable, Literal[">", ">=", "<", "<=", "==", "!="]] = ">=",
name: str = "duration_me",
name: str = "DurationMeanError",
preserve_dims: str = "init_time",
product_time_resolution_hours: bool = False,
):
super().__init__(name=name, preserve_dims=preserve_dims)
self.reduce_spatial_dims = reduce_spatial_dims
self.threshold_criteria = threshold_criteria
self.op_func = utils.maybe_get_operator(op_func)
self.product_time_resolution_hours = product_time_resolution_hours

def _compute_metric(
self,
Expand All @@ -1218,22 +1226,21 @@ def _compute_metric(
"""Compute spatially averaged duration mean error.

Args:
forecast: Forecast dataset with dims (init_time, lead_time, valid_time)
target: Target dataset with dims (valid_time)
forecast: the forecast DataArray.
target: the target DataArray.

Returns:
Mean error between forecast and target event durations
The mean error between forecast and target event durations.
"""
spatial_dims = [
dim
for dim in forecast.dims
if dim not in ["init_time", "lead_time", "valid_time"]
]
# Handle criteria - either climatology (xr.DataArray) or float threshold
# Use local variable to avoid mutating self.threshold_criteria
threshold_criteria = self.threshold_criteria

# Need to get climatology into the correct format and interpolation for
# comparison
if isinstance(threshold_criteria, xr.DataArray):
# Climatology case, convert from dayofyear/hour to valid_time
# Climatology case, convert from dayofyear/hour to valid_time.
# Note that unintended behavior may occur if the case spans multiple years.
threshold_criteria = utils.convert_day_yearofday_to_time(
threshold_criteria, forecast.valid_time.dt.year.values[0]
)
Expand All @@ -1242,14 +1249,28 @@ def _compute_metric(
threshold_criteria = utils.interp_climatology_to_target(
target, threshold_criteria
)
forecast = utils.reduce_dataarray(
forecast, method="mean", reduce_dims=spatial_dims
)
target = utils.reduce_dataarray(target, method="mean", reduce_dims=spatial_dims)
forecast = forecast.compute()
target = target.compute()
# Reduce spatial dimensions if specified (default is to reduce)
if len(self.reduce_spatial_dims) > 0:
target = utils.reduce_dataarray(
target, method="mean", reduce_dims=self.reduce_spatial_dims, skipna=True
)
forecast = utils.reduce_dataarray(
forecast,
method="mean",
reduce_dims=self.reduce_spatial_dims,
skipna=True,
)

if isinstance(threshold_criteria, xr.DataArray):
threshold_criteria = utils.reduce_dataarray(
threshold_criteria,
method="mean",
reduce_dims=self.reduce_spatial_dims,
skipna=True,
)
forecast_mask = self.op_func(forecast, threshold_criteria)
target_mask = self.op_func(target, threshold_criteria)

# Track NaN locations in forecast data
forecast_valid_mask = ~forecast.isnull()

Expand All @@ -1259,18 +1280,25 @@ def _compute_metric(
target_mask_final = target_mask.where(forecast_valid_mask)
# If sparse, will need to expand_dims first as transpose is not supported
except AttributeError:
print("target_mask is sparse")
target_mask_final = target_mask.expand_dims(dim={"lead_time": 41}).where(
forecast_valid_mask
logger.info(
"Target mask is sparse, expanding dimensions to handle unsupported "
"transpose operation."
)
target_mask_final = target_mask.expand_dims(
dim={"lead_time": target.lead_time.size}
).where(forecast_valid_mask)

# Sum to get durations (NaN values are excluded by default)
forecast_duration = forecast_mask_final.groupby(self.preserve_dims).sum(
skipna=True
)
target_duration = target_mask_final.groupby(self.preserve_dims).sum(skipna=True)

# TODO: product of time resolution hours and duration
if self.product_time_resolution_hours:
time_resolution_hours = utils.determine_temporal_resolution(forecast)
forecast_duration = forecast_duration * time_resolution_hours
target_duration = target_duration * time_resolution_hours

return super()._compute_metric(
forecast=forecast_duration,
target=target_duration,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,7 +1009,7 @@ def test_instantiation(self):
climatology = self.create_climatology()
metric = metrics.DurationMeanError(threshold_criteria=climatology)
assert isinstance(metric, metrics.MeanError)
assert metric.name == "duration_me"
assert metric.name == "DurationMeanError"

def test_base_metric_inheritance(self):
"""Test that DurationMeanError inherits from ME."""
Expand Down Expand Up @@ -1420,7 +1420,7 @@ def test_instantiation_with_float_threshold_criteria(self):
criteria."""
metric = metrics.DurationMeanError(threshold_criteria=300.0)
assert isinstance(metric, metrics.MeanError)
assert metric.name == "duration_me"
assert metric.name == "DurationMeanError"
assert metric.threshold_criteria == 300.0

def test_me_with_float_threshold_all_forecast_exceeds(self):
Expand Down
Loading