From 1de4b5f1205c033f315042b8826a2f45c3d8475b Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 2 Nov 2023 17:51:30 +0800 Subject: [PATCH 1/8] fix #7186 Signed-off-by: KumoLiu --- monai/losses/dice.py | 6 +++--- monai/losses/focal_loss.py | 2 -- monai/metrics/confusion_matrix.py | 7 ++----- monai/metrics/f_beta_score.py | 7 ++----- monai/metrics/meaniou.py | 3 --- monai/metrics/regression.py | 12 ------------ monai/metrics/surface_dice.py | 3 --- monai/metrics/utils.py | 4 ++-- 8 files changed, 9 insertions(+), 35 deletions(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index b3c0f57c6e..f29865fa8e 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -377,7 +377,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: denominator = ground_o + pred_o - w = self.w_func(ground_o.float()) + w = self.w_func(ground_o.int()) infs = torch.isinf(w) if self.batch: w[infs] = 0.0 @@ -623,11 +623,11 @@ def _compute_alpha_generalized_true_positives(self, flat_target: torch.Tensor) - Args: flat_target: the target tensor. """ - alpha: torch.Tensor = torch.ones((flat_target.size(0), self.num_classes)).float().to(flat_target.device) + alpha: torch.Tensor = torch.ones((flat_target.size(0), self.num_classes)).int().to(flat_target.device) if self.alpha_mode == "GDL": # GDL style # Define alpha like in the generalized dice loss # i.e. the inverse of the volume of each class. - one_hot_f = F.one_hot(flat_target, num_classes=self.num_classes).permute(0, 2, 1).float() + one_hot_f = F.one_hot(flat_target, num_classes=self.num_classes).permute(0, 2, 1).int() volumes = torch.sum(one_hot_f, dim=2) alpha = 1.0 / (volumes + 1.0) else: # default, i.e. like in the original paper diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index 98c1a071b6..778c2fb27b 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -154,8 +154,6 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: raise ValueError(f"ground truth has different shape ({target.shape}) from input ({input.shape})") loss: Optional[torch.Tensor] = None - input = input.float() - target = target.float() if self.use_softmax: if not self.include_background and self.alpha is not None: self.alpha = None diff --git a/monai/metrics/confusion_matrix.py b/monai/metrics/confusion_matrix.py index 9083b7733f..2a883348e1 100644 --- a/monai/metrics/confusion_matrix.py +++ b/monai/metrics/confusion_matrix.py @@ -153,9 +153,6 @@ def get_confusion_matrix(y_pred: torch.Tensor, y: torch.Tensor, include_backgrou if not include_background: y_pred, y = ignore_background(y_pred=y_pred, y=y) - y = y.float() - y_pred = y_pred.float() - if y.shape != y_pred.shape: raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.") @@ -165,8 +162,8 @@ def get_confusion_matrix(y_pred: torch.Tensor, y: torch.Tensor, include_backgrou # As for classification tasks, S equals to 1. y_pred = y_pred.reshape(batch_size, n_class, -1) y = y.reshape(batch_size, n_class, -1) - tp = ((y_pred + y) == 2).float() - tn = ((y_pred + y) == 0).float() + tp = ((y_pred + y) == 2).int() + tn = ((y_pred + y) == 0).int() tp = tp.sum(dim=[2]) tn = tn.sum(dim=[2]) diff --git a/monai/metrics/f_beta_score.py b/monai/metrics/f_beta_score.py index a5703105a2..bc35ca9534 100644 --- a/monai/metrics/f_beta_score.py +++ b/monai/metrics/f_beta_score.py @@ -63,9 +63,6 @@ def get_f_beta_score(y_pred: torch.Tensor, y: torch.Tensor, include_background: if not include_background: y_pred, y = ignore_background(y_pred=y_pred, y=y) - y = y.float() - y_pred = y_pred.float() - if y.shape != y_pred.shape: raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.") @@ -75,8 +72,8 @@ def get_f_beta_score(y_pred: torch.Tensor, y: torch.Tensor, include_background: # As for classification tasks, S equals to 1. y_pred = y_pred.view(batch_size, n_class, -1) y = y.view(batch_size, n_class, -1) - tp = ((y_pred + y) == 2).float() - tn = ((y_pred + y) == 0).float() + tp = ((y_pred + y) == 2).int() + tn = ((y_pred + y) == 0).int() tp = tp.sum(dim=[2]) tn = tn.sum(dim=[2]) diff --git a/monai/metrics/meaniou.py b/monai/metrics/meaniou.py index 7d1ae49f25..65c53f7aa5 100644 --- a/monai/metrics/meaniou.py +++ b/monai/metrics/meaniou.py @@ -130,9 +130,6 @@ def compute_iou( if not include_background: y_pred, y = ignore_background(y_pred=y_pred, y=y) - y = y.float() - y_pred = y_pred.float() - if y.shape != y_pred.shape: raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.") diff --git a/monai/metrics/regression.py b/monai/metrics/regression.py index f37230f09e..9d29654ee3 100644 --- a/monai/metrics/regression.py +++ b/monai/metrics/regression.py @@ -111,9 +111,6 @@ def __init__(self, reduction: MetricReduction | str = MetricReduction.MEAN, get_ self.sq_func = partial(torch.pow, exponent=2.0) def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - y_pred = y_pred.float() - y = y.float() - return compute_mean_error_metrics(y_pred, y, func=self.sq_func) @@ -143,9 +140,6 @@ def __init__(self, reduction: MetricReduction | str = MetricReduction.MEAN, get_ self.abs_func = torch.abs def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - y_pred = y_pred.float() - y = y.float() - return compute_mean_error_metrics(y_pred, y, func=self.abs_func) @@ -176,9 +170,6 @@ def __init__(self, reduction: MetricReduction | str = MetricReduction.MEAN, get_ self.sq_func = partial(torch.pow, exponent=2.0) def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - y_pred = y_pred.float() - y = y.float() - mse_out = compute_mean_error_metrics(y_pred, y, func=self.sq_func) return torch.sqrt(mse_out) @@ -218,9 +209,6 @@ def __init__( self.sq_func = partial(torch.pow, exponent=2.0) def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> Any: - y_pred = y_pred.float() - y = y.float() - mse_out = compute_mean_error_metrics(y_pred, y, func=self.sq_func) return 20 * math.log10(self.max_val) - 10 * torch.log10(mse_out) diff --git a/monai/metrics/surface_dice.py b/monai/metrics/surface_dice.py index f8c402a756..635eb1bc24 100644 --- a/monai/metrics/surface_dice.py +++ b/monai/metrics/surface_dice.py @@ -228,9 +228,6 @@ def compute_surface_dice( f"y_pred and y should have same shape, but instead, shapes are {y_pred.shape} (y_pred) and {y.shape} (y)." ) - y = y.float() - y_pred = y_pred.float() - batch_size, n_class = y_pred.shape[:2] if n_class != len(class_thresholds): diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index fe145b0f50..d4a0ec441d 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -95,7 +95,7 @@ def do_metric_reduction( # some elements might be Nan (if ground truth y was missing (zeros)) # we need to account for it nans = torch.isnan(f) - not_nans = (~nans).float() + not_nans = (~nans).int() t_zero = torch.zeros(1, device=f.device, dtype=f.dtype) reduction = look_up_option(reduction, MetricReduction) @@ -108,7 +108,7 @@ def do_metric_reduction( not_nans = not_nans.sum(dim=1) f = torch.where(not_nans > 0, f.sum(dim=1) / not_nans, t_zero) # channel average - not_nans = (not_nans > 0).float().sum(dim=0) + not_nans = (not_nans > 0).int().sum(dim=0) f = torch.where(not_nans > 0, f.sum(dim=0) / not_nans, t_zero) # batch average elif reduction == MetricReduction.SUM: From 21157f60c5c3817d46285644924645761a9156ca Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 2 Nov 2023 18:04:26 +0800 Subject: [PATCH 2/8] add unit test Signed-off-by: KumoLiu --- tests/test_compute_confusion_matrix.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/test_compute_confusion_matrix.py b/tests/test_compute_confusion_matrix.py index a886d8b7e4..2a8ae37745 100644 --- a/tests/test_compute_confusion_matrix.py +++ b/tests/test_compute_confusion_matrix.py @@ -210,6 +210,10 @@ TEST_CASES_CLF = [data_clf.copy(), result_clf] +TEST_CASE_PRECISION = [ + {"y_pred": torch.zeros([1, 1, 1024, 1024, 44], device=_device), "y": torch.zeros([1, 1, 1024, 1024, 44], device=_device)}, + torch.tensor([[[0.0, 0.0, 46137344.0, 0.0]]]) +] class TestConfusionMatrix(unittest.TestCase): @parameterized.expand([TEST_CASE_CONFUSION_MATRIX]) @@ -274,6 +278,13 @@ def test_clf_with_nan(self, input_data, expected_value): expected_value = compute_confusion_matrix_metric("tpr", expected_value) assert_allclose(result, expected_value, atol=1e-4, rtol=1e-4) + @parameterized.expand([TEST_CASE_PRECISION]) + def test_precision(self, input_data, expected_value): + # include or ignore background + result = get_confusion_matrix(**input_data) + assert_allclose(result, expected_value, atol=1e-4, rtol=1e-4) + np.testing.assert_equal(result.device, input_data["y_pred"].device) + if __name__ == "__main__": unittest.main() From 3c650d0841d360e5f471c255bdf53644908385da Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 2 Nov 2023 18:05:27 +0800 Subject: [PATCH 3/8] fix flake8 Signed-off-by: KumoLiu --- tests/test_compute_confusion_matrix.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_compute_confusion_matrix.py b/tests/test_compute_confusion_matrix.py index 2a8ae37745..e0a92aec67 100644 --- a/tests/test_compute_confusion_matrix.py +++ b/tests/test_compute_confusion_matrix.py @@ -211,10 +211,14 @@ TEST_CASES_CLF = [data_clf.copy(), result_clf] TEST_CASE_PRECISION = [ - {"y_pred": torch.zeros([1, 1, 1024, 1024, 44], device=_device), "y": torch.zeros([1, 1, 1024, 1024, 44], device=_device)}, - torch.tensor([[[0.0, 0.0, 46137344.0, 0.0]]]) + { + "y_pred": torch.zeros([1, 1, 1024, 1024, 44], device=_device), + "y": torch.zeros([1, 1, 1024, 1024, 44], device=_device), + }, + torch.tensor([[[0.0, 0.0, 46137344.0, 0.0]]]), ] + class TestConfusionMatrix(unittest.TestCase): @parameterized.expand([TEST_CASE_CONFUSION_MATRIX]) def test_value(self, input_data, expected_value): From 166a62864aa210d8e7819e152d01a00e93a20862 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 2 Nov 2023 18:32:35 +0800 Subject: [PATCH 4/8] revert loss Signed-off-by: KumoLiu --- monai/losses/dice.py | 6 +++--- monai/losses/focal_loss.py | 2 ++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index f29865fa8e..b3c0f57c6e 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -377,7 +377,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: denominator = ground_o + pred_o - w = self.w_func(ground_o.int()) + w = self.w_func(ground_o.float()) infs = torch.isinf(w) if self.batch: w[infs] = 0.0 @@ -623,11 +623,11 @@ def _compute_alpha_generalized_true_positives(self, flat_target: torch.Tensor) - Args: flat_target: the target tensor. """ - alpha: torch.Tensor = torch.ones((flat_target.size(0), self.num_classes)).int().to(flat_target.device) + alpha: torch.Tensor = torch.ones((flat_target.size(0), self.num_classes)).float().to(flat_target.device) if self.alpha_mode == "GDL": # GDL style # Define alpha like in the generalized dice loss # i.e. the inverse of the volume of each class. - one_hot_f = F.one_hot(flat_target, num_classes=self.num_classes).permute(0, 2, 1).int() + one_hot_f = F.one_hot(flat_target, num_classes=self.num_classes).permute(0, 2, 1).float() volumes = torch.sum(one_hot_f, dim=2) alpha = 1.0 / (volumes + 1.0) else: # default, i.e. like in the original paper diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index 778c2fb27b..98c1a071b6 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -154,6 +154,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: raise ValueError(f"ground truth has different shape ({target.shape}) from input ({input.shape})") loss: Optional[torch.Tensor] = None + input = input.float() + target = target.float() if self.use_softmax: if not self.include_background and self.alpha is not None: self.alpha = None From 4f51f20ffdf13d80b0f3e5a3e17aad0b99dabcde Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 2 Nov 2023 18:40:44 +0800 Subject: [PATCH 5/8] remove int Signed-off-by: KumoLiu --- monai/metrics/confusion_matrix.py | 4 ++-- monai/metrics/f_beta_score.py | 4 ++-- monai/metrics/utils.py | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/monai/metrics/confusion_matrix.py b/monai/metrics/confusion_matrix.py index 2a883348e1..0189b33955 100644 --- a/monai/metrics/confusion_matrix.py +++ b/monai/metrics/confusion_matrix.py @@ -162,8 +162,8 @@ def get_confusion_matrix(y_pred: torch.Tensor, y: torch.Tensor, include_backgrou # As for classification tasks, S equals to 1. y_pred = y_pred.reshape(batch_size, n_class, -1) y = y.reshape(batch_size, n_class, -1) - tp = ((y_pred + y) == 2).int() - tn = ((y_pred + y) == 0).int() + tp = ((y_pred + y) == 2) + tn = ((y_pred + y) == 0) tp = tp.sum(dim=[2]) tn = tn.sum(dim=[2]) diff --git a/monai/metrics/f_beta_score.py b/monai/metrics/f_beta_score.py index bc35ca9534..a92fa09a10 100644 --- a/monai/metrics/f_beta_score.py +++ b/monai/metrics/f_beta_score.py @@ -72,8 +72,8 @@ def get_f_beta_score(y_pred: torch.Tensor, y: torch.Tensor, include_background: # As for classification tasks, S equals to 1. y_pred = y_pred.view(batch_size, n_class, -1) y = y.view(batch_size, n_class, -1) - tp = ((y_pred + y) == 2).int() - tn = ((y_pred + y) == 0).int() + tp = ((y_pred + y) == 2) + tn = ((y_pred + y) == 0) tp = tp.sum(dim=[2]) tn = tn.sum(dim=[2]) diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index d4a0ec441d..eed230b3bf 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -95,12 +95,12 @@ def do_metric_reduction( # some elements might be Nan (if ground truth y was missing (zeros)) # we need to account for it nans = torch.isnan(f) - not_nans = (~nans).int() + not_nans = (~nans) t_zero = torch.zeros(1, device=f.device, dtype=f.dtype) reduction = look_up_option(reduction, MetricReduction) if reduction == MetricReduction.NONE: - return f, not_nans + return f, not_nans.float() f[nans] = 0 if reduction == MetricReduction.MEAN: @@ -108,7 +108,7 @@ def do_metric_reduction( not_nans = not_nans.sum(dim=1) f = torch.where(not_nans > 0, f.sum(dim=1) / not_nans, t_zero) # channel average - not_nans = (not_nans > 0).int().sum(dim=0) + not_nans = (not_nans > 0).sum(dim=0) f = torch.where(not_nans > 0, f.sum(dim=0) / not_nans, t_zero) # batch average elif reduction == MetricReduction.SUM: From bbc0f08b6ff89663cb759d6f55d9f50f724e7815 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 2 Nov 2023 18:58:21 +0800 Subject: [PATCH 6/8] bc fix Signed-off-by: KumoLiu --- monai/metrics/confusion_matrix.py | 6 +++--- monai/metrics/f_beta_score.py | 6 +++--- monai/metrics/utils.py | 26 +++++++++++++------------- 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/monai/metrics/confusion_matrix.py b/monai/metrics/confusion_matrix.py index 0189b33955..f5da1b6b14 100644 --- a/monai/metrics/confusion_matrix.py +++ b/monai/metrics/confusion_matrix.py @@ -165,9 +165,9 @@ def get_confusion_matrix(y_pred: torch.Tensor, y: torch.Tensor, include_backgrou tp = ((y_pred + y) == 2) tn = ((y_pred + y) == 0) - tp = tp.sum(dim=[2]) - tn = tn.sum(dim=[2]) - p = y.sum(dim=[2]) + tp = tp.sum(dim=[2]).float() + tn = tn.sum(dim=[2]).float() + p = y.sum(dim=[2]).float() n = y.shape[-1] - p fn = p - tp diff --git a/monai/metrics/f_beta_score.py b/monai/metrics/f_beta_score.py index a92fa09a10..79af7a5acd 100644 --- a/monai/metrics/f_beta_score.py +++ b/monai/metrics/f_beta_score.py @@ -75,9 +75,9 @@ def get_f_beta_score(y_pred: torch.Tensor, y: torch.Tensor, include_background: tp = ((y_pred + y) == 2) tn = ((y_pred + y) == 0) - tp = tp.sum(dim=[2]) - tn = tn.sum(dim=[2]) - p = y.sum(dim=[2]) + tp = tp.sum(dim=[2]).float() + tn = tn.sum(dim=[2]).float() + p = y.sum(dim=[2]).float() n = y.shape[-1] - p fn = p - tp diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index eed230b3bf..77e6510ab6 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -105,27 +105,27 @@ def do_metric_reduction( f[nans] = 0 if reduction == MetricReduction.MEAN: # 2 steps, first, mean by channel (accounting for nans), then by batch - not_nans = not_nans.sum(dim=1) - f = torch.where(not_nans > 0, f.sum(dim=1) / not_nans, t_zero) # channel average + not_nans = not_nans.sum(dim=1).float() + f = torch.where(not_nans > 0, f.sum(dim=1).float() / not_nans, t_zero) # channel average - not_nans = (not_nans > 0).sum(dim=0) - f = torch.where(not_nans > 0, f.sum(dim=0) / not_nans, t_zero) # batch average + not_nans = (not_nans > 0).sum(dim=0).float() + f = torch.where(not_nans > 0, f.sum(dim=0).float() / not_nans, t_zero) # batch average elif reduction == MetricReduction.SUM: - not_nans = not_nans.sum(dim=[0, 1]) + not_nans = not_nans.sum(dim=[0, 1]).float() f = torch.sum(f, dim=[0, 1]) # sum over the batch and channel dims elif reduction == MetricReduction.MEAN_BATCH: - not_nans = not_nans.sum(dim=0) - f = torch.where(not_nans > 0, f.sum(dim=0) / not_nans, t_zero) # batch average + not_nans = not_nans.sum(dim=0).float() + f = torch.where(not_nans > 0, f.sum(dim=0).float() / not_nans, t_zero) # batch average elif reduction == MetricReduction.SUM_BATCH: - not_nans = not_nans.sum(dim=0) - f = f.sum(dim=0) # the batch sum + not_nans = not_nans.sum(dim=0).float() + f = f.sum(dim=0).float() # the batch sum elif reduction == MetricReduction.MEAN_CHANNEL: - not_nans = not_nans.sum(dim=1) - f = torch.where(not_nans > 0, f.sum(dim=1) / not_nans, t_zero) # channel average + not_nans = not_nans.sum(dim=1).float() + f = torch.where(not_nans > 0, f.sum(dim=1).float() / not_nans, t_zero) # channel average elif reduction == MetricReduction.SUM_CHANNEL: - not_nans = not_nans.sum(dim=1) - f = f.sum(dim=1) # the channel sum + not_nans = not_nans.sum(dim=1).float() + f = f.sum(dim=1).float() # the channel sum elif reduction != MetricReduction.NONE: raise ValueError( f"Unsupported reduction: {reduction}, available options are " From a45ce671df65f7ab60a8c90ab2b9cd08df59a7c7 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 2 Nov 2023 19:07:03 +0800 Subject: [PATCH 7/8] fix flake8 Signed-off-by: KumoLiu --- monai/metrics/confusion_matrix.py | 4 ++-- monai/metrics/f_beta_score.py | 4 ++-- monai/metrics/utils.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/metrics/confusion_matrix.py b/monai/metrics/confusion_matrix.py index f5da1b6b14..26ec823081 100644 --- a/monai/metrics/confusion_matrix.py +++ b/monai/metrics/confusion_matrix.py @@ -162,8 +162,8 @@ def get_confusion_matrix(y_pred: torch.Tensor, y: torch.Tensor, include_backgrou # As for classification tasks, S equals to 1. y_pred = y_pred.reshape(batch_size, n_class, -1) y = y.reshape(batch_size, n_class, -1) - tp = ((y_pred + y) == 2) - tn = ((y_pred + y) == 0) + tp = (y_pred + y) == 2 + tn = (y_pred + y) == 0 tp = tp.sum(dim=[2]).float() tn = tn.sum(dim=[2]).float() diff --git a/monai/metrics/f_beta_score.py b/monai/metrics/f_beta_score.py index 79af7a5acd..61e4525662 100644 --- a/monai/metrics/f_beta_score.py +++ b/monai/metrics/f_beta_score.py @@ -72,8 +72,8 @@ def get_f_beta_score(y_pred: torch.Tensor, y: torch.Tensor, include_background: # As for classification tasks, S equals to 1. y_pred = y_pred.view(batch_size, n_class, -1) y = y.view(batch_size, n_class, -1) - tp = ((y_pred + y) == 2) - tn = ((y_pred + y) == 0) + tp = (y_pred + y) == 2 + tn = (y_pred + y) == 0 tp = tp.sum(dim=[2]).float() tn = tn.sum(dim=[2]).float() diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index 77e6510ab6..774ebad47f 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -95,7 +95,7 @@ def do_metric_reduction( # some elements might be Nan (if ground truth y was missing (zeros)) # we need to account for it nans = torch.isnan(f) - not_nans = (~nans) + not_nans = ~nans t_zero = torch.zeros(1, device=f.device, dtype=f.dtype) reduction = look_up_option(reduction, MetricReduction) From 8bd7518487668de9b073e94d72323bb3893ae92f Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 2 Nov 2023 21:13:09 +0800 Subject: [PATCH 8/8] fix ci Signed-off-by: KumoLiu --- monai/metrics/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index 774ebad47f..c139fc35ed 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -97,7 +97,7 @@ def do_metric_reduction( nans = torch.isnan(f) not_nans = ~nans - t_zero = torch.zeros(1, device=f.device, dtype=f.dtype) + t_zero = torch.zeros(1, device=f.device, dtype=torch.float) reduction = look_up_option(reduction, MetricReduction) if reduction == MetricReduction.NONE: return f, not_nans.float()