Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/handlers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ CSV saver
:members:


Iteration Metric
----------------
.. autoclass:: IterationMetric
:members:


Mean Dice metrics handler
-------------------------
.. autoclass:: MeanDice
Expand Down
1 change: 1 addition & 0 deletions monai/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .classification_saver import ClassificationSaver
from .confusion_matrix import ConfusionMatrix
from .hausdorff_distance import HausdorffDistance
from .iteration_metric import IterationMetric
from .lr_schedule_handler import LrScheduleHandler
from .mean_dice import MeanDice
from .metric_logger import MetricLogger
Expand Down
82 changes: 9 additions & 73 deletions monai/handlers/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Optional, Sequence
from typing import Any, Callable, Optional

import torch

from monai.handlers.iteration_metric import IterationMetric
from monai.metrics import ConfusionMatrixMetric, compute_confusion_matrix_metric
from monai.metrics.utils import MetricReduction, do_metric_reduction
from monai.utils import exact_version, optional_import

NotComputableError, _ = optional_import("ignite.exceptions", "0.4.2", exact_version, "NotComputableError")
Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric")
reinit__is_reduced, _ = optional_import("ignite.metrics.metric", "0.4.2", exact_version, "reinit__is_reduced")
sync_all_reduce, _ = optional_import("ignite.metrics.metric", "0.4.2", exact_version, "sync_all_reduce")


class ConfusionMatrix(Metric): # type: ignore[valid-type, misc] # due to optional_import
class ConfusionMatrix(IterationMetric):
"""
Compute confusion matrix related metrics from full size Tensor and collects average over batch, class-channels, iterations.
"""
Expand All @@ -32,7 +27,6 @@ def __init__(
self,
include_background: bool = True,
metric_name: str = "hit_rate",
compute_sample: bool = False,
output_transform: Callable = lambda x: x,
device: Optional[torch.device] = None,
) -> None:
Expand All @@ -48,79 +42,21 @@ def __init__(
``"informedness"``, ``"markedness"``]
Some of the metrics have multiple aliases (as shown in the wikipedia page aforementioned),
and you can also input those names instead.
compute_sample: if ``True``, each sample's metric will be computed first.
If ``False``, the confusion matrix for all samples will be accumulated first. Defaults to ``False``.
output_transform: transform the ignite.engine.state.output into [y_pred, y] pair.
device: device specification in case of distributed computation usage.

See also:
:py:meth:`monai.metrics.confusion_matrix`
"""
super().__init__(output_transform, device=device)
self.confusion_matrix = ConfusionMatrixMetric(
metric_fn = ConfusionMatrixMetric(
include_background=include_background,
metric_name=metric_name,
compute_sample=compute_sample,
reduction=MetricReduction.MEAN,
compute_sample=False,
reduction=MetricReduction.NONE,
)
self._sum = 0.0
self._num_examples = 0
self.compute_sample = compute_sample
self.metric_name = metric_name
self._total_tp = 0.0
self._total_fp = 0.0
self._total_tn = 0.0
self._total_fn = 0.0

@reinit__is_reduced
def reset(self) -> None:
self._sum = 0.0
self._num_examples = 0
self._total_tp = 0.0
self._total_fp = 0.0
self._total_tn = 0.0
self._total_fn = 0.0

@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
"""
Args:
output: sequence with contents [y_pred, y].

Raises:
ValueError: When ``output`` length is not 2. This metric can only support y_pred and y.
super().__init__(metric_fn=metric_fn, output_transform=output_transform, device=device)

"""
if len(output) != 2:
raise ValueError(f"output must have length 2, got {len(output)}.")
y_pred, y = output
if self.compute_sample is True:
score, not_nans = self.confusion_matrix(y_pred, y)
not_nans = int(not_nans.item())

# add all items in current batch
self._sum += score.item() * not_nans
self._num_examples += not_nans
else:
confusion_matrix = self.confusion_matrix(y_pred, y)
confusion_matrix, _ = do_metric_reduction(confusion_matrix, MetricReduction.SUM)
self._total_tp += confusion_matrix[0].item()
self._total_fp += confusion_matrix[1].item()
self._total_tn += confusion_matrix[2].item()
self._total_fn += confusion_matrix[3].item()

@sync_all_reduce("_sum", "_num_examples", "_total_tp", "_total_fp", "_total_tn", "_total_fn")
def compute(self):
"""
Raises:
NotComputableError: When ``compute`` is called before an ``update`` occurs.

"""
if self.compute_sample is True:
if self._num_examples == 0:
raise NotComputableError(
"ConfusionMatrix metric must have at least one example before it can be computed."
)
return self._sum / self._num_examples
confusion_matrix = torch.tensor([self._total_tp, self._total_fp, self._total_tn, self._total_fn])
def _reduce(self, scores) -> Any:
confusion_matrix, _ = do_metric_reduction(scores, MetricReduction.MEAN)
return compute_confusion_matrix_metric(self.metric_name, confusion_matrix)
55 changes: 7 additions & 48 deletions monai/handlers/hausdorff_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Optional, Sequence
from typing import Callable, Optional

import torch

from monai.handlers.iteration_metric import IterationMetric
from monai.metrics import HausdorffDistanceMetric
from monai.utils import MetricReduction, exact_version, optional_import
from monai.utils import MetricReduction

NotComputableError, _ = optional_import("ignite.exceptions", "0.4.2", exact_version, "NotComputableError")
Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric")
reinit__is_reduced, _ = optional_import("ignite.metrics.metric", "0.4.2", exact_version, "reinit__is_reduced")
sync_all_reduce, _ = optional_import("ignite.metrics.metric", "0.4.2", exact_version, "sync_all_reduce")


class HausdorffDistance(Metric): # type: ignore[valid-type, misc] # due to optional_import
class HausdorffDistance(IterationMetric):
"""
Computes Hausdorff distance from full size Tensor and collects average over batch, class-channels, iterations.
"""
Expand Down Expand Up @@ -52,48 +48,11 @@ def __init__(

"""
super().__init__(output_transform, device=device)
self.hd = HausdorffDistanceMetric(
metric_fn = HausdorffDistanceMetric(
include_background=include_background,
distance_metric=distance_metric,
percentile=percentile,
directed=directed,
reduction=MetricReduction.MEAN,
reduction=MetricReduction.NONE,
)
self._sum = 0.0
self._num_examples = 0

@reinit__is_reduced
def reset(self) -> None:
self._sum = 0.0
self._num_examples = 0

@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
"""
Args:
output: sequence with contents [y_pred, y].

Raises:
ValueError: When ``output`` length is not 2. The metric can only support y_pred and y.

"""
if len(output) != 2:
raise ValueError(f"output must have length 2, got {len(output)}.")
y_pred, y = output
score, not_nans = self.hd(y_pred, y)
not_nans = int(not_nans.item())

# add all items in current batch
self._sum += score.item() * not_nans
self._num_examples += not_nans

@sync_all_reduce("_sum", "_num_examples")
def compute(self) -> float:
"""
Raises:
NotComputableError: When ``compute`` is called before an ``update`` occurs.

"""
if self._num_examples == 0:
raise NotComputableError("HausdorffDistance must have at least one example before it can be computed.")
return self._sum / self._num_examples
super().__init__(metric_fn=metric_fn, output_transform=output_transform, device=device)
105 changes: 105 additions & 0 deletions monai/handlers/iteration_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, List, Optional, Sequence

import torch

from monai.metrics import do_metric_reduction
from monai.utils import MetricReduction, exact_version, optional_import

NotComputableError, _ = optional_import("ignite.exceptions", "0.4.2", exact_version, "NotComputableError")
idist, _ = optional_import("ignite", "0.4.2", exact_version, "distributed")
Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric")
reinit__is_reduced, _ = optional_import("ignite.metrics.metric", "0.4.2", exact_version, "reinit__is_reduced")


class IterationMetric(Metric): # type: ignore[valid-type, misc] # due to optional_import
"""
Class for metrics that should be computed on every iteration and compute final results when epoch completed.
Similar to the `EpochMetric` in ignite:
https://github.com/pytorch/ignite/blob/v0.4.2/ignite/metrics/epoch_metric.py#L13.

Args:
metric_fn: callable function or class to compute raw metric results after every iteration.
expect to return a Tensor with shape (batch, channel, ...) or tuple (Tensor, not_nans).
output_transform: transform the ignite.engine.state.output into [y_pred, y] pair.
device: device specification in case of distributed computation usage.

"""

def __init__(
self,
metric_fn: Callable,
output_transform: Callable = lambda x: x,
device: Optional[torch.device] = None,
) -> None:
self._is_reduced: bool = False
self.metric_fn = metric_fn
self._scores: List = []
super().__init__(output_transform, device=device)

@reinit__is_reduced
def reset(self) -> None:
self._scores = []

@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
"""
Args:
output: sequence with contents [y_pred, y].

Raises:
ValueError: When ``output`` length is not 2. metric_fn can only support y_pred and y.

"""
if len(output) != 2:
raise ValueError(f"output must have length 2, got {len(output)}.")
y_pred, y = output
score = self.metric_fn(y_pred, y)
if isinstance(score, (tuple, list)):
score = score[0]
self._scores.append(score)

def compute(self) -> Any:
"""
Raises:
NotComputableError: When ``compute`` is called before an ``update`` occurs.

"""
_scores = torch.cat(self._scores, dim=0)

ws = idist.get_world_size()
if ws > 1 and not self._is_reduced:
# make sure the _scores is evenly-divisible on multi-GPUs
length = _scores.shape[0]
max_len = max(idist.all_gather(length)).item()
if length < max_len:
size = [max_len - length] + list(_scores.shape[1:])
_scores = torch.cat([_scores, _scores.new_full(size, float("NaN"))], dim=0)

# all gather across all processes
_scores = idist.all_gather(_scores)
self._is_reduced = True

result: torch.Tensor = torch.zeros(1)
if idist.get_rank() == 0:
# run compute_fn on zero rank only
result = self._reduce(_scores)

if ws > 1:
# broadcast result to all processes
result = idist.broadcast(result, src=0)

return result.item() if torch.is_tensor(result) else result

def _reduce(self, scores) -> Any:
return do_metric_reduction(scores, MetricReduction.MEAN)[0]
56 changes: 7 additions & 49 deletions monai/handlers/mean_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Optional, Sequence
from typing import Callable, Optional

import torch

from monai.handlers.iteration_metric import IterationMetric
from monai.metrics import DiceMetric
from monai.utils import MetricReduction, exact_version, optional_import
from monai.utils import MetricReduction

NotComputableError, _ = optional_import("ignite.exceptions", "0.4.2", exact_version, "NotComputableError")
Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric")
reinit__is_reduced, _ = optional_import("ignite.metrics.metric", "0.4.2", exact_version, "reinit__is_reduced")
sync_all_reduce, _ = optional_import("ignite.metrics.metric", "0.4.2", exact_version, "sync_all_reduce")


class MeanDice(Metric): # type: ignore[valid-type, misc] # due to optional_import
class MeanDice(IterationMetric):
"""
Computes Dice score metric from full size Tensor and collects average over batch, class-channels, iterations.
"""
Expand All @@ -44,46 +40,8 @@ def __init__(
See also:
:py:meth:`monai.metrics.meandice.compute_meandice`
"""
super().__init__(output_transform, device=device)
self.dice = DiceMetric(
metric_fn = DiceMetric(
include_background=include_background,
reduction=MetricReduction.MEAN,
reduction=MetricReduction.NONE,
)
self._sum = 0.0
self._num_examples = 0

@reinit__is_reduced
def reset(self) -> None:
self._sum = 0.0
self._num_examples = 0

@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
"""
Args:
output: sequence with contents [y_pred, y].

Raises:
ValueError: When ``output`` length is not 2. MeanDice metric can only support y_pred and y.

"""
if len(output) != 2:
raise ValueError(f"output must have length 2, got {len(output)}.")
y_pred, y = output
score, not_nans = self.dice(y_pred, y)
not_nans = int(not_nans.item())

# add all items in current batch
self._sum += score.item() * not_nans
self._num_examples += not_nans

@sync_all_reduce("_sum", "_num_examples")
def compute(self) -> float:
"""
Raises:
NotComputableError: When ``compute`` is called before an ``update`` occurs.

"""
if self._num_examples == 0:
raise NotComputableError("MeanDice must have at least one example before it can be computed.")
return self._sum / self._num_examples
super().__init__(metric_fn=metric_fn, output_transform=output_transform, device=device)
Loading