Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion docs/source/handlers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,24 @@ ROC AUC metrics handler
:members:


Confusion Matrix metrics handler
Confusion matrix metrics handler
--------------------------------
.. autoclass:: ConfusionMatrix
:members:


Hausdorff distance metrics handler
----------------------------------
.. autoclass:: HausdorffDistance
:members:


Surface distance metrics handler
--------------------------------
.. autoclass:: SurfaceDistance
:members:


Metric logger
-------------
.. autoclass:: MetricLogger
Expand Down
12 changes: 9 additions & 3 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,27 @@ Metrics
--------------------------
.. autofunction:: compute_roc_auc

`Confusion Matrix`
`Confusion matrix`
------------------
.. autofunction:: get_confusion_matrix

.. autoclass:: ConfusionMatrixMetric
:members:

`Hausdorff Distance`
`Hausdorff distance`
--------------------
.. autofunction:: compute_hausdorff_distance

`Average Surface Distance`
.. autoclass:: HausdorffDistanceMetric
:members:

`Average surface distance`
--------------------------
.. autofunction:: compute_average_surface_distance

.. autoclass:: SurfaceDistanceMetric
:members:

`Occlusion sensitivity`
-----------------------
.. autofunction:: compute_occlusion_sensitivity
23 changes: 12 additions & 11 deletions monai/apps/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from monai.transforms import LoadImaged, Randomizable
from monai.utils import ensure_tuple

__all__ = ["MedNISTDataset", "DecathlonDataset", "CrossValidation"]


class MedNISTDataset(Randomizable, CacheDataset):
"""
Expand Down Expand Up @@ -121,7 +123,7 @@ def _generate_data_list(self, dataset_dir: str) -> List[Dict]:
image_class.extend([i] * num_each[i])
num_total = len(image_class)

data = list()
data = []

for i in range(num_total):
self.randomize()
Expand Down Expand Up @@ -302,18 +304,17 @@ def _generate_data_list(self, dataset_dir: str) -> List[Dict]:
def _split_datalist(self, datalist: List[Dict]) -> List[Dict]:
if self.section == "test":
return datalist
else:
length = len(datalist)
indices = np.arange(length)
self.randomize(indices)
length = len(datalist)
indices = np.arange(length)
self.randomize(indices)

val_length = int(length * self.val_frac)
if self.section == "training":
self.indices = indices[val_length:]
else:
self.indices = indices[:val_length]
val_length = int(length * self.val_frac)
if self.section == "training":
self.indices = indices[val_length:]
else:
self.indices = indices[:val_length]

return [datalist[i] for i in self.indices]
return [datalist[i] for i in self.indices]


class CrossValidation:
Expand Down
7 changes: 7 additions & 0 deletions monai/apps/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@
else:
tqdm, has_tqdm = optional_import("tqdm", "4.47.0", min_version, "tqdm")

__all__ = [
"check_hash",
"download_url",
"extractall",
"download_and_extract",
]


def check_hash(filepath: str, val: Optional[str] = None, hash_type: str = "md5") -> bool:
"""
Expand Down
4 changes: 2 additions & 2 deletions monai/config/deviceconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def set_visible_devices(*dev_inds):

def _dict_append(in_dict, key, fn):
try:
in_dict[key] = fn()
in_dict[key] = fn() if callable(fn) else fn
except BaseException:
in_dict[key] = "UNKNOWN for given OS"

Expand Down Expand Up @@ -197,7 +197,7 @@ def get_gpu_info() -> OrderedDict:
_dict_append(output, "Current device", lambda: torch.cuda.current_device())
_dict_append(output, "Library compiled for CUDA architectures", lambda: torch.cuda.get_arch_list())
for gpu in range(num_gpus):
_dict_append(output, "Info for GPU", lambda: gpu)
_dict_append(output, "Info for GPU", gpu)
gpu_info = torch.cuda.get_device_properties(gpu)
_dict_append(output, "\tName", lambda: gpu_info.name)
_dict_append(output, "\tIs integrated", lambda: bool(gpu_info.is_integrated))
Expand Down
2 changes: 2 additions & 0 deletions monai/data/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
Nifti1Image, _ = optional_import("nibabel.nifti1", name="Nifti1Image")
PILImage, has_pil = optional_import("PIL.Image")

__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader"]


class ImageReader(ABC):
"""Abstract class to define interface APIs to load image files.
Expand Down
14 changes: 6 additions & 8 deletions monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,7 @@ def dense_patch_slices(
dim_starts.append(start_idx)
starts.append(dim_starts)
out = np.asarray([x.flatten() for x in np.meshgrid(*starts, indexing="ij")]).T
slices = [tuple(slice(s, s + patch_size[d]) for d, s in enumerate(x)) for x in out]
return slices
return [tuple(slice(s, s + patch_size[d]) for d, s in enumerate(x)) for x in out]


def iter_patch(
Expand Down Expand Up @@ -550,7 +549,7 @@ def is_supported_format(filename: Union[Sequence[str], str], suffixes: Sequence[
filenames: Sequence[str] = ensure_tuple(filename)
for name in filenames:
tokens: Sequence[str] = PurePath(name).suffixes
if len(tokens) == 0 or not any(("." + s.lower()) in "".join(tokens) for s in suffixes):
if len(tokens) == 0 or all("." + s.lower() not in "".join(tokens) for s in suffixes):
return False

return True
Expand Down Expand Up @@ -598,7 +597,7 @@ def partition_dataset(

"""
data_len = len(data)
datasets = list()
datasets = []

indices = list(range(data_len))
if shuffle:
Expand Down Expand Up @@ -682,7 +681,7 @@ def partition_dataset_classes(
"""
if not classes or len(classes) != len(data):
raise ValueError(f"length of classes {classes} must match the dataset length {len(data)}.")
datasets = list()
datasets = []
class_indices = defaultdict(list)
for i, c in enumerate(classes):
class_indices[c].append(i)
Expand All @@ -698,7 +697,7 @@ def partition_dataset_classes(
drop_last=drop_last,
even_divisible=even_divisible,
)
if len(class_partition_indices) == 0:
if not class_partition_indices:
class_partition_indices = per_class_partition_indices
else:
for part, data_indices in zip(class_partition_indices, per_class_partition_indices):
Expand Down Expand Up @@ -735,8 +734,7 @@ def select_cross_validation_folds(partitions: Sequence[Iterable], folds: Union[S
>>> select_cross_validation_folds(partitions, [-1, 2])
[9, 10, 5, 6]
"""
data_list = [data_item for fold_id in ensure_tuple(folds) for data_item in partitions[fold_id]]
return data_list
return [data_item for fold_id in ensure_tuple(folds) for data_item in partitions[fold_id]]


class DistributedSampler(_TorchDistributedSampler):
Expand Down
2 changes: 2 additions & 0 deletions monai/engines/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine")
Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric")

__all__ = ["Evaluator", "SupervisedEvaluator", "EnsembleEvaluator"]


class Evaluator(Workflow):
"""
Expand Down
5 changes: 5 additions & 0 deletions monai/engines/multi_gpu_supervised_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@
Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine")
Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric")

__all__ = [
"create_multigpu_supervised_trainer",
"create_multigpu_supervised_evaluator",
]


def _default_transform(_x: torch.Tensor, _y: torch.Tensor, _y_pred: torch.Tensor, loss: torch.Tensor) -> float:
return loss.item()
Expand Down
2 changes: 2 additions & 0 deletions monai/engines/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine")
Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric")

__all__ = ["Trainer", "SupervisedTrainer", "GanTrainer"]


class Trainer(Workflow):
"""
Expand Down
2 changes: 2 additions & 0 deletions monai/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
from .checkpoint_saver import CheckpointSaver
from .classification_saver import ClassificationSaver
from .confusion_matrix import ConfusionMatrix
from .hausdorff_distance import HausdorffDistance
from .lr_schedule_handler import LrScheduleHandler
from .mean_dice import MeanDice
from .metric_logger import MetricLogger
from .roc_auc import ROCAUC
from .segmentation_saver import SegmentationSaver
from .smartcache_handler import SmartCacheHandler
from .stats_handler import StatsHandler
from .surface_distance import SurfaceDistance
from .tensorboard_handlers import TensorBoardImageHandler, TensorBoardStatsHandler
from .utils import *
from .validation_handler import ValidationHandler
99 changes: 99 additions & 0 deletions monai/handlers/hausdorff_distance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Optional, Sequence

import torch

from monai.metrics import HausdorffDistanceMetric
from monai.utils import MetricReduction, exact_version, optional_import

NotComputableError, _ = optional_import("ignite.exceptions", "0.4.2", exact_version, "NotComputableError")
Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric")
reinit__is_reduced, _ = optional_import("ignite.metrics.metric", "0.4.2", exact_version, "reinit__is_reduced")
sync_all_reduce, _ = optional_import("ignite.metrics.metric", "0.4.2", exact_version, "sync_all_reduce")


class HausdorffDistance(Metric): # type: ignore[valid-type, misc] # due to optional_import
"""
Computes Hausdorff distance from full size Tensor and collects average over batch, class-channels, iterations.
"""

def __init__(
self,
include_background: bool = False,
distance_metric: str = "euclidean",
percentile: Optional[float] = None,
directed: bool = False,
output_transform: Callable = lambda x: x,
device: Optional[torch.device] = None,
) -> None:
"""

Args:
include_background: whether to include distance computation on the first channel of the predicted output.
Defaults to ``False``.
distance_metric: : [``"euclidean"``, ``"chessboard"``, ``"taxicab"``]
the metric used to compute surface distance. Defaults to ``"euclidean"``.
percentile: an optional float number between 0 and 100. If specified, the corresponding
percentile of the Hausdorff Distance rather than the maximum result will be achieved.
Defaults to ``None``.
directed: whether to calculate directed Hausdorff distance. Defaults to ``False``.
output_transform: transform the ignite.engine.state.output into [y_pred, y] pair.
device: device specification in case of distributed computation usage.

"""
super().__init__(output_transform, device=device)
self.hd = HausdorffDistanceMetric(
include_background=include_background,
distance_metric=distance_metric,
percentile=percentile,
directed=directed,
reduction=MetricReduction.MEAN,
)
self._sum = 0.0
self._num_examples = 0

@reinit__is_reduced
def reset(self) -> None:
self._sum = 0.0
self._num_examples = 0

@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
"""
Args:
output: sequence with contents [y_pred, y].

Raises:
ValueError: When ``output`` length is not 2. The metric can only support y_pred and y.

"""
if len(output) != 2:
raise ValueError(f"output must have length 2, got {len(output)}.")
y_pred, y = output
score, not_nans = self.hd(y_pred, y)
not_nans = int(not_nans.item())

# add all items in current batch
self._sum += score.item() * not_nans
self._num_examples += not_nans

@sync_all_reduce("_sum", "_num_examples")
def compute(self) -> float:
"""
Raises:
NotComputableError: When ``compute`` is called before an ``update`` occurs.

"""
if self._num_examples == 0:
raise NotComputableError("HausdorffDistance must have at least one example before it can be computed.")
return self._sum / self._num_examples
Loading