diff --git a/monai/metrics/confusion_matrix.py b/monai/metrics/confusion_matrix.py index 9083b7733f..26ec823081 100644 --- a/monai/metrics/confusion_matrix.py +++ b/monai/metrics/confusion_matrix.py @@ -153,9 +153,6 @@ def get_confusion_matrix(y_pred: torch.Tensor, y: torch.Tensor, include_backgrou if not include_background: y_pred, y = ignore_background(y_pred=y_pred, y=y) - y = y.float() - y_pred = y_pred.float() - if y.shape != y_pred.shape: raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.") @@ -165,12 +162,12 @@ def get_confusion_matrix(y_pred: torch.Tensor, y: torch.Tensor, include_backgrou # As for classification tasks, S equals to 1. y_pred = y_pred.reshape(batch_size, n_class, -1) y = y.reshape(batch_size, n_class, -1) - tp = ((y_pred + y) == 2).float() - tn = ((y_pred + y) == 0).float() + tp = (y_pred + y) == 2 + tn = (y_pred + y) == 0 - tp = tp.sum(dim=[2]) - tn = tn.sum(dim=[2]) - p = y.sum(dim=[2]) + tp = tp.sum(dim=[2]).float() + tn = tn.sum(dim=[2]).float() + p = y.sum(dim=[2]).float() n = y.shape[-1] - p fn = p - tp diff --git a/monai/metrics/f_beta_score.py b/monai/metrics/f_beta_score.py index a5703105a2..61e4525662 100644 --- a/monai/metrics/f_beta_score.py +++ b/monai/metrics/f_beta_score.py @@ -63,9 +63,6 @@ def get_f_beta_score(y_pred: torch.Tensor, y: torch.Tensor, include_background: if not include_background: y_pred, y = ignore_background(y_pred=y_pred, y=y) - y = y.float() - y_pred = y_pred.float() - if y.shape != y_pred.shape: raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.") @@ -75,12 +72,12 @@ def get_f_beta_score(y_pred: torch.Tensor, y: torch.Tensor, include_background: # As for classification tasks, S equals to 1. y_pred = y_pred.view(batch_size, n_class, -1) y = y.view(batch_size, n_class, -1) - tp = ((y_pred + y) == 2).float() - tn = ((y_pred + y) == 0).float() + tp = (y_pred + y) == 2 + tn = (y_pred + y) == 0 - tp = tp.sum(dim=[2]) - tn = tn.sum(dim=[2]) - p = y.sum(dim=[2]) + tp = tp.sum(dim=[2]).float() + tn = tn.sum(dim=[2]).float() + p = y.sum(dim=[2]).float() n = y.shape[-1] - p fn = p - tp diff --git a/monai/metrics/meaniou.py b/monai/metrics/meaniou.py index 7d1ae49f25..65c53f7aa5 100644 --- a/monai/metrics/meaniou.py +++ b/monai/metrics/meaniou.py @@ -130,9 +130,6 @@ def compute_iou( if not include_background: y_pred, y = ignore_background(y_pred=y_pred, y=y) - y = y.float() - y_pred = y_pred.float() - if y.shape != y_pred.shape: raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.") diff --git a/monai/metrics/regression.py b/monai/metrics/regression.py index f37230f09e..9d29654ee3 100644 --- a/monai/metrics/regression.py +++ b/monai/metrics/regression.py @@ -111,9 +111,6 @@ def __init__(self, reduction: MetricReduction | str = MetricReduction.MEAN, get_ self.sq_func = partial(torch.pow, exponent=2.0) def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - y_pred = y_pred.float() - y = y.float() - return compute_mean_error_metrics(y_pred, y, func=self.sq_func) @@ -143,9 +140,6 @@ def __init__(self, reduction: MetricReduction | str = MetricReduction.MEAN, get_ self.abs_func = torch.abs def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - y_pred = y_pred.float() - y = y.float() - return compute_mean_error_metrics(y_pred, y, func=self.abs_func) @@ -176,9 +170,6 @@ def __init__(self, reduction: MetricReduction | str = MetricReduction.MEAN, get_ self.sq_func = partial(torch.pow, exponent=2.0) def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - y_pred = y_pred.float() - y = y.float() - mse_out = compute_mean_error_metrics(y_pred, y, func=self.sq_func) return torch.sqrt(mse_out) @@ -218,9 +209,6 @@ def __init__( self.sq_func = partial(torch.pow, exponent=2.0) def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> Any: - y_pred = y_pred.float() - y = y.float() - mse_out = compute_mean_error_metrics(y_pred, y, func=self.sq_func) return 20 * math.log10(self.max_val) - 10 * torch.log10(mse_out) diff --git a/monai/metrics/surface_dice.py b/monai/metrics/surface_dice.py index f8c402a756..635eb1bc24 100644 --- a/monai/metrics/surface_dice.py +++ b/monai/metrics/surface_dice.py @@ -228,9 +228,6 @@ def compute_surface_dice( f"y_pred and y should have same shape, but instead, shapes are {y_pred.shape} (y_pred) and {y.shape} (y)." ) - y = y.float() - y_pred = y_pred.float() - batch_size, n_class = y_pred.shape[:2] if n_class != len(class_thresholds): diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index fe145b0f50..c139fc35ed 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -95,37 +95,37 @@ def do_metric_reduction( # some elements might be Nan (if ground truth y was missing (zeros)) # we need to account for it nans = torch.isnan(f) - not_nans = (~nans).float() + not_nans = ~nans - t_zero = torch.zeros(1, device=f.device, dtype=f.dtype) + t_zero = torch.zeros(1, device=f.device, dtype=torch.float) reduction = look_up_option(reduction, MetricReduction) if reduction == MetricReduction.NONE: - return f, not_nans + return f, not_nans.float() f[nans] = 0 if reduction == MetricReduction.MEAN: # 2 steps, first, mean by channel (accounting for nans), then by batch - not_nans = not_nans.sum(dim=1) - f = torch.where(not_nans > 0, f.sum(dim=1) / not_nans, t_zero) # channel average + not_nans = not_nans.sum(dim=1).float() + f = torch.where(not_nans > 0, f.sum(dim=1).float() / not_nans, t_zero) # channel average - not_nans = (not_nans > 0).float().sum(dim=0) - f = torch.where(not_nans > 0, f.sum(dim=0) / not_nans, t_zero) # batch average + not_nans = (not_nans > 0).sum(dim=0).float() + f = torch.where(not_nans > 0, f.sum(dim=0).float() / not_nans, t_zero) # batch average elif reduction == MetricReduction.SUM: - not_nans = not_nans.sum(dim=[0, 1]) + not_nans = not_nans.sum(dim=[0, 1]).float() f = torch.sum(f, dim=[0, 1]) # sum over the batch and channel dims elif reduction == MetricReduction.MEAN_BATCH: - not_nans = not_nans.sum(dim=0) - f = torch.where(not_nans > 0, f.sum(dim=0) / not_nans, t_zero) # batch average + not_nans = not_nans.sum(dim=0).float() + f = torch.where(not_nans > 0, f.sum(dim=0).float() / not_nans, t_zero) # batch average elif reduction == MetricReduction.SUM_BATCH: - not_nans = not_nans.sum(dim=0) - f = f.sum(dim=0) # the batch sum + not_nans = not_nans.sum(dim=0).float() + f = f.sum(dim=0).float() # the batch sum elif reduction == MetricReduction.MEAN_CHANNEL: - not_nans = not_nans.sum(dim=1) - f = torch.where(not_nans > 0, f.sum(dim=1) / not_nans, t_zero) # channel average + not_nans = not_nans.sum(dim=1).float() + f = torch.where(not_nans > 0, f.sum(dim=1).float() / not_nans, t_zero) # channel average elif reduction == MetricReduction.SUM_CHANNEL: - not_nans = not_nans.sum(dim=1) - f = f.sum(dim=1) # the channel sum + not_nans = not_nans.sum(dim=1).float() + f = f.sum(dim=1).float() # the channel sum elif reduction != MetricReduction.NONE: raise ValueError( f"Unsupported reduction: {reduction}, available options are " diff --git a/tests/test_compute_confusion_matrix.py b/tests/test_compute_confusion_matrix.py index a886d8b7e4..e0a92aec67 100644 --- a/tests/test_compute_confusion_matrix.py +++ b/tests/test_compute_confusion_matrix.py @@ -210,6 +210,14 @@ TEST_CASES_CLF = [data_clf.copy(), result_clf] +TEST_CASE_PRECISION = [ + { + "y_pred": torch.zeros([1, 1, 1024, 1024, 44], device=_device), + "y": torch.zeros([1, 1, 1024, 1024, 44], device=_device), + }, + torch.tensor([[[0.0, 0.0, 46137344.0, 0.0]]]), +] + class TestConfusionMatrix(unittest.TestCase): @parameterized.expand([TEST_CASE_CONFUSION_MATRIX]) @@ -274,6 +282,13 @@ def test_clf_with_nan(self, input_data, expected_value): expected_value = compute_confusion_matrix_metric("tpr", expected_value) assert_allclose(result, expected_value, atol=1e-4, rtol=1e-4) + @parameterized.expand([TEST_CASE_PRECISION]) + def test_precision(self, input_data, expected_value): + # include or ignore background + result = get_confusion_matrix(**input_data) + assert_allclose(result, expected_value, atol=1e-4, rtol=1e-4) + np.testing.assert_equal(result.device, input_data["y_pred"].device) + if __name__ == "__main__": unittest.main()