From 8f110a0fcd52e4e8173b62b10d81065250ed9a7f Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 22 Jan 2021 11:45:43 +0800 Subject: [PATCH 1/7] [DLMED] add IterationHandler refer to the EpochHandler in ignite Signed-off-by: Nic Ma --- docs/source/handlers.rst | 6 ++ monai/handlers/__init__.py | 1 + monai/handlers/confusion_matrix.py | 82 +++----------------- monai/handlers/hausdorff_distance.py | 55 ++----------- monai/handlers/iteration_metric.py | 98 ++++++++++++++++++++++++ monai/handlers/mean_dice.py | 56 ++------------ monai/handlers/surface_distance.py | 56 ++------------ monai/metrics/utils.py | 6 +- tests/test_handler_confusion_matrix.py | 13 ++-- tests/test_handler_hausdorff_distance.py | 3 +- tests/test_handler_surface_distance.py | 3 +- 11 files changed, 146 insertions(+), 233 deletions(-) create mode 100644 monai/handlers/iteration_metric.py diff --git a/docs/source/handlers.rst b/docs/source/handlers.rst index 2962f725d8..d1ce257cb7 100644 --- a/docs/source/handlers.rst +++ b/docs/source/handlers.rst @@ -22,6 +22,12 @@ CSV saver :members: +Iteration Metric +---------------- +.. autoclass:: IterationMetric + :members: + + Mean Dice metrics handler ------------------------- .. autoclass:: MeanDice diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index 1df516eaf0..a873cd8b15 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -14,6 +14,7 @@ from .classification_saver import ClassificationSaver from .confusion_matrix import ConfusionMatrix from .hausdorff_distance import HausdorffDistance +from .iteration_metric import IterationMetric from .lr_schedule_handler import LrScheduleHandler from .mean_dice import MeanDice from .metric_logger import MetricLogger diff --git a/monai/handlers/confusion_matrix.py b/monai/handlers/confusion_matrix.py index fe60b964a7..7ca10fa91a 100644 --- a/monai/handlers/confusion_matrix.py +++ b/monai/handlers/confusion_matrix.py @@ -9,21 +9,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional, Sequence +from typing import Callable, Optional import torch from monai.metrics import ConfusionMatrixMetric, compute_confusion_matrix_metric from monai.metrics.utils import MetricReduction, do_metric_reduction -from monai.utils import exact_version, optional_import +from monai.handlers.iteration_metric import IterationMetric -NotComputableError, _ = optional_import("ignite.exceptions", "0.4.2", exact_version, "NotComputableError") -Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric") -reinit__is_reduced, _ = optional_import("ignite.metrics.metric", "0.4.2", exact_version, "reinit__is_reduced") -sync_all_reduce, _ = optional_import("ignite.metrics.metric", "0.4.2", exact_version, "sync_all_reduce") - -class ConfusionMatrix(Metric): # type: ignore[valid-type, misc] # due to optional_import +class ConfusionMatrix(IterationMetric): """ Compute confusion matrix related metrics from full size Tensor and collects average over batch, class-channels, iterations. """ @@ -32,7 +27,6 @@ def __init__( self, include_background: bool = True, metric_name: str = "hit_rate", - compute_sample: bool = False, output_transform: Callable = lambda x: x, device: Optional[torch.device] = None, ) -> None: @@ -48,79 +42,21 @@ def __init__( ``"informedness"``, ``"markedness"``] Some of the metrics have multiple aliases (as shown in the wikipedia page aforementioned), and you can also input those names instead. - compute_sample: if ``True``, each sample's metric will be computed first. - If ``False``, the confusion matrix for all samples will be accumulated first. Defaults to ``False``. output_transform: transform the ignite.engine.state.output into [y_pred, y] pair. device: device specification in case of distributed computation usage. See also: :py:meth:`monai.metrics.confusion_matrix` """ - super().__init__(output_transform, device=device) - self.confusion_matrix = ConfusionMatrixMetric( + metric_fn = ConfusionMatrixMetric( include_background=include_background, metric_name=metric_name, - compute_sample=compute_sample, - reduction=MetricReduction.MEAN, + compute_sample=False, + reduction=MetricReduction.NONE, ) - self._sum = 0.0 - self._num_examples = 0 - self.compute_sample = compute_sample self.metric_name = metric_name - self._total_tp = 0.0 - self._total_fp = 0.0 - self._total_tn = 0.0 - self._total_fn = 0.0 - - @reinit__is_reduced - def reset(self) -> None: - self._sum = 0.0 - self._num_examples = 0 - self._total_tp = 0.0 - self._total_fp = 0.0 - self._total_tn = 0.0 - self._total_fn = 0.0 - - @reinit__is_reduced - def update(self, output: Sequence[torch.Tensor]) -> None: - """ - Args: - output: sequence with contents [y_pred, y]. - - Raises: - ValueError: When ``output`` length is not 2. This metric can only support y_pred and y. + super().__init__(metric_fn=metric_fn, output_transform=output_transform, device=device) - """ - if len(output) != 2: - raise ValueError(f"output must have length 2, got {len(output)}.") - y_pred, y = output - if self.compute_sample is True: - score, not_nans = self.confusion_matrix(y_pred, y) - not_nans = int(not_nans.item()) - - # add all items in current batch - self._sum += score.item() * not_nans - self._num_examples += not_nans - else: - confusion_matrix = self.confusion_matrix(y_pred, y) - confusion_matrix, _ = do_metric_reduction(confusion_matrix, MetricReduction.SUM) - self._total_tp += confusion_matrix[0].item() - self._total_fp += confusion_matrix[1].item() - self._total_tn += confusion_matrix[2].item() - self._total_fn += confusion_matrix[3].item() - - @sync_all_reduce("_sum", "_num_examples", "_total_tp", "_total_fp", "_total_tn", "_total_fn") - def compute(self): - """ - Raises: - NotComputableError: When ``compute`` is called before an ``update`` occurs. - - """ - if self.compute_sample is True: - if self._num_examples == 0: - raise NotComputableError( - "ConfusionMatrix metric must have at least one example before it can be computed." - ) - return self._sum / self._num_examples - confusion_matrix = torch.tensor([self._total_tp, self._total_fp, self._total_tn, self._total_fn]) + def _reduce(self, scores) -> torch.Tensor: + confusion_matrix, _ = do_metric_reduction(scores, MetricReduction.MEAN) return compute_confusion_matrix_metric(self.metric_name, confusion_matrix) diff --git a/monai/handlers/hausdorff_distance.py b/monai/handlers/hausdorff_distance.py index 581550a703..f87ba8f3ea 100644 --- a/monai/handlers/hausdorff_distance.py +++ b/monai/handlers/hausdorff_distance.py @@ -9,20 +9,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional, Sequence +from typing import Callable, Optional import torch from monai.metrics import HausdorffDistanceMetric -from monai.utils import MetricReduction, exact_version, optional_import +from monai.utils import MetricReduction +from monai.handlers.iteration_metric import IterationMetric -NotComputableError, _ = optional_import("ignite.exceptions", "0.4.2", exact_version, "NotComputableError") -Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric") -reinit__is_reduced, _ = optional_import("ignite.metrics.metric", "0.4.2", exact_version, "reinit__is_reduced") -sync_all_reduce, _ = optional_import("ignite.metrics.metric", "0.4.2", exact_version, "sync_all_reduce") - -class HausdorffDistance(Metric): # type: ignore[valid-type, misc] # due to optional_import +class HausdorffDistance(IterationMetric): """ Computes Hausdorff distance from full size Tensor and collects average over batch, class-channels, iterations. """ @@ -52,48 +48,11 @@ def __init__( """ super().__init__(output_transform, device=device) - self.hd = HausdorffDistanceMetric( + metric_fn = HausdorffDistanceMetric( include_background=include_background, distance_metric=distance_metric, percentile=percentile, directed=directed, - reduction=MetricReduction.MEAN, + reduction=MetricReduction.NONE, ) - self._sum = 0.0 - self._num_examples = 0 - - @reinit__is_reduced - def reset(self) -> None: - self._sum = 0.0 - self._num_examples = 0 - - @reinit__is_reduced - def update(self, output: Sequence[torch.Tensor]) -> None: - """ - Args: - output: sequence with contents [y_pred, y]. - - Raises: - ValueError: When ``output`` length is not 2. The metric can only support y_pred and y. - - """ - if len(output) != 2: - raise ValueError(f"output must have length 2, got {len(output)}.") - y_pred, y = output - score, not_nans = self.hd(y_pred, y) - not_nans = int(not_nans.item()) - - # add all items in current batch - self._sum += score.item() * not_nans - self._num_examples += not_nans - - @sync_all_reduce("_sum", "_num_examples") - def compute(self) -> float: - """ - Raises: - NotComputableError: When ``compute`` is called before an ``update`` occurs. - - """ - if self._num_examples == 0: - raise NotComputableError("HausdorffDistance must have at least one example before it can be computed.") - return self._sum / self._num_examples + super().__init__(metric_fn=metric_fn, output_transform=output_transform, device=device) diff --git a/monai/handlers/iteration_metric.py b/monai/handlers/iteration_metric.py new file mode 100644 index 0000000000..d44375d04e --- /dev/null +++ b/monai/handlers/iteration_metric.py @@ -0,0 +1,98 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional, Sequence + +import torch + +from monai.utils import MetricReduction, exact_version, optional_import +from monai.metrics import do_metric_reduction + +NotComputableError, _ = optional_import("ignite.exceptions", "0.4.2", exact_version, "NotComputableError") +idist, _ = optional_import("ignite", "0.4.2", exact_version, "distributed") +Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric") +reinit__is_reduced, _ = optional_import("ignite.metrics.metric", "0.4.2", exact_version, "reinit__is_reduced") + + +class IterationMetric(Metric): # type: ignore[valid-type, misc] # due to optional_import + """ + Class for metrics that should be computed on every iteration and compute final results when epoch completed. + Similar to the `EpochMetric` in ignite: + https://github.com/pytorch/ignite/blob/v0.4.2/ignite/metrics/epoch_metric.py#L13. + + Args: + metric_fn: callable function or class to compute raw metric results after every iteration. + expect to return a Tensor with shape (batch, channel, ...) or tuple (Tensor, not_nans). + output_transform: transform the ignite.engine.state.output into [y_pred, y] pair. + device: device specification in case of distributed computation usage. + + """ + + def __init__( + self, + metric_fn: Callable, + output_transform: Callable = lambda x: x, + device: Optional[torch.device] = None, + ) -> None: + super().__init__(output_transform, device=device) + self.metric_fn = metric_fn + self._scores = [] + + @reinit__is_reduced + def reset(self) -> None: + self._scores = [] + + @reinit__is_reduced + def update(self, output: Sequence[torch.Tensor]) -> None: + """ + Args: + output: sequence with contents [y_pred, y]. + + Raises: + ValueError: When ``output`` length is not 2. metric_fn can only support y_pred and y. + + """ + if len(output) != 2: + raise ValueError(f"output must have length 2, got {len(output)}.") + y_pred, y = output + score = self.metric_fn(y_pred, y) + if isinstance(score, (tuple, list)): + score = score[0] + self._scores.append(score) + + def compute(self) -> None: + """ + Raises: + NotComputableError: When ``compute`` is called before an ``update`` occurs. + + """ + _scores = torch.cat(self._scores, dim=0) + + ws = idist.get_world_size() + + if ws > 1 and not self._is_reduced: + # all gather across all processes + _scores = idist.all_gather(_scores) + self._is_reduced = True + + result = 0.0 + if idist.get_rank() == 0: + # run compute_fn on zero rank only + result = self._reduce(_scores) + + if ws > 1: + # broadcast result to all processes + result = idist.broadcast(result, src=0) + + return result + + def _reduce(self, scores) -> torch.Tensor: + return do_metric_reduction(scores, MetricReduction.MEAN)[0].item() diff --git a/monai/handlers/mean_dice.py b/monai/handlers/mean_dice.py index 3c34948604..df22d62f19 100644 --- a/monai/handlers/mean_dice.py +++ b/monai/handlers/mean_dice.py @@ -9,20 +9,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional, Sequence +from typing import Callable, Optional import torch from monai.metrics import DiceMetric -from monai.utils import MetricReduction, exact_version, optional_import +from monai.utils import MetricReduction +from monai.handlers.iteration_metric import IterationMetric -NotComputableError, _ = optional_import("ignite.exceptions", "0.4.2", exact_version, "NotComputableError") -Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric") -reinit__is_reduced, _ = optional_import("ignite.metrics.metric", "0.4.2", exact_version, "reinit__is_reduced") -sync_all_reduce, _ = optional_import("ignite.metrics.metric", "0.4.2", exact_version, "sync_all_reduce") - -class MeanDice(Metric): # type: ignore[valid-type, misc] # due to optional_import +class MeanDice(IterationMetric): """ Computes Dice score metric from full size Tensor and collects average over batch, class-channels, iterations. """ @@ -44,46 +40,8 @@ def __init__( See also: :py:meth:`monai.metrics.meandice.compute_meandice` """ - super().__init__(output_transform, device=device) - self.dice = DiceMetric( + metric_fn = DiceMetric( include_background=include_background, - reduction=MetricReduction.MEAN, + reduction=MetricReduction.NONE, ) - self._sum = 0.0 - self._num_examples = 0 - - @reinit__is_reduced - def reset(self) -> None: - self._sum = 0.0 - self._num_examples = 0 - - @reinit__is_reduced - def update(self, output: Sequence[torch.Tensor]) -> None: - """ - Args: - output: sequence with contents [y_pred, y]. - - Raises: - ValueError: When ``output`` length is not 2. MeanDice metric can only support y_pred and y. - - """ - if len(output) != 2: - raise ValueError(f"output must have length 2, got {len(output)}.") - y_pred, y = output - score, not_nans = self.dice(y_pred, y) - not_nans = int(not_nans.item()) - - # add all items in current batch - self._sum += score.item() * not_nans - self._num_examples += not_nans - - @sync_all_reduce("_sum", "_num_examples") - def compute(self) -> float: - """ - Raises: - NotComputableError: When ``compute`` is called before an ``update`` occurs. - - """ - if self._num_examples == 0: - raise NotComputableError("MeanDice must have at least one example before it can be computed.") - return self._sum / self._num_examples + super().__init__(metric_fn=metric_fn, output_transform=output_transform, device=device) diff --git a/monai/handlers/surface_distance.py b/monai/handlers/surface_distance.py index 514cf3e6c7..4e2366c666 100644 --- a/monai/handlers/surface_distance.py +++ b/monai/handlers/surface_distance.py @@ -9,20 +9,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional, Sequence +from typing import Callable, Optional import torch from monai.metrics import SurfaceDistanceMetric -from monai.utils import MetricReduction, exact_version, optional_import +from monai.utils import MetricReduction +from monai.handlers.iteration_metric import IterationMetric -NotComputableError, _ = optional_import("ignite.exceptions", "0.4.2", exact_version, "NotComputableError") -Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric") -reinit__is_reduced, _ = optional_import("ignite.metrics.metric", "0.4.2", exact_version, "reinit__is_reduced") -sync_all_reduce, _ = optional_import("ignite.metrics.metric", "0.4.2", exact_version, "sync_all_reduce") - -class SurfaceDistance(Metric): # type: ignore[valid-type, misc] # due to optional_import +class SurfaceDistance(IterationMetric): """ Computes surface distance from full size Tensor and collects average over batch, class-channels, iterations. """ @@ -48,48 +44,10 @@ def __init__( device: device specification in case of distributed computation usage. """ - super().__init__(output_transform, device=device) - self.hd = SurfaceDistanceMetric( + metric_fn = SurfaceDistanceMetric( include_background=include_background, symmetric=symmetric, distance_metric=distance_metric, - reduction=MetricReduction.MEAN, + reduction=MetricReduction.NONE, ) - self._sum = 0.0 - self._num_examples = 0 - - @reinit__is_reduced - def reset(self) -> None: - self._sum = 0.0 - self._num_examples = 0 - - @reinit__is_reduced - def update(self, output: Sequence[torch.Tensor]) -> None: - """ - Args: - output: sequence with contents [y_pred, y]. - - Raises: - ValueError: When ``output`` length is not 2. The metric can only support y_pred and y. - - """ - if len(output) != 2: - raise ValueError(f"output must have length 2, got {len(output)}.") - y_pred, y = output - score, not_nans = self.hd(y_pred, y) - not_nans = int(not_nans.item()) - - # add all items in current batch - self._sum += score.item() * not_nans - self._num_examples += not_nans - - @sync_all_reduce("_sum", "_num_examples") - def compute(self) -> float: - """ - Raises: - NotComputableError: When ``compute`` is called before an ``update`` occurs. - - """ - if self._num_examples == 0: - raise NotComputableError("SurfaceDistance must have at least one example before it can be computed.") - return self._sum / self._num_examples + super().__init__(metric_fn=metric_fn, output_transform=output_transform, device=device) diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index 68f21f1613..cc7049ff81 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -53,7 +53,7 @@ def do_metric_reduction( f: a tensor that contains the calculated metric scores per batch and per class. The first two dims should be batch and class. reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, - ``"mean_channel"``, ``"sum_channel"``} + ``"mean_channel"``, ``"sum_channel"``}, if "none", return the input f tensor and not_nans. Define the mode to reduce computation result of 1 batch data. Defaults to ``"mean"``. Raises: @@ -65,11 +65,13 @@ def do_metric_reduction( # we need to account for it nans = torch.isnan(f) not_nans = (~nans).float() - f[nans] = 0 t_zero = torch.zeros(1, device=f.device, dtype=f.dtype) reduction = MetricReduction(reduction) + if reduction == MetricReduction.NONE: + return f, not_nans + f[nans] = 0 if reduction == MetricReduction.MEAN: # 2 steps, first, mean by channel (accounting for nans), then by batch not_nans = not_nans.sum(dim=1) diff --git a/tests/test_handler_confusion_matrix.py b/tests/test_handler_confusion_matrix.py index ac5edb72e2..e533245536 100644 --- a/tests/test_handler_confusion_matrix.py +++ b/tests/test_handler_confusion_matrix.py @@ -17,11 +17,10 @@ from monai.handlers import ConfusionMatrix -TEST_CASE_1 = [{"include_background": True, "metric_name": "f1", "compute_sample": False}, 0.75] -TEST_CASE_2 = [{"include_background": False, "metric_name": "ppv", "compute_sample": False}, 1.0] +TEST_CASE_1 = [{"include_background": True, "metric_name": "f1"}, 0.75] +TEST_CASE_2 = [{"include_background": False, "metric_name": "ppv"}, 1.0] -TEST_CASE_SEG_1 = [{"include_background": True, "metric_name": "tpr", "compute_sample": True}, 0.8333] -TEST_CASE_SEG_2 = [{"include_background": True, "metric_name": "tpr", "compute_sample": False}, 0.7] +TEST_CASE_SEG_1 = [{"include_background": True, "metric_name": "tpr"}, 0.7] data_1: Dict[Any, Any] = { "y_pred": torch.tensor( @@ -70,7 +69,7 @@ def test_compute(self, input_params, expected_avg): avg_metric = metric.compute() self.assertAlmostEqual(avg_metric, expected_avg, places=4) - @parameterized.expand([TEST_CASE_SEG_1, TEST_CASE_SEG_2]) + @parameterized.expand([TEST_CASE_SEG_1]) def test_compute_seg(self, input_params, expected_avg): metric = ConfusionMatrix(**input_params) @@ -82,9 +81,7 @@ def test_compute_seg(self, input_params, expected_avg): y = data_2["y"] metric.update([y_pred, y]) - avg_metric = metric.compute() - if input_params["compute_sample"] is False: - avg_metric = avg_metric.item() + avg_metric = metric.compute().item() self.assertAlmostEqual(avg_metric, expected_avg, places=4) @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) diff --git a/tests/test_handler_hausdorff_distance.py b/tests/test_handler_hausdorff_distance.py index ee30040cc8..edf59320ea 100644 --- a/tests/test_handler_hausdorff_distance.py +++ b/tests/test_handler_hausdorff_distance.py @@ -71,10 +71,9 @@ def test_compute(self): y_pred, y = TEST_SAMPLE_3 hd_metric.update([y_pred, y]) self.assertEqual(hd_metric.compute(), float("inf")) - self.assertEqual(hd_metric._num_examples, 3) y_pred, y = TEST_SAMPLE_4 hd_metric.update([y_pred, y]) - self.assertEqual(hd_metric._num_examples, 3) + self.assertEqual(hd_metric.compute(), float("inf")) def test_shape_mismatch(self): hd_metric = HausdorffDistance(include_background=True) diff --git a/tests/test_handler_surface_distance.py b/tests/test_handler_surface_distance.py index b4d9584289..656b0d64b2 100644 --- a/tests/test_handler_surface_distance.py +++ b/tests/test_handler_surface_distance.py @@ -71,10 +71,9 @@ def test_compute(self): y_pred, y = TEST_SAMPLE_3 sur_metric.update([y_pred, y]) self.assertAlmostEqual(sur_metric.compute(), float("inf")) - self.assertAlmostEqual(sur_metric._num_examples, 3) y_pred, y = TEST_SAMPLE_4 sur_metric.update([y_pred, y]) - self.assertAlmostEqual(sur_metric._num_examples, 3) + self.assertAlmostEqual(sur_metric.compute(), float("inf")) def test_shape_mismatch(self): sur_metric = SurfaceDistance(include_background=True) From 560802dfd4c8059794961a453fd0467d13ee794f Mon Sep 17 00:00:00 2001 From: monai-bot Date: Fri, 22 Jan 2021 10:25:50 +0000 Subject: [PATCH 2/7] [MONAI] python code formatting Signed-off-by: monai-bot --- monai/handlers/confusion_matrix.py | 2 +- monai/handlers/hausdorff_distance.py | 2 +- monai/handlers/iteration_metric.py | 2 +- monai/handlers/mean_dice.py | 2 +- monai/handlers/surface_distance.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/handlers/confusion_matrix.py b/monai/handlers/confusion_matrix.py index 7ca10fa91a..eba75ef957 100644 --- a/monai/handlers/confusion_matrix.py +++ b/monai/handlers/confusion_matrix.py @@ -13,9 +13,9 @@ import torch +from monai.handlers.iteration_metric import IterationMetric from monai.metrics import ConfusionMatrixMetric, compute_confusion_matrix_metric from monai.metrics.utils import MetricReduction, do_metric_reduction -from monai.handlers.iteration_metric import IterationMetric class ConfusionMatrix(IterationMetric): diff --git a/monai/handlers/hausdorff_distance.py b/monai/handlers/hausdorff_distance.py index f87ba8f3ea..3e4a3d70ba 100644 --- a/monai/handlers/hausdorff_distance.py +++ b/monai/handlers/hausdorff_distance.py @@ -13,9 +13,9 @@ import torch +from monai.handlers.iteration_metric import IterationMetric from monai.metrics import HausdorffDistanceMetric from monai.utils import MetricReduction -from monai.handlers.iteration_metric import IterationMetric class HausdorffDistance(IterationMetric): diff --git a/monai/handlers/iteration_metric.py b/monai/handlers/iteration_metric.py index d44375d04e..76586bdb00 100644 --- a/monai/handlers/iteration_metric.py +++ b/monai/handlers/iteration_metric.py @@ -13,8 +13,8 @@ import torch -from monai.utils import MetricReduction, exact_version, optional_import from monai.metrics import do_metric_reduction +from monai.utils import MetricReduction, exact_version, optional_import NotComputableError, _ = optional_import("ignite.exceptions", "0.4.2", exact_version, "NotComputableError") idist, _ = optional_import("ignite", "0.4.2", exact_version, "distributed") diff --git a/monai/handlers/mean_dice.py b/monai/handlers/mean_dice.py index df22d62f19..057acbee97 100644 --- a/monai/handlers/mean_dice.py +++ b/monai/handlers/mean_dice.py @@ -13,9 +13,9 @@ import torch +from monai.handlers.iteration_metric import IterationMetric from monai.metrics import DiceMetric from monai.utils import MetricReduction -from monai.handlers.iteration_metric import IterationMetric class MeanDice(IterationMetric): diff --git a/monai/handlers/surface_distance.py b/monai/handlers/surface_distance.py index 4e2366c666..17b667ab46 100644 --- a/monai/handlers/surface_distance.py +++ b/monai/handlers/surface_distance.py @@ -13,9 +13,9 @@ import torch +from monai.handlers.iteration_metric import IterationMetric from monai.metrics import SurfaceDistanceMetric from monai.utils import MetricReduction -from monai.handlers.iteration_metric import IterationMetric class SurfaceDistance(IterationMetric): From b2c317e95f44bc474fc9db6bd7007fb515632262 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 22 Jan 2021 18:44:44 +0800 Subject: [PATCH 3/7] [DLMED] fix flake8 issue Signed-off-by: Nic Ma --- monai/handlers/confusion_matrix.py | 4 ++-- monai/handlers/iteration_metric.py | 13 +++++++------ 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/monai/handlers/confusion_matrix.py b/monai/handlers/confusion_matrix.py index eba75ef957..46226f530b 100644 --- a/monai/handlers/confusion_matrix.py +++ b/monai/handlers/confusion_matrix.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional +from typing import Any, Callable, Optional import torch @@ -57,6 +57,6 @@ def __init__( self.metric_name = metric_name super().__init__(metric_fn=metric_fn, output_transform=output_transform, device=device) - def _reduce(self, scores) -> torch.Tensor: + def _reduce(self, scores) -> Any: confusion_matrix, _ = do_metric_reduction(scores, MetricReduction.MEAN) return compute_confusion_matrix_metric(self.metric_name, confusion_matrix) diff --git a/monai/handlers/iteration_metric.py b/monai/handlers/iteration_metric.py index 76586bdb00..c3dd0f9cb5 100644 --- a/monai/handlers/iteration_metric.py +++ b/monai/handlers/iteration_metric.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional, Sequence +from typing import Any, Callable, List, Optional, Sequence import torch @@ -42,9 +42,10 @@ def __init__( output_transform: Callable = lambda x: x, device: Optional[torch.device] = None, ) -> None: - super().__init__(output_transform, device=device) + self._is_reduced: bool = False self.metric_fn = metric_fn - self._scores = [] + self._scores: List = [] + super().__init__(output_transform, device=device) @reinit__is_reduced def reset(self) -> None: @@ -68,7 +69,7 @@ def update(self, output: Sequence[torch.Tensor]) -> None: score = score[0] self._scores.append(score) - def compute(self) -> None: + def compute(self) -> float: """ Raises: NotComputableError: When ``compute`` is called before an ``update`` occurs. @@ -83,7 +84,7 @@ def compute(self) -> None: _scores = idist.all_gather(_scores) self._is_reduced = True - result = 0.0 + result: float = 0.0 if idist.get_rank() == 0: # run compute_fn on zero rank only result = self._reduce(_scores) @@ -94,5 +95,5 @@ def compute(self) -> None: return result - def _reduce(self, scores) -> torch.Tensor: + def _reduce(self, scores) -> Any: return do_metric_reduction(scores, MetricReduction.MEAN)[0].item() From 6f2371e4a498865a84116e59a88f405b32e83558 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sat, 23 Jan 2021 01:22:01 +0800 Subject: [PATCH 4/7] [DLMED] fix the multi-gpu issue Signed-off-by: Nic Ma --- monai/handlers/iteration_metric.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/monai/handlers/iteration_metric.py b/monai/handlers/iteration_metric.py index c3dd0f9cb5..0715514702 100644 --- a/monai/handlers/iteration_metric.py +++ b/monai/handlers/iteration_metric.py @@ -75,9 +75,14 @@ def compute(self) -> float: NotComputableError: When ``compute`` is called before an ``update`` occurs. """ - _scores = torch.cat(self._scores, dim=0) - ws = idist.get_world_size() + if ws > 1 and not self._is_reduced: + # make sure the _scores is evenly-divisible on multi-GPUs + length = len(self._scores) + for _ in range(length, max(idist.all_gather(length)).item()): + self._scores.append(self._scores[0].new_full(self._scores[0].shape, float("NaN"))) + + _scores = torch.cat(self._scores, dim=0) if ws > 1 and not self._is_reduced: # all gather across all processes From 4b13a4e4b160b8bdc3855261fc7e114daab4dccc Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sat, 23 Jan 2021 02:04:07 +0800 Subject: [PATCH 5/7] [DLMED] fix typo Signed-off-by: Nic Ma --- monai/handlers/iteration_metric.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/monai/handlers/iteration_metric.py b/monai/handlers/iteration_metric.py index 0715514702..81bdfa87ca 100644 --- a/monai/handlers/iteration_metric.py +++ b/monai/handlers/iteration_metric.py @@ -75,16 +75,17 @@ def compute(self) -> float: NotComputableError: When ``compute`` is called before an ``update`` occurs. """ + _scores = torch.cat(self._scores, dim=0) + ws = idist.get_world_size() if ws > 1 and not self._is_reduced: # make sure the _scores is evenly-divisible on multi-GPUs - length = len(self._scores) - for _ in range(length, max(idist.all_gather(length)).item()): - self._scores.append(self._scores[0].new_full(self._scores[0].shape, float("NaN"))) + length = _scores.shape[0] + max_len = max(idist.all_gather(length)).item() + if length < max_len: + size = [max_len - length] + list(_scores.shape[1:]) + _scores = torch.cat([_scores, _scores.new_full(size, float("NaN"))], dim=0) - _scores = torch.cat(self._scores, dim=0) - - if ws > 1 and not self._is_reduced: # all gather across all processes _scores = idist.all_gather(_scores) self._is_reduced = True From af9fdc61bf7bc37c32ec3727fb65d98b797a40bc Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sat, 23 Jan 2021 08:28:09 +0800 Subject: [PATCH 6/7] [DLMED] fix distributed tests Signed-off-by: Nic Ma --- monai/handlers/iteration_metric.py | 4 ++-- tests/test_handler_confusion_matrix.py | 2 +- tests/test_handler_confusion_matrix_dist.py | 16 ++++------------ 3 files changed, 7 insertions(+), 15 deletions(-) diff --git a/monai/handlers/iteration_metric.py b/monai/handlers/iteration_metric.py index 81bdfa87ca..08ef4362ea 100644 --- a/monai/handlers/iteration_metric.py +++ b/monai/handlers/iteration_metric.py @@ -99,7 +99,7 @@ def compute(self) -> float: # broadcast result to all processes result = idist.broadcast(result, src=0) - return result + return result.item() if torch.is_tensor(result) else result def _reduce(self, scores) -> Any: - return do_metric_reduction(scores, MetricReduction.MEAN)[0].item() + return do_metric_reduction(scores, MetricReduction.MEAN)[0] diff --git a/tests/test_handler_confusion_matrix.py b/tests/test_handler_confusion_matrix.py index e533245536..cc231b82db 100644 --- a/tests/test_handler_confusion_matrix.py +++ b/tests/test_handler_confusion_matrix.py @@ -81,7 +81,7 @@ def test_compute_seg(self, input_params, expected_avg): y = data_2["y"] metric.update([y_pred, y]) - avg_metric = metric.compute().item() + avg_metric = metric.compute() self.assertAlmostEqual(avg_metric, expected_avg, places=4) @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) diff --git a/tests/test_handler_confusion_matrix_dist.py b/tests/test_handler_confusion_matrix_dist.py index 583ba716aa..ebe0eb9ca7 100644 --- a/tests/test_handler_confusion_matrix_dist.py +++ b/tests/test_handler_confusion_matrix_dist.py @@ -21,17 +21,13 @@ class DistributedConfusionMatrix(DistTestCase): - @DistCall(nnodes=1, nproc_per_node=2) - def test_compute_sample(self): - self._compute(True) - @DistCall(nnodes=1, nproc_per_node=2) def test_compute(self): - self._compute(False) + self._compute() - def _compute(self, compute_sample=True): + def _compute(self): device = f"cuda:{dist.get_rank()}" if torch.cuda.is_available() else "cpu" - metric = ConfusionMatrix(include_background=True, metric_name="tpr", compute_sample=compute_sample) + metric = ConfusionMatrix(include_background=True, metric_name="tpr") if dist.get_rank() == 0: y_pred = torch.tensor( @@ -62,11 +58,7 @@ def _compute(self, compute_sample=True): metric.update([y_pred, y]) avg_metric = metric.compute() - if compute_sample is False: - avg_metric = avg_metric.item() - np.testing.assert_allclose(avg_metric, 0.7, rtol=1e-04, atol=1e-04) - else: - np.testing.assert_allclose(avg_metric, 0.8333, rtol=1e-04, atol=1e-04) + np.testing.assert_allclose(avg_metric, 0.7, rtol=1e-04, atol=1e-04) if __name__ == "__main__": From f586b0888e15c4feca45976c27d20533a11373c3 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sat, 23 Jan 2021 08:43:26 +0800 Subject: [PATCH 7/7] [DLMED] fix flake8 issue Signed-off-by: Nic Ma --- monai/handlers/iteration_metric.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/handlers/iteration_metric.py b/monai/handlers/iteration_metric.py index 08ef4362ea..4d555b9dcb 100644 --- a/monai/handlers/iteration_metric.py +++ b/monai/handlers/iteration_metric.py @@ -69,7 +69,7 @@ def update(self, output: Sequence[torch.Tensor]) -> None: score = score[0] self._scores.append(score) - def compute(self) -> float: + def compute(self) -> Any: """ Raises: NotComputableError: When ``compute`` is called before an ``update`` occurs. @@ -90,7 +90,7 @@ def compute(self) -> float: _scores = idist.all_gather(_scores) self._is_reduced = True - result: float = 0.0 + result: torch.Tensor = torch.zeros(1) if idist.get_rank() == 0: # run compute_fn on zero rank only result = self._reduce(_scores)