From f3e493615e44bd486b7cbd52d41680699995d1e5 Mon Sep 17 00:00:00 2001 From: Rob Knapen Date: Thu, 7 May 2026 16:22:16 +0200 Subject: [PATCH] Display mode for R2 and MAPE metrics --- src/models/components/metrics/mape.py | 6 ++++-- src/models/components/metrics/r2.py | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/models/components/metrics/mape.py b/src/models/components/metrics/mape.py index 2df630d..af2fbe3 100644 --- a/src/models/components/metrics/mape.py +++ b/src/models/components/metrics/mape.py @@ -29,11 +29,13 @@ def __init__(self) -> None: def forward( self, pred: torch.Tensor, - mode: str, + mode: str | None = None, labels: torch.Tensor | None = None, batch: Dict[str, torch.Tensor] | None = None, **kwargs, ) -> Dict[str, torch.Tensor]: + if mode not in _MODES: + raise ValueError(f"MAPE.forward: mode must be one of {_MODES}, got '{mode}'") if labels is None: labels = batch.get("target") if batch is not None else None if labels is None: @@ -43,4 +45,4 @@ def forward( metric = self._mape[f"mode_{mode}"] metric.update(pred.squeeze(-1), labels.squeeze(-1)) - return {self.name: metric} + return {f"{mode}_{self.name}": metric} diff --git a/src/models/components/metrics/r2.py b/src/models/components/metrics/r2.py index 1d3bf2f..27f6584 100644 --- a/src/models/components/metrics/r2.py +++ b/src/models/components/metrics/r2.py @@ -29,11 +29,13 @@ def __init__(self) -> None: def forward( self, pred: torch.Tensor, - mode: str, + mode: str | None = None, labels: torch.Tensor | None = None, batch: Dict[str, torch.Tensor] | None = None, **kwargs, ) -> Dict[str, torch.Tensor]: + if mode not in _MODES: + raise ValueError(f"RSquared.forward: mode must be one of {_MODES}, got '{mode}'") if labels is None: labels = batch.get("target") if batch is not None else None if labels is None: @@ -43,4 +45,4 @@ def forward( metric = self._r2[f"mode_{mode}"] metric.update(pred.squeeze(-1), labels.squeeze(-1)) - return {self.name: metric} + return {f"{mode}_{self.name}": metric}