From 6f23f38609802a089119ba14ee1254e48b626c9a Mon Sep 17 00:00:00 2001 From: yiheng-wang-nv Date: Fri, 11 Dec 2020 17:44:43 +0800 Subject: [PATCH 1/7] Update hausdorff metric Signed-off-by: yiheng-wang-nv --- monai/metrics/__init__.py | 2 +- monai/metrics/hausdorff_distance.py | 145 ++++++++++++++++++++++------ monai/metrics/meandice.py | 2 +- monai/metrics/surface_distance.py | 9 +- monai/metrics/utils.py | 21 +--- tests/test_hausdorff_distance.py | 42 ++++---- 6 files changed, 142 insertions(+), 79 deletions(-) diff --git a/monai/metrics/__init__.py b/monai/metrics/__init__.py index 61e288cf0c..3ae46addde 100644 --- a/monai/metrics/__init__.py +++ b/monai/metrics/__init__.py @@ -10,7 +10,7 @@ # limitations under the License. from .confusion_matrix import ConfusionMatrixMetric, compute_confusion_matrix_metric, get_confusion_matrix -from .hausdorff_distance import compute_hausdorff_distance +from .hausdorff_distance import HausdorffDistance, compute_hausdorff_distance from .meandice import DiceMetric, compute_meandice from .occlusion_sensitivity import compute_occlusion_sensitivity from .rocauc import compute_roc_auc diff --git a/monai/metrics/hausdorff_distance.py b/monai/metrics/hausdorff_distance.py index 1cfaea2449..a019a3b107 100644 --- a/monai/metrics/hausdorff_distance.py +++ b/monai/metrics/hausdorff_distance.py @@ -9,73 +9,153 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from typing import Optional, Union import numpy as np import torch -from .utils import get_mask_edges, get_surface_distance +from monai.metrics.utils import * + + +class HausdorffDistance: + """ + Compute Hausdorff Distance between two tensors. It can support both multi-classes and multi-labels tasks. + It supports both directed and non-directed Hausdorff distance calculation. In addition, specify the `percentile` + parameter can get the percentile of the distance. + Input `y_pred` (BNHW[D] where N is number of classes) is compared with ground truth `y` (BNHW[D]). + `y_preds` is expected to have binarized predictions and `y` should be in one-hot format. + You can use suitable transforms in ``monai.transforms.post`` first to achieve binarized values. + + Args: + include_background: whether to skip distance computation on the first channel of + the predicted output. Defaults to ``True``. + distance_metric: : [``"euclidean"``, ``"chessboard"``, ``"taxicab"``] + the metric used to compute surface distance. Defaults to ``"euclidean"``. + percentile: an optional float number between 0 and 100. If specified, the corresponding + percentile of the Hausdorff Distance rather than the maximum result will be achieved. + Defaults to ``None``. + directed: whether to calculate directed Hausdorff distance. Defaults to ``False``. + reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``} + Define the mode to reduce computation result of 1 batch data. Defaults to ``"mean"``. + + """ + + def __init__( + self, + include_background: bool = True, + distance_metric: str = "euclidean", + percentile: Optional[float] = None, + directed: bool = False, + reduction: Union[MetricReduction, str] = MetricReduction.MEAN, + ) -> None: + super().__init__() + self.include_background = include_background + self.distance_metric = distance_metric + self.percentile = percentile + self.directed = directed + self.reduction = reduction + + def __call__(self, y_pred: torch.Tensor, y: torch.Tensor): + """ + Args: + y_pred: input data to compute, typical segmentation model output. + It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values + should be binarized. + y: ground truth to compute mean dice metric. It must be one-hot format and first dim is batch. + The values should be binarized. + + Raises: + ValueError: when `y` is not a binarized tensor. + ValueError: when `y_pred` has less than three dimensions. + """ + if not torch.all(y_pred.byte() == y_pred): + warnings.warn("y_pred is not a binarized tensor here!") + if not torch.all(y.byte() == y): + raise ValueError("y should be a binarized tensor.") + dims = y_pred.ndimension() + if dims < 3: + raise ValueError("y_pred should have at least three dimensions.") + # compute dice (BxC) for each channel for each batch + f = compute_hausdorff_distance( + y_pred=y_pred, + y=y, + include_background=self.include_background, + distance_metric=self.distance_metric, + percentile=self.percentile, + directed=self.directed, + ) + + # do metric reduction + f, not_nans = do_metric_reduction(f, self.reduction) + return f, not_nans def compute_hausdorff_distance( - seg_pred: Union[np.ndarray, torch.Tensor], - seg_gt: Union[np.ndarray, torch.Tensor], - label_idx: int, + y_pred: Union[np.ndarray, torch.Tensor], + y: Union[np.ndarray, torch.Tensor], + include_background: bool = True, distance_metric: str = "euclidean", percentile: Optional[float] = None, directed: bool = False, ): """ - Compute the Hausdorff distance. The user has the option to calculate the - directed or non-directed Hausdorff distance. By default, the non-directed - Hausdorff distance is calculated. In addition, specify the `percentile` - parameter can get the percentile of the distance. + Compute the Hausdorff distance. Args: - seg_pred: the predicted binary or labelfield image. - seg_gt: the actual binary or labelfield image. - label_idx: for labelfield images, convert to binary with - `seg_pred = seg_pred == label_idx`. + y_pred: input data to compute, typical segmentation model output. + It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values + should be binarized. + y: ground truth to compute mean dice metric. It must be one-hot format and first dim is batch. + The values should be binarized. + include_background: whether to skip distance computation on the first channel of + the predicted output. Defaults to ``True``. distance_metric: : [``"euclidean"``, ``"chessboard"``, ``"taxicab"``] the metric used to compute surface distance. Defaults to ``"euclidean"``. percentile: an optional float number between 0 and 100. If specified, the corresponding percentile of the Hausdorff Distance rather than the maximum result will be achieved. Defaults to ``None``. - directed: calculate directed Hausdorff distance. Defaults to ``False``. + directed: whether to calculate directed Hausdorff distance. Defaults to ``False``. """ - (edges_pred, edges_gt) = get_mask_edges(seg_pred, seg_gt, label_idx) - hd = compute_percent_hausdorff_distance(edges_pred, edges_gt, label_idx, distance_metric, percentile) - if directed: - return hd + if not include_background: + y_pred, y = ignore_background( + y_pred=y_pred, + y=y, + ) - hd2 = compute_percent_hausdorff_distance(edges_gt, edges_pred, label_idx, distance_metric, percentile) - return max(hd, hd2) + y = y.float() + y_pred = y_pred.float() + + if y.shape != y_pred.shape: + raise ValueError("y_pred and y should have same shapes.") + + batch_size, n_class = y_pred.shape[:2] + hd = np.empty((batch_size, n_class)) + for b, c in np.ndindex(batch_size, n_class): + (edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c]) + distance_1 = compute_percent_hausdorff_distance(edges_pred, edges_gt, distance_metric, percentile) + if directed: + hd[b, c] = distance_1 + else: + distance_2 = compute_percent_hausdorff_distance(edges_gt, edges_pred, distance_metric, percentile) + hd[b, c] = max(distance_1, distance_2) + hd = torch.from_numpy(hd).double() + return hd def compute_percent_hausdorff_distance( edges_pred: np.ndarray, edges_gt: np.ndarray, - label_idx: int, distance_metric: str = "euclidean", percentile: Optional[float] = None, ): """ This function is used to compute the directed Hausdorff distance. - - Args: - edges_pred: the edge of the predictions. - edges_gt: the edge of the ground truth. - label_idx: for labelfield images, convert to binary with - `seg_pred = seg_pred == label_idx`. - distance_metric: : [``"euclidean"``, ``"chessboard"``, ``"taxicab"``] - the metric used to compute surface distance. Defaults to ``"euclidean"``. - percentile: an optional float number between 0 and 100. If specified, the corresponding - percentile of the Hausdorff Distance rather than the maximum result will be achieved. - Defaults to ``None``. """ - surface_distance = get_surface_distance(edges_pred, edges_gt, label_idx, distance_metric=distance_metric) + surface_distance = get_surface_distance(edges_pred, edges_gt, distance_metric=distance_metric) # for input without foreground if surface_distance.shape == (0,): @@ -83,6 +163,7 @@ def compute_percent_hausdorff_distance( if not percentile: return surface_distance.max() + elif 0 <= percentile <= 100: return np.percentile(surface_distance, percentile) else: diff --git a/monai/metrics/meandice.py b/monai/metrics/meandice.py index 18382e7849..b530c425ee 100644 --- a/monai/metrics/meandice.py +++ b/monai/metrics/meandice.py @@ -30,7 +30,7 @@ class DiceMetric: Args: include_background: whether to skip Dice computation on the first channel of - the predicted output. Defaults to True. + the predicted output. Defaults to ``True``. reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, ``"mean_channel"``, ``"sum_channel"``} Define the mode to reduce computation result of 1 batch data. Defaults to ``"mean"``. diff --git a/monai/metrics/surface_distance.py b/monai/metrics/surface_distance.py index 7914364b9c..2a7de0b9b7 100644 --- a/monai/metrics/surface_distance.py +++ b/monai/metrics/surface_distance.py @@ -20,7 +20,6 @@ def compute_average_surface_distance( seg_pred: Union[np.ndarray, torch.Tensor], seg_gt: Union[np.ndarray, torch.Tensor], - label_idx: int, symmetric: bool = False, distance_metric: str = "euclidean", ): @@ -33,15 +32,13 @@ def compute_average_surface_distance( Args: seg_pred: first binary or labelfield image. seg_gt: second binary or labelfield image. - label_idx: for labelfield images, convert to binary with - `seg_pred = seg_pred == label_idx`. symmetric: if calculate the symmetric average surface distance between `seg_pred` and `seg_gt`. Defaults to ``False``. distance_metric: : [``"euclidean"``, ``"chessboard"``, ``"taxicab"``] the metric used to compute surface distance. Defaults to ``"euclidean"``. """ - (edges_pred, edges_gt) = get_mask_edges(seg_pred, seg_gt, label_idx) - surface_distance = get_surface_distance(edges_pred, edges_gt, label_idx, distance_metric=distance_metric) + (edges_pred, edges_gt) = get_mask_edges(seg_pred, seg_gt) + surface_distance = get_surface_distance(edges_pred, edges_gt, distance_metric=distance_metric) if surface_distance.shape == (0,): return np.inf @@ -49,7 +46,7 @@ def compute_average_surface_distance( if not symmetric: return avg_surface_distance - surface_distance_2 = get_surface_distance(edges_gt, edges_pred, label_idx, distance_metric=distance_metric) + surface_distance_2 = get_surface_distance(edges_gt, edges_pred, distance_metric=distance_metric) if surface_distance_2.shape == (0,): return np.inf diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index 08450fa355..30c9a6fb3c 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -65,7 +65,7 @@ def do_metric_reduction( not_nans = (~nans).float() f[nans] = 0 - t_zero = torch.zeros(1, device=f.device, dtype=torch.float) + t_zero = torch.zeros(1, device=f.device, dtype=torch.float64) reduction = MetricReduction(reduction) if reduction == MetricReduction.MEAN: @@ -104,7 +104,7 @@ def do_metric_reduction( def get_mask_edges( seg_pred: Union[np.ndarray, torch.Tensor], seg_gt: Union[np.ndarray, torch.Tensor], - label_idx: int, + label_idx: int = 1, crop: bool = True, ) -> Tuple[np.ndarray, np.ndarray]: """ @@ -141,9 +141,8 @@ def get_mask_edges( if torch.is_tensor(seg_gt): seg_gt = seg_gt.detach().cpu().numpy() - # Check non-zero number of elements and same shape - if seg_pred.size == 0 or seg_pred.shape != seg_gt.shape: - raise ValueError("Labelfields should have same shape (and non-zero number of elements)") + if seg_pred.shape != seg_gt.shape: + raise ValueError("seg_pred and seg_gt should have same shapes.") # If not binary images, convert them if seg_pred.dtype != bool: @@ -170,26 +169,14 @@ def get_mask_edges( def get_surface_distance( edges_pred: np.ndarray, edges_gt: np.ndarray, - label_idx: int, - crop: bool = True, distance_metric: str = "euclidean", ) -> np.ndarray: """ This function is used to compute the surface distances from `seg_pred` to `seg_gt`. - In order to improve the computing efficiency, before getting the edges, - the images can be cropped and only keep the foreground if not specifies - ``crop = False``. - Args: edges_pred: the edge of the predictions. edges_gt: the edge of the ground truth. - label_idx: for labelfield images, convert to binary with - `seg_pred = seg_pred == label_idx`. - crop: crop input images and only keep the foregrounds. In order to - maintain two inputs' shapes, here the bounding box is achieved - by ``(seg_pred | seg_gt)`` which represents the union set of two - images. Defaults to ``True``. distance_metric: : [``"euclidean"``, ``"chessboard"``, ``"taxicab"``] the metric used to compute surface distance. Defaults to ``"euclidean"``. diff --git a/tests/test_hausdorff_distance.py b/tests/test_hausdorff_distance.py index dda1186612..969e1103bf 100644 --- a/tests/test_hausdorff_distance.py +++ b/tests/test_hausdorff_distance.py @@ -13,27 +13,24 @@ from typing import Tuple import numpy as np +import torch from parameterized import parameterized -from monai.metrics import compute_hausdorff_distance +from monai.metrics import HausdorffDistance def create_spherical_seg_3d( radius: float = 20.0, centre: Tuple[int, int, int] = (49, 49, 49), - labelfield_value: int = 1, - background_value: int = 0, im_shape: Tuple[int, int, int] = (99, 99, 99), ) -> np.ndarray: """ Return a 3D image with a sphere inside. Voxel values will be - `labelfield_value` inside the sphere, and `background_value` elsewhere. + 1 inside the sphere, and 0 elsewhere. Args: radius: radius of sphere (in terms of number of voxels, can be partial) centre: location of sphere centre. - labelfield_value: index of labelfield. - background_value: index of background. im_shape: shape of image to create See also: @@ -46,8 +43,8 @@ def create_spherical_seg_3d( ] circle = (spx * spx + spy * spy + spz * spz) <= radius * radius - image[circle] = labelfield_value - image[~circle] = background_value + image[circle] = 1 + image[~circle] = 0 return image @@ -60,15 +57,13 @@ def create_spherical_seg_3d( [ create_spherical_seg_3d(radius=20, centre=(20, 20, 20)), create_spherical_seg_3d(radius=20, centre=(19, 19, 19)), - 1, ], [1.7320508075688772, 1.7320508075688772, 1, 1, 3, 3], ], [ [ - create_spherical_seg_3d(radius=33, labelfield_value=2, centre=(19, 33, 22)), - create_spherical_seg_3d(radius=33, labelfield_value=2, centre=(20, 33, 22)), - 2, + create_spherical_seg_3d(radius=33, centre=(19, 33, 22)), + create_spherical_seg_3d(radius=33, centre=(20, 33, 22)), ], [1, 1, 1, 1, 1, 1], ], @@ -76,7 +71,6 @@ def create_spherical_seg_3d( [ create_spherical_seg_3d(radius=20, centre=(20, 33, 22)), create_spherical_seg_3d(radius=40, centre=(20, 33, 22)), - 1, ], [20.09975124224178, 20.223748416156685, 15, 20, 24, 35], ], @@ -84,7 +78,6 @@ def create_spherical_seg_3d( [ np.zeros([99, 99, 99]), create_spherical_seg_3d(radius=40, centre=(20, 33, 22)), - 1, ], [np.inf, np.inf, np.inf, np.inf, np.inf, np.inf], ], @@ -92,7 +85,6 @@ def create_spherical_seg_3d( [ np.zeros([99, 99, 99]), np.zeros([99, 99, 99]), - 1, ], [np.inf, np.inf, np.inf, np.inf, np.inf, np.inf], ], @@ -100,7 +92,6 @@ def create_spherical_seg_3d( [ create_spherical_seg_3d(), np.zeros([99, 99, 99]), - 1, ], [np.inf, np.inf, np.inf, np.inf, np.inf, np.inf], ], @@ -108,7 +99,6 @@ def create_spherical_seg_3d( [ create_spherical_seg_3d(radius=20, centre=(20, 33, 22)), create_spherical_seg_3d(radius=40, centre=(20, 33, 22)), - 1, 95, ], [19.924858845171276, 20.09975124224178, 14, 18, 22, 33], @@ -120,17 +110,25 @@ class TestHausdorffDistance(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_value(self, input_data, expected_value): percentile = None - if len(input_data) == 4: - [seg_1, seg_2, label_idx, percentile] = input_data + if len(input_data) == 3: + [seg_1, seg_2, percentile] = input_data else: - [seg_1, seg_2, label_idx] = input_data + [seg_1, seg_2] = input_data ct = 0 + seg_1 = torch.tensor(seg_1) + seg_2 = torch.tensor(seg_2) for metric in ["euclidean", "chessboard", "taxicab"]: for directed in [True, False]: - result = compute_hausdorff_distance( - seg_1, seg_2, label_idx, distance_metric=metric, percentile=percentile, directed=directed + hd_metric = HausdorffDistance( + include_background=False, distance_metric=metric, percentile=percentile, directed=directed ) + # shape of seg_1, seg_2 are: HWD, converts to BNHWD + batch, n_class = 2, 3 + batch_seg_1 = seg_1.unsqueeze(0).unsqueeze(0).repeat([batch, n_class, 1, 1, 1]) + batch_seg_2 = seg_2.unsqueeze(0).unsqueeze(0).repeat([batch, n_class, 1, 1, 1]) + result, _ = hd_metric(batch_seg_1, batch_seg_2) expected_value_curr = expected_value[ct] + batch_expected_value_curr = np.tile(np.asarray(expected_value_curr), [batch, n_class]) np.testing.assert_allclose(expected_value_curr, result, rtol=1e-7) ct += 1 From 3448fd3ed639490eed62689bc8124b8e18857aeb Mon Sep 17 00:00:00 2001 From: yiheng-wang-nv Date: Fri, 11 Dec 2020 18:39:03 +0800 Subject: [PATCH 2/7] Update surface distance Signed-off-by: yiheng-wang-nv --- docs/source/metrics.rst | 6 ++ monai/metrics/__init__.py | 2 +- monai/metrics/hausdorff_distance.py | 9 +- monai/metrics/surface_distance.py | 132 ++++++++++++++++++++++++---- tests/test_hausdorff_distance.py | 1 - tests/test_surface_distance.py | 44 ++++------ 6 files changed, 143 insertions(+), 51 deletions(-) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index d3f5a347c7..bdb02bd0b4 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -28,10 +28,16 @@ Metrics -------------------- .. autofunction:: compute_hausdorff_distance +.. autoclass:: HausdorffDistance + :members: + `Average Surface Distance` -------------------------- .. autofunction:: compute_average_surface_distance +.. autoclass:: SurfaceDistance + :members: + `Occlusion sensitivity` ----------------------- .. autofunction:: compute_occlusion_sensitivity \ No newline at end of file diff --git a/monai/metrics/__init__.py b/monai/metrics/__init__.py index 3ae46addde..01aaec631d 100644 --- a/monai/metrics/__init__.py +++ b/monai/metrics/__init__.py @@ -14,5 +14,5 @@ from .meandice import DiceMetric, compute_meandice from .occlusion_sensitivity import compute_occlusion_sensitivity from .rocauc import compute_roc_auc -from .surface_distance import compute_average_surface_distance +from .surface_distance import SurfaceDistance, compute_average_surface_distance from .utils import * diff --git a/monai/metrics/hausdorff_distance.py b/monai/metrics/hausdorff_distance.py index a019a3b107..23bc462783 100644 --- a/monai/metrics/hausdorff_distance.py +++ b/monai/metrics/hausdorff_distance.py @@ -63,7 +63,7 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor): y_pred: input data to compute, typical segmentation model output. It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values should be binarized. - y: ground truth to compute mean dice metric. It must be one-hot format and first dim is batch. + y: ground truth to compute the distance. It must be one-hot format and first dim is batch. The values should be binarized. Raises: @@ -77,7 +77,7 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor): dims = y_pred.ndimension() if dims < 3: raise ValueError("y_pred should have at least three dimensions.") - # compute dice (BxC) for each channel for each batch + # compute (BxC) for each channel for each batch f = compute_hausdorff_distance( y_pred=y_pred, y=y, @@ -107,7 +107,7 @@ def compute_hausdorff_distance( y_pred: input data to compute, typical segmentation model output. It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values should be binarized. - y: ground truth to compute mean dice metric. It must be one-hot format and first dim is batch. + y: ground truth to compute mean the distance. It must be one-hot format and first dim is batch. The values should be binarized. include_background: whether to skip distance computation on the first channel of the predicted output. Defaults to ``True``. @@ -141,8 +141,7 @@ def compute_hausdorff_distance( else: distance_2 = compute_percent_hausdorff_distance(edges_gt, edges_pred, distance_metric, percentile) hd[b, c] = max(distance_1, distance_2) - hd = torch.from_numpy(hd).double() - return hd + return torch.from_numpy(hd) def compute_percent_hausdorff_distance( diff --git a/monai/metrics/surface_distance.py b/monai/metrics/surface_distance.py index 2a7de0b9b7..100c1cb2b3 100644 --- a/monai/metrics/surface_distance.py +++ b/monai/metrics/surface_distance.py @@ -9,46 +9,140 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from typing import Union import numpy as np import torch -from .utils import get_mask_edges, get_surface_distance +from monai.metrics.utils import * + + +class SurfaceDistance: + """ + Compute Surface Distance between two tensors. It can support both multi-classes and multi-labels tasks. + It supports both symmetric and asymmetric surface distance calculation. + Input `y_pred` (BNHW[D] where N is number of classes) is compared with ground truth `y` (BNHW[D]). + `y_preds` is expected to have binarized predictions and `y` should be in one-hot format. + You can use suitable transforms in ``monai.transforms.post`` first to achieve binarized values. + + Args: + include_background: whether to skip distance computation on the first channel of + the predicted output. Defaults to ``True``. + symmetric: whether to calculate the symmetric average surface distance between + `seg_pred` and `seg_gt`. Defaults to ``False``. + distance_metric: : [``"euclidean"``, ``"chessboard"``, ``"taxicab"``] + the metric used to compute surface distance. Defaults to ``"euclidean"``. + reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``} + Define the mode to reduce computation result of 1 batch data. Defaults to ``"mean"``. + + """ + + def __init__( + self, + include_background: bool = True, + symmetric: bool = False, + distance_metric: str = "euclidean", + reduction: Union[MetricReduction, str] = MetricReduction.MEAN, + ) -> None: + super().__init__() + self.include_background = include_background + self.distance_metric = distance_metric + self.symmetric = symmetric + self.reduction = reduction + + def __call__(self, y_pred: torch.Tensor, y: torch.Tensor): + """ + Args: + y_pred: input data to compute, typical segmentation model output. + It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values + should be binarized. + y: ground truth to compute the distance. It must be one-hot format and first dim is batch. + The values should be binarized. + + Raises: + ValueError: when `y` is not a binarized tensor. + ValueError: when `y_pred` has less than three dimensions. + """ + if not torch.all(y_pred.byte() == y_pred): + warnings.warn("y_pred is not a binarized tensor here!") + if not torch.all(y.byte() == y): + raise ValueError("y should be a binarized tensor.") + dims = y_pred.ndimension() + if dims < 3: + raise ValueError("y_pred should have at least three dimensions.") + # compute (BxC) for each channel for each batch + f = compute_average_surface_distance( + y_pred=y_pred, + y=y, + include_background=self.include_background, + symmetric=self.symmetric, + distance_metric=self.distance_metric, + ) + + # do metric reduction + f, not_nans = do_metric_reduction(f, self.reduction) + return f, not_nans def compute_average_surface_distance( - seg_pred: Union[np.ndarray, torch.Tensor], - seg_gt: Union[np.ndarray, torch.Tensor], + y_pred: Union[np.ndarray, torch.Tensor], + y: Union[np.ndarray, torch.Tensor], + include_background: bool = True, symmetric: bool = False, distance_metric: str = "euclidean", ): """ - This function is used to compute the Average Surface Distance from `seg_pred` to `seg_gt` + This function is used to compute the Average Surface Distance from `y_pred` to `y` under the default setting. In addition, if sets ``symmetric = True``, the average symmetric surface distance between these two inputs will be returned. Args: - seg_pred: first binary or labelfield image. - seg_gt: second binary or labelfield image. - symmetric: if calculate the symmetric average surface distance between + y_pred: input data to compute, typical segmentation model output. + It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values + should be binarized. + y: ground truth to compute mean the distance. It must be one-hot format and first dim is batch. + The values should be binarized. + include_background: whether to skip distance computation on the first channel of + the predicted output. Defaults to ``True``. + symmetric: whether to calculate the symmetric average surface distance between `seg_pred` and `seg_gt`. Defaults to ``False``. distance_metric: : [``"euclidean"``, ``"chessboard"``, ``"taxicab"``] the metric used to compute surface distance. Defaults to ``"euclidean"``. """ - (edges_pred, edges_gt) = get_mask_edges(seg_pred, seg_gt) - surface_distance = get_surface_distance(edges_pred, edges_gt, distance_metric=distance_metric) - if surface_distance.shape == (0,): - return np.inf - avg_surface_distance = surface_distance.mean() - if not symmetric: - return avg_surface_distance + if not include_background: + y_pred, y = ignore_background( + y_pred=y_pred, + y=y, + ) + + y = y.float() + y_pred = y_pred.float() + + if y.shape != y_pred.shape: + raise ValueError("y_pred and y should have same shapes.") + + batch_size, n_class = y_pred.shape[:2] + asd = np.empty((batch_size, n_class)) - surface_distance_2 = get_surface_distance(edges_gt, edges_pred, distance_metric=distance_metric) - if surface_distance_2.shape == (0,): - return np.inf + for b, c in np.ndindex(batch_size, n_class): + (edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c]) + surface_distance = get_surface_distance(edges_pred, edges_gt, distance_metric=distance_metric) + if surface_distance.shape == (0,): + avg_surface_distance = np.inf + else: + avg_surface_distance = surface_distance.mean() + if not symmetric: + asd[b, c] = avg_surface_distance + else: + surface_distance_2 = get_surface_distance(edges_gt, edges_pred, distance_metric=distance_metric) + if surface_distance_2.shape == (0,): + avg_surface_distance_2 = np.inf + else: + avg_surface_distance_2 = surface_distance_2.mean() + asd[b, c] = np.mean((avg_surface_distance, avg_surface_distance_2)) - avg_surface_distance_2 = surface_distance_2.mean() - return np.mean((avg_surface_distance, avg_surface_distance_2)) + return torch.from_numpy(asd) diff --git a/tests/test_hausdorff_distance.py b/tests/test_hausdorff_distance.py index 969e1103bf..8d37acb577 100644 --- a/tests/test_hausdorff_distance.py +++ b/tests/test_hausdorff_distance.py @@ -128,7 +128,6 @@ def test_value(self, input_data, expected_value): batch_seg_2 = seg_2.unsqueeze(0).unsqueeze(0).repeat([batch, n_class, 1, 1, 1]) result, _ = hd_metric(batch_seg_1, batch_seg_2) expected_value_curr = expected_value[ct] - batch_expected_value_curr = np.tile(np.asarray(expected_value_curr), [batch, n_class]) np.testing.assert_allclose(expected_value_curr, result, rtol=1e-7) ct += 1 diff --git a/tests/test_surface_distance.py b/tests/test_surface_distance.py index 8b16dc4f35..5548e7907f 100644 --- a/tests/test_surface_distance.py +++ b/tests/test_surface_distance.py @@ -13,27 +13,24 @@ from typing import Tuple import numpy as np +import torch from parameterized import parameterized -from monai.metrics import compute_average_surface_distance +from monai.metrics import SurfaceDistance def create_spherical_seg_3d( radius: float = 20.0, centre: Tuple[int, int, int] = (49, 49, 49), - labelfield_value: int = 1, - background_value: int = 0, im_shape: Tuple[int, int, int] = (99, 99, 99), ) -> np.ndarray: """ Return a 3D image with a sphere inside. Voxel values will be - `labelfield_value` inside the sphere, and `background_value` elsewhere. + 1 inside the sphere, and 0 elsewhere. Args: radius: radius of sphere (in terms of number of voxels, can be partial) centre: location of sphere centre. - labelfield_value: index of labelfield. - background_value: index of background. im_shape: shape of image to create See also: @@ -46,30 +43,28 @@ def create_spherical_seg_3d( ] circle = (spx * spx + spy * spy + spz * spz) <= radius * radius - image[circle] = labelfield_value - image[~circle] = background_value + image[circle] = 1 + image[~circle] = 0 return image TEST_CASES = [ [ - [create_spherical_seg_3d(), create_spherical_seg_3d(), 1], + [create_spherical_seg_3d(), create_spherical_seg_3d()], [0, 0], ], [ [ create_spherical_seg_3d(radius=20, centre=(20, 20, 20)), create_spherical_seg_3d(radius=20, centre=(19, 19, 19)), - 1, "taxicab", ], [1.0380029806259314, 1.0380029806259314], ], [ [ - create_spherical_seg_3d(radius=33, labelfield_value=2, centre=(19, 33, 22)), - create_spherical_seg_3d(radius=33, labelfield_value=2, centre=(20, 33, 22)), - 2, + create_spherical_seg_3d(radius=33, centre=(19, 33, 22)), + create_spherical_seg_3d(radius=33, centre=(20, 33, 22)), ], [0.35021200688332677, 0.3483278807706289], ], @@ -77,7 +72,6 @@ def create_spherical_seg_3d( [ create_spherical_seg_3d(radius=20, centre=(20, 33, 22)), create_spherical_seg_3d(radius=40, centre=(20, 33, 22)), - 1, ], [13.975673696300824, 12.040033513150455], ], @@ -85,7 +79,6 @@ def create_spherical_seg_3d( [ create_spherical_seg_3d(radius=20, centre=(20, 33, 22)), create_spherical_seg_3d(radius=40, centre=(20, 33, 22)), - 1, "chessboard", ], [10.792254295459173, 9.605067064083457], @@ -94,7 +87,6 @@ def create_spherical_seg_3d( [ create_spherical_seg_3d(radius=20, centre=(20, 33, 22)), create_spherical_seg_3d(radius=40, centre=(20, 33, 22)), - 1, "taxicab", ], [17.32691760951026, 12.432687531048186], @@ -103,7 +95,6 @@ def create_spherical_seg_3d( [ np.zeros([99, 99, 99]), create_spherical_seg_3d(radius=40, centre=(20, 33, 22)), - 1, ], [np.inf, np.inf], ], @@ -111,7 +102,6 @@ def create_spherical_seg_3d( [ np.zeros([99, 99, 99]), np.zeros([99, 99, 99]), - 1, ], [np.inf, np.inf], ], @@ -119,7 +109,6 @@ def create_spherical_seg_3d( [ create_spherical_seg_3d(), np.zeros([99, 99, 99]), - 1, "taxicab", ], [np.inf, np.inf], @@ -130,17 +119,22 @@ def create_spherical_seg_3d( class TestAllSurfaceMetrics(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_value(self, input_data, expected_value): - if len(input_data) == 4: - [seg_1, seg_2, label_idx, metric] = input_data + if len(input_data) == 3: + [seg_1, seg_2, metric] = input_data else: - [seg_1, seg_2, label_idx] = input_data + [seg_1, seg_2] = input_data metric = "euclidean" ct = 0 + seg_1 = torch.tensor(seg_1) + seg_2 = torch.tensor(seg_2) for symmetric in [True, False]: + sur_metric = SurfaceDistance(include_background=False, symmetric=symmetric, distance_metric=metric) + # shape of seg_1, seg_2 are: HWD, converts to BNHWD + batch, n_class = 2, 3 + batch_seg_1 = seg_1.unsqueeze(0).unsqueeze(0).repeat([batch, n_class, 1, 1, 1]) + batch_seg_2 = seg_2.unsqueeze(0).unsqueeze(0).repeat([batch, n_class, 1, 1, 1]) + result, _ = sur_metric(batch_seg_1, batch_seg_2) expected_value_curr = expected_value[ct] - result = compute_average_surface_distance( - seg_1, seg_2, label_idx, symmetric=symmetric, distance_metric=metric - ) np.testing.assert_allclose(expected_value_curr, result, rtol=1e-7) ct += 1 From 31a6dfea16cbfadf211e6cbf0fa6cf51b1a01630 Mon Sep 17 00:00:00 2001 From: yiheng-wang-nv Date: Fri, 11 Dec 2020 22:21:11 +0800 Subject: [PATCH 3/7] fix reduction type error Signed-off-by: yiheng-wang-nv --- monai/metrics/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index 30c9a6fb3c..e3eb2eff05 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -65,7 +65,7 @@ def do_metric_reduction( not_nans = (~nans).float() f[nans] = 0 - t_zero = torch.zeros(1, device=f.device, dtype=torch.float64) + t_zero = torch.zeros(1, device=f.device, dtype=f.dtype) reduction = MetricReduction(reduction) if reduction == MetricReduction.MEAN: From 15ee69ebc771b6353ac9a4c0b6a1ae0941eeaddf Mon Sep 17 00:00:00 2001 From: yiheng-wang-nv Date: Mon, 14 Dec 2020 20:34:27 +0800 Subject: [PATCH 4/7] Add handlers for new metrics Signed-off-by: yiheng-wang-nv --- docs/source/handlers.rst | 12 +++ docs/source/metrics.rst | 4 +- monai/handlers/__init__.py | 2 + monai/handlers/hausdorff_distance.py | 99 ++++++++++++++++++++++++ monai/handlers/surface_distance.py | 95 +++++++++++++++++++++++ monai/metrics/__init__.py | 4 +- monai/metrics/hausdorff_distance.py | 16 ++-- monai/metrics/surface_distance.py | 14 ++-- monai/metrics/utils.py | 29 +++---- tests/min_tests.py | 2 + tests/test_handler_hausdorff_distance.py | 88 +++++++++++++++++++++ tests/test_handler_surface_distance.py | 88 +++++++++++++++++++++ tests/test_hausdorff_distance.py | 35 ++++++--- tests/test_surface_distance.py | 27 +++++-- 14 files changed, 468 insertions(+), 47 deletions(-) create mode 100644 monai/handlers/hausdorff_distance.py create mode 100644 monai/handlers/surface_distance.py create mode 100644 tests/test_handler_hausdorff_distance.py create mode 100644 tests/test_handler_surface_distance.py diff --git a/docs/source/handlers.rst b/docs/source/handlers.rst index 475a44de64..2eb171c00f 100644 --- a/docs/source/handlers.rst +++ b/docs/source/handlers.rst @@ -40,6 +40,18 @@ Confusion Matrix metrics handler :members: +Hausdorff Distance metrics handler +---------------------------------- +.. autoclass:: HausdorffDistance + :members: + + +Surface Distance metrics handler +-------------------------------- +.. autoclass:: SurfaceDistance + :members: + + Metric logger ------------- .. autoclass:: MetricLogger diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index bdb02bd0b4..b280b5461d 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -28,14 +28,14 @@ Metrics -------------------- .. autofunction:: compute_hausdorff_distance -.. autoclass:: HausdorffDistance +.. autoclass:: HausdorffDistanceMetric :members: `Average Surface Distance` -------------------------- .. autofunction:: compute_average_surface_distance -.. autoclass:: SurfaceDistance +.. autoclass:: SurfaceDistanceMetric :members: `Occlusion sensitivity` diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index 9becd5c5f6..37715cad52 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -13,6 +13,7 @@ from .checkpoint_saver import CheckpointSaver from .classification_saver import ClassificationSaver from .confusion_matrix import ConfusionMatrix +from .hausdorff_distance import HausdorffDistance from .lr_schedule_handler import LrScheduleHandler from .mean_dice import MeanDice from .metric_logger import MetricLogger @@ -20,6 +21,7 @@ from .segmentation_saver import SegmentationSaver from .smartcache_handler import SmartCacheHandler from .stats_handler import StatsHandler +from .surface_distance import SurfaceDistance from .tensorboard_handlers import TensorBoardImageHandler, TensorBoardStatsHandler from .utils import * from .validation_handler import ValidationHandler diff --git a/monai/handlers/hausdorff_distance.py b/monai/handlers/hausdorff_distance.py new file mode 100644 index 0000000000..56b8b341ff --- /dev/null +++ b/monai/handlers/hausdorff_distance.py @@ -0,0 +1,99 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional, Sequence + +import torch + +from monai.metrics import HausdorffDistanceMetric +from monai.utils import MetricReduction, exact_version, optional_import + +NotComputableError, _ = optional_import("ignite.exceptions", "0.4.2", exact_version, "NotComputableError") +Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric") +reinit__is_reduced, _ = optional_import("ignite.metrics.metric", "0.4.2", exact_version, "reinit__is_reduced") +sync_all_reduce, _ = optional_import("ignite.metrics.metric", "0.4.2", exact_version, "sync_all_reduce") + + +class HausdorffDistance(Metric): # type: ignore[valid-type, misc] # due to optional_import + """ + Computes Hausdorff distance from full size Tensor and collects average over batch, class-channels, iterations. + """ + + def __init__( + self, + include_background: bool = False, + distance_metric: str = "euclidean", + percentile: Optional[float] = None, + directed: bool = False, + output_transform: Callable = lambda x: x, + device: Optional[torch.device] = None, + ) -> None: + """ + + Args: + include_background: whether to include distance computation on the first channel of the predicted output. + Defaults to ``False``. + distance_metric: : [``"euclidean"``, ``"chessboard"``, ``"taxicab"``] + the metric used to compute surface distance. Defaults to ``"euclidean"``. + percentile: an optional float number between 0 and 100. If specified, the corresponding + percentile of the Hausdorff Distance rather than the maximum result will be achieved. + Defaults to ``None``. + directed: whether to calculate directed Hausdorff distance. Defaults to ``False``. + output_transform: transform the ignite.engine.state.output into [y_pred, y] pair. + device: device specification in case of distributed computation usage. + + """ + super().__init__(output_transform, device=device) + self.hd = HausdorffDistanceMetric( + include_background=include_background, + distance_metric=distance_metric, + percentile=percentile, + directed=directed, + reduction=MetricReduction.MEAN, + ) + self._sum = 0.0 + self._num_examples = 0 + + @reinit__is_reduced + def reset(self) -> None: + self._sum = 0.0 + self._num_examples = 0 + + @reinit__is_reduced + def update(self, output: Sequence[torch.Tensor]) -> None: + """ + Args: + output: sequence with contents [y_pred, y]. + + Raises: + ValueError: When ``output`` length is not 2. The metric can only support y_pred and y. + + """ + if len(output) != 2: + raise ValueError(f"output must have length 2, got {len(output)}.") + y_pred, y = output + score, not_nans = self.hd(y_pred, y) + not_nans = int(not_nans.item()) + + # add all items in current batch + self._sum += score.item() * not_nans + self._num_examples += not_nans + + @sync_all_reduce("_sum", "_num_examples") + def compute(self) -> float: + """ + Raises: + NotComputableError: When ``compute`` is called before an ``update`` occurs. + + """ + if self._num_examples == 0: + raise NotComputableError("HausdorffDistance must have at least one example before it can be computed.") + return self._sum / self._num_examples diff --git a/monai/handlers/surface_distance.py b/monai/handlers/surface_distance.py new file mode 100644 index 0000000000..b35089423c --- /dev/null +++ b/monai/handlers/surface_distance.py @@ -0,0 +1,95 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional, Sequence + +import torch + +from monai.metrics import SurfaceDistanceMetric +from monai.utils import MetricReduction, exact_version, optional_import + +NotComputableError, _ = optional_import("ignite.exceptions", "0.4.2", exact_version, "NotComputableError") +Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric") +reinit__is_reduced, _ = optional_import("ignite.metrics.metric", "0.4.2", exact_version, "reinit__is_reduced") +sync_all_reduce, _ = optional_import("ignite.metrics.metric", "0.4.2", exact_version, "sync_all_reduce") + + +class SurfaceDistance(Metric): # type: ignore[valid-type, misc] # due to optional_import + """ + Computes surface distance from full size Tensor and collects average over batch, class-channels, iterations. + """ + + def __init__( + self, + include_background: bool = False, + symmetric: bool = False, + distance_metric: str = "euclidean", + output_transform: Callable = lambda x: x, + device: Optional[torch.device] = None, + ) -> None: + """ + + Args: + include_background: whether to include distance computation on the first channel of the predicted output. + Defaults to ``False``. + symmetric: whether to calculate the symmetric average surface distance between + `seg_pred` and `seg_gt`. Defaults to ``False``. + distance_metric: : [``"euclidean"``, ``"chessboard"``, ``"taxicab"``] + the metric used to compute surface distance. Defaults to ``"euclidean"``. + output_transform: transform the ignite.engine.state.output into [y_pred, y] pair. + device: device specification in case of distributed computation usage. + + """ + super().__init__(output_transform, device=device) + self.hd = SurfaceDistanceMetric( + include_background=include_background, + symmetric=symmetric, + distance_metric=distance_metric, + reduction=MetricReduction.MEAN, + ) + self._sum = 0.0 + self._num_examples = 0 + + @reinit__is_reduced + def reset(self) -> None: + self._sum = 0.0 + self._num_examples = 0 + + @reinit__is_reduced + def update(self, output: Sequence[torch.Tensor]) -> None: + """ + Args: + output: sequence with contents [y_pred, y]. + + Raises: + ValueError: When ``output`` length is not 2. The metric can only support y_pred and y. + + """ + if len(output) != 2: + raise ValueError(f"output must have length 2, got {len(output)}.") + y_pred, y = output + score, not_nans = self.hd(y_pred, y) + not_nans = int(not_nans.item()) + + # add all items in current batch + self._sum += score.item() * not_nans + self._num_examples += not_nans + + @sync_all_reduce("_sum", "_num_examples") + def compute(self) -> float: + """ + Raises: + NotComputableError: When ``compute`` is called before an ``update`` occurs. + + """ + if self._num_examples == 0: + raise NotComputableError("SurfaceDistance must have at least one example before it can be computed.") + return self._sum / self._num_examples diff --git a/monai/metrics/__init__.py b/monai/metrics/__init__.py index 01aaec631d..345f451258 100644 --- a/monai/metrics/__init__.py +++ b/monai/metrics/__init__.py @@ -10,9 +10,9 @@ # limitations under the License. from .confusion_matrix import ConfusionMatrixMetric, compute_confusion_matrix_metric, get_confusion_matrix -from .hausdorff_distance import HausdorffDistance, compute_hausdorff_distance +from .hausdorff_distance import HausdorffDistanceMetric, compute_hausdorff_distance from .meandice import DiceMetric, compute_meandice from .occlusion_sensitivity import compute_occlusion_sensitivity from .rocauc import compute_roc_auc -from .surface_distance import SurfaceDistance, compute_average_surface_distance +from .surface_distance import SurfaceDistanceMetric, compute_average_surface_distance from .utils import * diff --git a/monai/metrics/hausdorff_distance.py b/monai/metrics/hausdorff_distance.py index 23bc462783..cb9fc25f57 100644 --- a/monai/metrics/hausdorff_distance.py +++ b/monai/metrics/hausdorff_distance.py @@ -18,7 +18,7 @@ from monai.metrics.utils import * -class HausdorffDistance: +class HausdorffDistanceMetric: """ Compute Hausdorff Distance between two tensors. It can support both multi-classes and multi-labels tasks. It supports both directed and non-directed Hausdorff distance calculation. In addition, specify the `percentile` @@ -28,8 +28,8 @@ class HausdorffDistance: You can use suitable transforms in ``monai.transforms.post`` first to achieve binarized values. Args: - include_background: whether to skip distance computation on the first channel of - the predicted output. Defaults to ``True``. + include_background: whether to include distance computation on the first channel of + the predicted output. Defaults to ``False``. distance_metric: : [``"euclidean"``, ``"chessboard"``, ``"taxicab"``] the metric used to compute surface distance. Defaults to ``"euclidean"``. percentile: an optional float number between 0 and 100. If specified, the corresponding @@ -44,7 +44,7 @@ class HausdorffDistance: def __init__( self, - include_background: bool = True, + include_background: bool = False, distance_metric: str = "euclidean", percentile: Optional[float] = None, directed: bool = False, @@ -95,7 +95,7 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor): def compute_hausdorff_distance( y_pred: Union[np.ndarray, torch.Tensor], y: Union[np.ndarray, torch.Tensor], - include_background: bool = True, + include_background: bool = False, distance_metric: str = "euclidean", percentile: Optional[float] = None, directed: bool = False, @@ -110,7 +110,7 @@ def compute_hausdorff_distance( y: ground truth to compute mean the distance. It must be one-hot format and first dim is batch. The values should be binarized. include_background: whether to skip distance computation on the first channel of - the predicted output. Defaults to ``True``. + the predicted output. Defaults to ``False``. distance_metric: : [``"euclidean"``, ``"chessboard"``, ``"taxicab"``] the metric used to compute surface distance. Defaults to ``"euclidean"``. percentile: an optional float number between 0 and 100. If specified, the corresponding @@ -156,9 +156,9 @@ def compute_percent_hausdorff_distance( surface_distance = get_surface_distance(edges_pred, edges_gt, distance_metric=distance_metric) - # for input without foreground + # for both pred and gt do not have foreground if surface_distance.shape == (0,): - return np.inf + return np.nan if not percentile: return surface_distance.max() diff --git a/monai/metrics/surface_distance.py b/monai/metrics/surface_distance.py index 100c1cb2b3..f6fd38d8b6 100644 --- a/monai/metrics/surface_distance.py +++ b/monai/metrics/surface_distance.py @@ -18,7 +18,7 @@ from monai.metrics.utils import * -class SurfaceDistance: +class SurfaceDistanceMetric: """ Compute Surface Distance between two tensors. It can support both multi-classes and multi-labels tasks. It supports both symmetric and asymmetric surface distance calculation. @@ -28,7 +28,7 @@ class SurfaceDistance: Args: include_background: whether to skip distance computation on the first channel of - the predicted output. Defaults to ``True``. + the predicted output. Defaults to ``False``. symmetric: whether to calculate the symmetric average surface distance between `seg_pred` and `seg_gt`. Defaults to ``False``. distance_metric: : [``"euclidean"``, ``"chessboard"``, ``"taxicab"``] @@ -41,7 +41,7 @@ class SurfaceDistance: def __init__( self, - include_background: bool = True, + include_background: bool = False, symmetric: bool = False, distance_metric: str = "euclidean", reduction: Union[MetricReduction, str] = MetricReduction.MEAN, @@ -89,7 +89,7 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor): def compute_average_surface_distance( y_pred: Union[np.ndarray, torch.Tensor], y: Union[np.ndarray, torch.Tensor], - include_background: bool = True, + include_background: bool = False, symmetric: bool = False, distance_metric: str = "euclidean", ): @@ -106,7 +106,7 @@ def compute_average_surface_distance( y: ground truth to compute mean the distance. It must be one-hot format and first dim is batch. The values should be binarized. include_background: whether to skip distance computation on the first channel of - the predicted output. Defaults to ``True``. + the predicted output. Defaults to ``False``. symmetric: whether to calculate the symmetric average surface distance between `seg_pred` and `seg_gt`. Defaults to ``False``. distance_metric: : [``"euclidean"``, ``"chessboard"``, ``"taxicab"``] @@ -132,7 +132,7 @@ def compute_average_surface_distance( (edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c]) surface_distance = get_surface_distance(edges_pred, edges_gt, distance_metric=distance_metric) if surface_distance.shape == (0,): - avg_surface_distance = np.inf + avg_surface_distance = np.nan else: avg_surface_distance = surface_distance.mean() if not symmetric: @@ -140,7 +140,7 @@ def compute_average_surface_distance( else: surface_distance_2 = get_surface_distance(edges_gt, edges_pred, distance_metric=distance_metric) if surface_distance_2.shape == (0,): - avg_surface_distance_2 = np.inf + avg_surface_distance_2 = np.nan else: avg_surface_distance_2 = surface_distance_2.mean() asd[b, c] = np.mean((avg_surface_distance, avg_surface_distance_2)) diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index e3eb2eff05..d1fcc5e723 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -167,16 +167,16 @@ def get_mask_edges( def get_surface_distance( - edges_pred: np.ndarray, - edges_gt: np.ndarray, + seg_pred: np.ndarray, + seg_gt: np.ndarray, distance_metric: str = "euclidean", ) -> np.ndarray: """ This function is used to compute the surface distances from `seg_pred` to `seg_gt`. Args: - edges_pred: the edge of the predictions. - edges_gt: the edge of the ground truth. + seg_pred: the edge of the predictions. + seg_gt: the edge of the ground truth. distance_metric: : [``"euclidean"``, ``"chessboard"``, ``"taxicab"``] the metric used to compute surface distance. Defaults to ``"euclidean"``. @@ -185,17 +185,20 @@ def get_surface_distance( - ``"taxicab"``, uses `taxicab` metric in chamfer type of transform. """ - if not np.any(edges_pred): - return np.array([]) - - if not np.any(edges_gt): - dis = np.inf * np.ones_like(edges_gt) + if not np.any(seg_gt): + if not np.any(seg_pred): + return np.array([]) + else: + dis = np.inf * np.ones_like(seg_gt) + return dis[seg_pred] else: + if not np.any(seg_pred): + dis = np.inf * np.ones_like(seg_gt) + return dis[seg_gt] if distance_metric == "euclidean": - dis = distance_transform_edt(~edges_gt) + dis = distance_transform_edt(~seg_gt) elif distance_metric == "chessboard" or distance_metric == "taxicab": - dis = distance_transform_cdt(~edges_gt, metric=distance_metric) + dis = distance_transform_cdt(~seg_gt, metric=distance_metric) else: raise ValueError(f"distance_metric {distance_metric} is not implemented.") - surface_distance = dis[edges_pred] - return surface_distance + return dis[seg_pred] diff --git a/tests/min_tests.py b/tests/min_tests.py index e22d94bc57..ba8e138afd 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -40,12 +40,14 @@ def run_testsuit(): "test_handler_lr_scheduler", "test_handler_confusion_matrix", "test_handler_confusion_matrix_dist", + "test_handler_hausdorff_distance", "test_handler_mean_dice", "test_handler_rocauc", "test_handler_rocauc_dist", "test_handler_segmentation_saver", "test_handler_smartcache", "test_handler_stats", + "test_handler_surface_distance", "test_handler_tb_image", "test_handler_tb_stats", "test_handler_validation", diff --git a/tests/test_handler_hausdorff_distance.py b/tests/test_handler_hausdorff_distance.py new file mode 100644 index 0000000000..67322718b1 --- /dev/null +++ b/tests/test_handler_hausdorff_distance.py @@ -0,0 +1,88 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from typing import Tuple + +import numpy as np +import torch + +from monai.handlers import HausdorffDistance + + +def create_spherical_seg_3d( + radius: float = 20.0, + centre: Tuple[int, int, int] = (49, 49, 49), + im_shape: Tuple[int, int, int] = (99, 99, 99), +) -> np.ndarray: + """ + Return a 3D image with a sphere inside. Voxel values will be + 1 inside the sphere, and 0 elsewhere. + + Args: + radius: radius of sphere (in terms of number of voxels, can be partial) + centre: location of sphere centre. + im_shape: shape of image to create + + See also: + :py:meth:`~create_test_image_3d` + """ + # Create image + image = np.zeros(im_shape, dtype=np.int32) + spy, spx, spz = np.ogrid[ + -centre[0] : im_shape[0] - centre[0], -centre[1] : im_shape[1] - centre[1], -centre[2] : im_shape[2] - centre[2] + ] + circle = (spx * spx + spy * spy + spz * spz) <= radius * radius + + image[circle] = 1 + image[~circle] = 0 + return image + + +sampler_sphere = torch.Tensor(create_spherical_seg_3d(radius=20, centre=(20, 20, 20))).unsqueeze(0).unsqueeze(0) +sampler_sphere_gt = torch.Tensor(create_spherical_seg_3d(radius=20, centre=(10, 20, 20))).unsqueeze(0).unsqueeze(0) +sampler_sphere_zeros = torch.zeros_like(sampler_sphere) + +TEST_SAMPLE_1 = [sampler_sphere, sampler_sphere_gt] +TEST_SAMPLE_2 = [sampler_sphere_gt, sampler_sphere_gt] +TEST_SAMPLE_3 = [sampler_sphere_zeros, sampler_sphere_gt] +TEST_SAMPLE_4 = [sampler_sphere_zeros, sampler_sphere_zeros] + + +class TestHandlerHausdorffDistance(unittest.TestCase): + # TODO test multi node Hausdorff Distance + + def test_compute(self): + hd_metric = HausdorffDistance(include_background=True) + y_pred, y = TEST_SAMPLE_1 + hd_metric.update([y_pred, y]) + self.assertEqual(hd_metric.compute(), 10) + y_pred, y = TEST_SAMPLE_2 + hd_metric.update([y_pred, y]) + self.assertEqual(hd_metric.compute(), 5) + y_pred, y = TEST_SAMPLE_3 + hd_metric.update([y_pred, y]) + self.assertEqual(hd_metric.compute(), float("inf")) + self.assertEqual(hd_metric._num_examples, 3) + y_pred, y = TEST_SAMPLE_4 + hd_metric.update([y_pred, y]) + self.assertEqual(hd_metric._num_examples, 3) + + def test_shape_mismatch(self): + hd_metric = HausdorffDistance(include_background=True) + with self.assertRaises((AssertionError, ValueError)): + y_pred = TEST_SAMPLE_1[0] + y = torch.ones((1, 1, 10, 10, 10)) + hd_metric.update([y_pred, y]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_handler_surface_distance.py b/tests/test_handler_surface_distance.py new file mode 100644 index 0000000000..02898769f6 --- /dev/null +++ b/tests/test_handler_surface_distance.py @@ -0,0 +1,88 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from typing import Tuple + +import numpy as np +import torch + +from monai.handlers import SurfaceDistance + + +def create_spherical_seg_3d( + radius: float = 20.0, + centre: Tuple[int, int, int] = (49, 49, 49), + im_shape: Tuple[int, int, int] = (99, 99, 99), +) -> np.ndarray: + """ + Return a 3D image with a sphere inside. Voxel values will be + 1 inside the sphere, and 0 elsewhere. + + Args: + radius: radius of sphere (in terms of number of voxels, can be partial) + centre: location of sphere centre. + im_shape: shape of image to create + + See also: + :py:meth:`~create_test_image_3d` + """ + # Create image + image = np.zeros(im_shape, dtype=np.int32) + spy, spx, spz = np.ogrid[ + -centre[0] : im_shape[0] - centre[0], -centre[1] : im_shape[1] - centre[1], -centre[2] : im_shape[2] - centre[2] + ] + circle = (spx * spx + spy * spy + spz * spz) <= radius * radius + + image[circle] = 1 + image[~circle] = 0 + return image + + +sampler_sphere = torch.Tensor(create_spherical_seg_3d(radius=20, centre=(20, 20, 20))).unsqueeze(0).unsqueeze(0) +sampler_sphere_gt = torch.Tensor(create_spherical_seg_3d(radius=20, centre=(10, 20, 20))).unsqueeze(0).unsqueeze(0) +sampler_sphere_zeros = torch.zeros_like(sampler_sphere) + +TEST_SAMPLE_1 = [sampler_sphere, sampler_sphere_gt] +TEST_SAMPLE_2 = [sampler_sphere_gt, sampler_sphere_gt] +TEST_SAMPLE_3 = [sampler_sphere_zeros, sampler_sphere_gt] +TEST_SAMPLE_4 = [sampler_sphere_zeros, sampler_sphere_zeros] + + +class TestHandlerSurfaceDistance(unittest.TestCase): + # TODO test multi node Surface Distance + + def test_compute(self): + sur_metric = SurfaceDistance(include_background=True) + y_pred, y = TEST_SAMPLE_1 + sur_metric.update([y_pred, y]) + self.assertAlmostEqual(sur_metric.compute(), 4.17133, places=4) + y_pred, y = TEST_SAMPLE_2 + sur_metric.update([y_pred, y]) + self.assertAlmostEqual(sur_metric.compute(), 2.08566, places=4) + y_pred, y = TEST_SAMPLE_3 + sur_metric.update([y_pred, y]) + self.assertAlmostEqual(sur_metric.compute(), float("inf")) + self.assertAlmostEqual(sur_metric._num_examples, 3) + y_pred, y = TEST_SAMPLE_4 + sur_metric.update([y_pred, y]) + self.assertAlmostEqual(sur_metric._num_examples, 3) + + def test_shape_mismatch(self): + sur_metric = SurfaceDistance(include_background=True) + with self.assertRaises((AssertionError, ValueError)): + y_pred = TEST_SAMPLE_1[0] + y = torch.ones((1, 1, 10, 10, 10)) + sur_metric.update([y_pred, y]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_hausdorff_distance.py b/tests/test_hausdorff_distance.py index 8d37acb577..96c52cbb68 100644 --- a/tests/test_hausdorff_distance.py +++ b/tests/test_hausdorff_distance.py @@ -16,7 +16,7 @@ import torch from parameterized import parameterized -from monai.metrics import HausdorffDistance +from monai.metrics import HausdorffDistanceMetric def create_spherical_seg_3d( @@ -76,6 +76,7 @@ def create_spherical_seg_3d( ], [ [ + # pred does not have foreground (but gt has), the metric should be inf np.zeros([99, 99, 99]), create_spherical_seg_3d(radius=40, centre=(20, 33, 22)), ], @@ -83,13 +84,7 @@ def create_spherical_seg_3d( ], [ [ - np.zeros([99, 99, 99]), - np.zeros([99, 99, 99]), - ], - [np.inf, np.inf, np.inf, np.inf, np.inf, np.inf], - ], - [ - [ + # gt does not have foreground (but pred has), the metric should be inf create_spherical_seg_3d(), np.zeros([99, 99, 99]), ], @@ -105,6 +100,16 @@ def create_spherical_seg_3d( ], ] +TEST_CASES_NANS = [ + [ + [ + # both pred and gt do not have foreground, metric and not_nans should be 0 + np.zeros([99, 99, 99]), + np.zeros([99, 99, 99]), + ], + ], +] + class TestHausdorffDistance(unittest.TestCase): @parameterized.expand(TEST_CASES) @@ -119,7 +124,7 @@ def test_value(self, input_data, expected_value): seg_2 = torch.tensor(seg_2) for metric in ["euclidean", "chessboard", "taxicab"]: for directed in [True, False]: - hd_metric = HausdorffDistance( + hd_metric = HausdorffDistanceMetric( include_background=False, distance_metric=metric, percentile=percentile, directed=directed ) # shape of seg_1, seg_2 are: HWD, converts to BNHWD @@ -131,6 +136,18 @@ def test_value(self, input_data, expected_value): np.testing.assert_allclose(expected_value_curr, result, rtol=1e-7) ct += 1 + @parameterized.expand(TEST_CASES_NANS) + def test_nans(self, input_data): + [seg_1, seg_2] = input_data + seg_1 = torch.tensor(seg_1) + seg_2 = torch.tensor(seg_2) + hd_metric = HausdorffDistanceMetric(include_background=False) + batch_seg_1 = seg_1.unsqueeze(0).unsqueeze(0) + batch_seg_2 = seg_2.unsqueeze(0).unsqueeze(0) + result, not_nans = hd_metric(batch_seg_1, batch_seg_2) + np.testing.assert_allclose(0, result, rtol=1e-7) + np.testing.assert_allclose(0, not_nans, rtol=1e-7) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_surface_distance.py b/tests/test_surface_distance.py index 5548e7907f..dca3aaec12 100644 --- a/tests/test_surface_distance.py +++ b/tests/test_surface_distance.py @@ -16,7 +16,7 @@ import torch from parameterized import parameterized -from monai.metrics import SurfaceDistance +from monai.metrics import SurfaceDistanceMetric def create_spherical_seg_3d( @@ -100,18 +100,21 @@ def create_spherical_seg_3d( ], [ [ + create_spherical_seg_3d(), np.zeros([99, 99, 99]), - np.zeros([99, 99, 99]), + "taxicab", ], [np.inf, np.inf], ], +] + +TEST_CASES_NANS = [ [ [ - create_spherical_seg_3d(), + # both pred and gt do not have foreground, metric and not_nans should be 0 + np.zeros([99, 99, 99]), np.zeros([99, 99, 99]), - "taxicab", ], - [np.inf, np.inf], ], ] @@ -128,7 +131,7 @@ def test_value(self, input_data, expected_value): seg_1 = torch.tensor(seg_1) seg_2 = torch.tensor(seg_2) for symmetric in [True, False]: - sur_metric = SurfaceDistance(include_background=False, symmetric=symmetric, distance_metric=metric) + sur_metric = SurfaceDistanceMetric(include_background=False, symmetric=symmetric, distance_metric=metric) # shape of seg_1, seg_2 are: HWD, converts to BNHWD batch, n_class = 2, 3 batch_seg_1 = seg_1.unsqueeze(0).unsqueeze(0).repeat([batch, n_class, 1, 1, 1]) @@ -138,6 +141,18 @@ def test_value(self, input_data, expected_value): np.testing.assert_allclose(expected_value_curr, result, rtol=1e-7) ct += 1 + @parameterized.expand(TEST_CASES_NANS) + def test_nans(self, input_data): + [seg_1, seg_2] = input_data + seg_1 = torch.tensor(seg_1) + seg_2 = torch.tensor(seg_2) + sur_metric = SurfaceDistanceMetric(include_background=False) + batch_seg_1 = seg_1.unsqueeze(0).unsqueeze(0) + batch_seg_2 = seg_2.unsqueeze(0).unsqueeze(0) + result, not_nans = sur_metric(batch_seg_1, batch_seg_2) + np.testing.assert_allclose(0, result, rtol=1e-7) + np.testing.assert_allclose(0, not_nans, rtol=1e-7) + if __name__ == "__main__": unittest.main() From ac4ff55223ab9b90863e7828f6826cba2c35b5af Mon Sep 17 00:00:00 2001 From: yiheng-wang-nv Date: Tue, 15 Dec 2020 00:16:57 +0800 Subject: [PATCH 5/7] Simplify distance code Signed-off-by: yiheng-wang-nv --- monai/metrics/utils.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index d1fcc5e723..58461bdbb7 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -186,11 +186,8 @@ def get_surface_distance( """ if not np.any(seg_gt): - if not np.any(seg_pred): - return np.array([]) - else: - dis = np.inf * np.ones_like(seg_gt) - return dis[seg_pred] + dis = np.inf * np.ones_like(seg_gt) + return dis[seg_pred] else: if not np.any(seg_pred): dis = np.inf * np.ones_like(seg_gt) From 2dc510ae52f86432a5d296f2368e2fd70e70d27f Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 14 Dec 2020 18:49:59 +0000 Subject: [PATCH 6/7] autofixes and typo fixes Signed-off-by: Wenqi Li --- docs/source/handlers.rst | 6 ++--- docs/source/metrics.rst | 6 ++--- monai/apps/datasets.py | 23 +++++++++-------- monai/apps/utils.py | 7 ++++++ monai/data/image_reader.py | 2 ++ monai/data/utils.py | 14 +++++------ monai/engines/evaluator.py | 2 ++ monai/engines/multi_gpu_supervised_trainer.py | 5 ++++ monai/engines/trainer.py | 2 ++ monai/handlers/utils.py | 2 ++ monai/inferers/inferer.py | 2 ++ monai/inferers/utils.py | 2 ++ monai/metrics/__init__.py | 2 +- monai/metrics/confusion_matrix.py | 6 ++--- monai/metrics/hausdorff_distance.py | 3 +++ monai/metrics/meandice.py | 1 + monai/metrics/rocauc.py | 25 +++++++++---------- monai/metrics/surface_distance.py | 1 + monai/metrics/utils.py | 12 ++++----- monai/networks/utils.py | 23 +++++++++++------ monai/optimizers/utils.py | 2 ++ monai/transforms/compose.py | 2 ++ monai/transforms/utils.py | 12 +++------ monai/utils/module.py | 5 +--- monai/visualize/img2tensorboard.py | 10 ++++---- 25 files changed, 105 insertions(+), 72 deletions(-) diff --git a/docs/source/handlers.rst b/docs/source/handlers.rst index 2eb171c00f..2962f725d8 100644 --- a/docs/source/handlers.rst +++ b/docs/source/handlers.rst @@ -34,19 +34,19 @@ ROC AUC metrics handler :members: -Confusion Matrix metrics handler +Confusion matrix metrics handler -------------------------------- .. autoclass:: ConfusionMatrix :members: -Hausdorff Distance metrics handler +Hausdorff distance metrics handler ---------------------------------- .. autoclass:: HausdorffDistance :members: -Surface Distance metrics handler +Surface distance metrics handler -------------------------------- .. autoclass:: SurfaceDistance :members: diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index b280b5461d..0bcfbd4240 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -17,21 +17,21 @@ Metrics -------------------------- .. autofunction:: compute_roc_auc -`Confusion Matrix` +`Confusion matrix` ------------------ .. autofunction:: get_confusion_matrix .. autoclass:: ConfusionMatrixMetric :members: -`Hausdorff Distance` +`Hausdorff distance` -------------------- .. autofunction:: compute_hausdorff_distance .. autoclass:: HausdorffDistanceMetric :members: -`Average Surface Distance` +`Average surface distance` -------------------------- .. autofunction:: compute_average_surface_distance diff --git a/monai/apps/datasets.py b/monai/apps/datasets.py index 99643fd4db..6272b50b4c 100644 --- a/monai/apps/datasets.py +++ b/monai/apps/datasets.py @@ -26,6 +26,8 @@ from monai.transforms import LoadImaged, Randomizable from monai.utils import ensure_tuple +__all__ = ["MedNISTDataset", "DecathlonDataset", "CrossValidation"] + class MedNISTDataset(Randomizable, CacheDataset): """ @@ -121,7 +123,7 @@ def _generate_data_list(self, dataset_dir: str) -> List[Dict]: image_class.extend([i] * num_each[i]) num_total = len(image_class) - data = list() + data = [] for i in range(num_total): self.randomize() @@ -302,18 +304,17 @@ def _generate_data_list(self, dataset_dir: str) -> List[Dict]: def _split_datalist(self, datalist: List[Dict]) -> List[Dict]: if self.section == "test": return datalist - else: - length = len(datalist) - indices = np.arange(length) - self.randomize(indices) + length = len(datalist) + indices = np.arange(length) + self.randomize(indices) - val_length = int(length * self.val_frac) - if self.section == "training": - self.indices = indices[val_length:] - else: - self.indices = indices[:val_length] + val_length = int(length * self.val_frac) + if self.section == "training": + self.indices = indices[val_length:] + else: + self.indices = indices[:val_length] - return [datalist[i] for i in self.indices] + return [datalist[i] for i in self.indices] class CrossValidation: diff --git a/monai/apps/utils.py b/monai/apps/utils.py index 8461bf4a29..e48dfb63f2 100644 --- a/monai/apps/utils.py +++ b/monai/apps/utils.py @@ -31,6 +31,13 @@ else: tqdm, has_tqdm = optional_import("tqdm", "4.47.0", min_version, "tqdm") +__all__ = [ + "check_hash", + "download_url", + "extractall", + "download_and_extract", +] + def check_hash(filepath: str, val: Optional[str] = None, hash_type: str = "md5") -> bool: """ diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 5b0450ab8a..32d03115ed 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -37,6 +37,8 @@ Nifti1Image, _ = optional_import("nibabel.nifti1", name="Nifti1Image") PILImage, has_pil = optional_import("PIL.Image") +__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader"] + class ImageReader(ABC): """Abstract class to define interface APIs to load image files. diff --git a/monai/data/utils.py b/monai/data/utils.py index b63ff6e66b..c5fcbf3c86 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -134,8 +134,7 @@ def dense_patch_slices( dim_starts.append(start_idx) starts.append(dim_starts) out = np.asarray([x.flatten() for x in np.meshgrid(*starts, indexing="ij")]).T - slices = [tuple(slice(s, s + patch_size[d]) for d, s in enumerate(x)) for x in out] - return slices + return [tuple(slice(s, s + patch_size[d]) for d, s in enumerate(x)) for x in out] def iter_patch( @@ -550,7 +549,7 @@ def is_supported_format(filename: Union[Sequence[str], str], suffixes: Sequence[ filenames: Sequence[str] = ensure_tuple(filename) for name in filenames: tokens: Sequence[str] = PurePath(name).suffixes - if len(tokens) == 0 or not any(("." + s.lower()) in "".join(tokens) for s in suffixes): + if len(tokens) == 0 or all("." + s.lower() not in "".join(tokens) for s in suffixes): return False return True @@ -598,7 +597,7 @@ def partition_dataset( """ data_len = len(data) - datasets = list() + datasets = [] indices = list(range(data_len)) if shuffle: @@ -682,7 +681,7 @@ def partition_dataset_classes( """ if not classes or len(classes) != len(data): raise ValueError(f"length of classes {classes} must match the dataset length {len(data)}.") - datasets = list() + datasets = [] class_indices = defaultdict(list) for i, c in enumerate(classes): class_indices[c].append(i) @@ -698,7 +697,7 @@ def partition_dataset_classes( drop_last=drop_last, even_divisible=even_divisible, ) - if len(class_partition_indices) == 0: + if not class_partition_indices: class_partition_indices = per_class_partition_indices else: for part, data_indices in zip(class_partition_indices, per_class_partition_indices): @@ -735,8 +734,7 @@ def select_cross_validation_folds(partitions: Sequence[Iterable], folds: Union[S >>> select_cross_validation_folds(partitions, [-1, 2]) [9, 10, 5, 6] """ - data_list = [data_item for fold_id in ensure_tuple(folds) for data_item in partitions[fold_id]] - return data_list + return [data_item for fold_id in ensure_tuple(folds) for data_item in partitions[fold_id]] class DistributedSampler(_TorchDistributedSampler): diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index 930747edfb..306be5f2db 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -28,6 +28,8 @@ Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric") +__all__ = ["Evaluator", "SupervisedEvaluator", "EnsembleEvaluator"] + class Evaluator(Workflow): """ diff --git a/monai/engines/multi_gpu_supervised_trainer.py b/monai/engines/multi_gpu_supervised_trainer.py index 7110a09c0f..33268308e5 100644 --- a/monai/engines/multi_gpu_supervised_trainer.py +++ b/monai/engines/multi_gpu_supervised_trainer.py @@ -29,6 +29,11 @@ Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric") +__all__ = [ + "create_multigpu_supervised_trainer", + "create_multigpu_supervised_evaluator", +] + def _default_transform(_x: torch.Tensor, _y: torch.Tensor, _y_pred: torch.Tensor, loss: torch.Tensor) -> float: return loss.item() diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index c625d1b669..64b38e2646 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -29,6 +29,8 @@ Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric") +__all__ = ["Trainer", "SupervisedTrainer", "GanTrainer"] + class Trainer(Workflow): """ diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index e401e18b0c..e96521f47e 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -21,6 +21,8 @@ else: Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") +__all__ = ["stopping_fn_from_metric", "stopping_fn_from_loss", "all_gather"] + def stopping_fn_from_metric(metric_name: str) -> Callable[[Engine], Any]: """ diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index eea56d3d45..36cc3de478 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -17,6 +17,8 @@ from monai.inferers.utils import sliding_window_inference from monai.utils import BlendMode, PytorchPadMode +__all__ = ["Inferer", "SimpleInferer", "SlidingWindowInferer"] + class Inferer(ABC): """ diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 48bd334061..c7db520cb2 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -17,6 +17,8 @@ from monai.data.utils import compute_importance_map, dense_patch_slices, get_valid_patch_size from monai.utils import BlendMode, PytorchPadMode, fall_back_tuple +__all__ = ["sliding_window_inference"] + def sliding_window_inference( inputs: torch.Tensor, diff --git a/monai/metrics/__init__.py b/monai/metrics/__init__.py index 345f451258..a0d626f45b 100644 --- a/monai/metrics/__init__.py +++ b/monai/metrics/__init__.py @@ -10,7 +10,7 @@ # limitations under the License. from .confusion_matrix import ConfusionMatrixMetric, compute_confusion_matrix_metric, get_confusion_matrix -from .hausdorff_distance import HausdorffDistanceMetric, compute_hausdorff_distance +from .hausdorff_distance import * from .meandice import DiceMetric, compute_meandice from .occlusion_sensitivity import compute_occlusion_sensitivity from .rocauc import compute_roc_auc diff --git a/monai/metrics/confusion_matrix.py b/monai/metrics/confusion_matrix.py index 8d2304cea3..abe077f85e 100644 --- a/monai/metrics/confusion_matrix.py +++ b/monai/metrics/confusion_matrix.py @@ -15,6 +15,7 @@ import torch from monai.metrics.utils import * +from monai.utils import MetricReduction class ConfusionMatrixMetric: @@ -263,10 +264,9 @@ def compute_confusion_matrix_metric(metric_name: str, confusion_matrix: torch.Te raise NotImplementedError("the metric is not implemented.") if isinstance(denominator, torch.Tensor): - result = torch.where(denominator != 0, numerator / denominator, nan_tensor) + return torch.where(denominator != 0, numerator / denominator, nan_tensor) else: - result = numerator / denominator - return result + return numerator / denominator def check_confusion_matrix_metric_name(metric_name: str): diff --git a/monai/metrics/hausdorff_distance.py b/monai/metrics/hausdorff_distance.py index cb9fc25f57..c649cd3a04 100644 --- a/monai/metrics/hausdorff_distance.py +++ b/monai/metrics/hausdorff_distance.py @@ -16,6 +16,9 @@ import torch from monai.metrics.utils import * +from monai.utils import MetricReduction + +__all__ = ["HausdorffDistanceMetric", "compute_hausdorff_distance", "compute_percent_hausdorff_distance"] class HausdorffDistanceMetric: diff --git a/monai/metrics/meandice.py b/monai/metrics/meandice.py index b530c425ee..53716909fe 100644 --- a/monai/metrics/meandice.py +++ b/monai/metrics/meandice.py @@ -15,6 +15,7 @@ import torch from monai.metrics.utils import * +from monai.utils import MetricReduction class DiceMetric: diff --git a/monai/metrics/rocauc.py b/monai/metrics/rocauc.py index d5c1cf20d2..7b26560d57 100644 --- a/monai/metrics/rocauc.py +++ b/monai/metrics/rocauc.py @@ -132,16 +132,15 @@ def compute_roc_auc( average = Average(average) if average == Average.MICRO: return _calculate(y.flatten(), y_pred.flatten()) - else: - y, y_pred = y.transpose(0, 1), y_pred.transpose(0, 1) - auc_values = [_calculate(y_, y_pred_) for y_, y_pred_ in zip(y, y_pred)] - if average == Average.NONE: - return auc_values - if average == Average.MACRO: - return np.mean(auc_values) - if average == Average.WEIGHTED: - weights = [sum(y_) for y_ in y] - return np.average(auc_values, weights=weights) - raise ValueError( - f'Unsupported average: {average}, available options are ["macro", "weighted", "micro", "none"].' - ) + y, y_pred = y.transpose(0, 1), y_pred.transpose(0, 1) + auc_values = [_calculate(y_, y_pred_) for y_, y_pred_ in zip(y, y_pred)] + if average == Average.NONE: + return auc_values + if average == Average.MACRO: + return np.mean(auc_values) + if average == Average.WEIGHTED: + weights = [sum(y_) for y_ in y] + return np.average(auc_values, weights=weights) + raise ValueError( + f'Unsupported average: {average}, available options are ["macro", "weighted", "micro", "none"].' + ) diff --git a/monai/metrics/surface_distance.py b/monai/metrics/surface_distance.py index f6fd38d8b6..8dcbe4d9f6 100644 --- a/monai/metrics/surface_distance.py +++ b/monai/metrics/surface_distance.py @@ -16,6 +16,7 @@ import torch from monai.metrics.utils import * +from monai.utils import MetricReduction class SurfaceDistanceMetric: diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index 58461bdbb7..ffe6093621 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -22,6 +22,8 @@ distance_transform_edt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_edt") distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt") +__all__ = ["ignore_background", "do_metric_reduction", "get_mask_edges", "get_surface_distance"] + def ignore_background( y_pred: torch.Tensor, @@ -91,9 +93,7 @@ def do_metric_reduction( elif reduction == MetricReduction.SUM_CHANNEL: not_nans = not_nans.sum(dim=1) f = f.sum(dim=1) # the channel sum - elif reduction == MetricReduction.NONE: - pass - else: + elif reduction != MetricReduction.NONE: raise ValueError( f"Unsupported reduction: {reduction}, available options are " '["mean", "sum", "mean_batch", "sum_batch", "mean_channel", "sum_channel" "none"].' @@ -187,15 +187,15 @@ def get_surface_distance( if not np.any(seg_gt): dis = np.inf * np.ones_like(seg_gt) - return dis[seg_pred] else: if not np.any(seg_pred): dis = np.inf * np.ones_like(seg_gt) return dis[seg_gt] if distance_metric == "euclidean": dis = distance_transform_edt(~seg_gt) - elif distance_metric == "chessboard" or distance_metric == "taxicab": + elif distance_metric in ["chessboard", "taxicab"]: dis = distance_transform_cdt(~seg_gt, metric=distance_metric) else: raise ValueError(f"distance_metric {distance_metric} is not implemented.") - return dis[seg_pred] + + return dis[seg_pred] diff --git a/monai/networks/utils.py b/monai/networks/utils.py index a46e8e66d7..1bcccd084c 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -20,6 +20,17 @@ from monai.utils import ensure_tuple_size +__all__ = [ + "one_hot", + "slice_channels", + "predict_segmentation", + "normalize_transform", + "to_norm_affine", + "normal_init", + "icnr_init", + "pixelshuffle", +] + def one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.float, dim: int = 1) -> torch.Tensor: """ @@ -72,11 +83,10 @@ def predict_segmentation( """ if not mutually_exclusive: return (cast(torch.Tensor, logits >= threshold)).int() - else: - if logits.shape[1] == 1: - warnings.warn("single channel prediction, `mutually_exclusive=True` ignored, use threshold instead.") - return (cast(torch.Tensor, logits >= threshold)).int() - return logits.argmax(1, keepdim=True) + if logits.shape[1] == 1: + warnings.warn("single channel prediction, `mutually_exclusive=True` ignored, use threshold instead.") + return (cast(torch.Tensor, logits >= threshold)).int() + return logits.argmax(1, keepdim=True) def normalize_transform( @@ -145,8 +155,7 @@ def to_norm_affine( src_xform = normalize_transform(src_size, affine.device, affine.dtype, align_corners) dst_xform = normalize_transform(dst_size, affine.device, affine.dtype, align_corners) - new_affine = src_xform @ affine @ torch.inverse(dst_xform) - return new_affine + return src_xform @ affine @ torch.inverse(dst_xform) def normal_init( diff --git a/monai/optimizers/utils.py b/monai/optimizers/utils.py index 57c7528ba4..4cafa45749 100644 --- a/monai/optimizers/utils.py +++ b/monai/optimizers/utils.py @@ -15,6 +15,8 @@ from monai.utils import ensure_tuple, ensure_tuple_rep +__all__ = ["generate_param_groups"] + def generate_param_groups( network: torch.nn.Module, diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index d5cae18a53..20e72f1df0 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -22,6 +22,8 @@ from monai.transforms.utils import apply_transform from monai.utils import MAX_SEED, ensure_tuple, get_seed +__all__ = ["Transform", "Randomizable", "Compose", "MapTransform"] + class Transform(ABC): """ diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 4a4b79cdf5..1523ce1e22 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -60,10 +60,7 @@ def zero_margins(img: np.ndarray, margin: int) -> bool: if np.any(img[:, :, :margin]) or np.any(img[:, :, -margin:]): return False - if np.any(img[:, :margin, :]) or np.any(img[:, -margin:, :]): - return False - - return True + return not np.any(img[:, :margin, :]) and not np.any(img[:, -margin:, :]) def rescale_array( @@ -262,8 +259,7 @@ def weighted_patch_samples( idx = v.searchsorted(r_state.random(n_samples) * v[-1], side="right") # compensate 'valid' mode diff = np.minimum(win_size, img_size) // 2 - centers = [np.unravel_index(i, v_size) + diff for i in np.asarray(idx, dtype=np.int)] - return centers + return [np.unravel_index(i, v_size) + diff for i in np.asarray(idx, dtype=np.int)] def generate_pos_neg_label_crop_centers( @@ -427,7 +423,7 @@ def create_rotate(spatial_dims: int, radians: Union[Sequence[float], float]) -> return np.array([[cos_, -sin_, 0.0], [sin_, cos_, 0.0], [0.0, 0.0, 1.0]]) raise ValueError("radians must be non empty.") - if spatial_dims == 3: + elif spatial_dims == 3: affine = None if len(radians) >= 1: sin_, cos_ = np.sin(radians[0]), np.cos(radians[0]) @@ -466,7 +462,7 @@ def create_shear(spatial_dims: int, coefs: Union[Sequence[float], float]) -> np. if spatial_dims == 2: coefs = ensure_tuple_size(coefs, dim=2, pad_val=0.0) return np.array([[1, coefs[0], 0.0], [coefs[1], 1.0, 0.0], [0.0, 0.0, 1.0]]) - if spatial_dims == 3: + elif spatial_dims == 3: coefs = ensure_tuple_size(coefs, dim=6, pad_val=0.0) return np.array( [ diff --git a/monai/utils/module.py b/monai/utils/module.py index 4bc9a6d63b..dfd5fb7d7b 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -244,10 +244,7 @@ def has_option(obj, keywords: Union[str, Sequence[str]]) -> bool: if not callable(obj): return False sig = inspect.signature(obj) - for key in ensure_tuple(keywords): - if key not in sig.parameters: - return False - return True + return all(key in sig.parameters for key in ensure_tuple(keywords)) def get_package_version(dep_name, default="NOT INSTALLED or UNKNOWN VERSION."): diff --git a/monai/visualize/img2tensorboard.py b/monai/visualize/img2tensorboard.py index 9b22cbdba1..c11bfcfc99 100644 --- a/monai/visualize/img2tensorboard.py +++ b/monai/visualize/img2tensorboard.py @@ -28,6 +28,9 @@ SummaryWriter, _ = optional_import("torch.utils.tensorboard", name="SummaryWriter") +__all__ = ["make_animated_gif_summary", "add_animated_gif", "add_animated_gif_no_channels", "plot_2d_or_3d_image"] + + def _image3_animated_gif(tag: str, image: Union[np.ndarray, torch.Tensor], scale_factor: float = 1.0) -> Summary: """Function to actually create the animated gif. @@ -76,10 +79,7 @@ def make_animated_gif_summary( if the image data is between 0 and 1, using 255 for this value will scale it to displayable range """ - if max_out == 1: - suffix = "/image" - else: - suffix = "/image/{}" + suffix = "/image" if max_out == 1 else "/image/{}" if other_indices is None: other_indices = {} axis_order = [0] + list(animation_axes) + list(image_axes) @@ -194,9 +194,9 @@ def plot_2d_or_3d_image( dataformats = "CHW" writer.add_image(f"{tag}_{dataformats}", d, step, dataformats=dataformats) return + dataformats = "HW" for j, d2 in enumerate(d[:max_channels]): d2 = rescale_array(d2, 0, 1) - dataformats = "HW" writer.add_image(f"{tag}_{dataformats}_{j}", d2, step, dataformats=dataformats) return From 876365fb388e660a98ff5063790031daa4721e5e Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 14 Dec 2020 19:25:40 +0000 Subject: [PATCH 7/7] fixes lgtm typos Signed-off-by: Wenqi Li --- monai/config/deviceconfig.py | 4 ++-- monai/metrics/confusion_matrix.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/monai/config/deviceconfig.py b/monai/config/deviceconfig.py index 355069f941..c70d495555 100644 --- a/monai/config/deviceconfig.py +++ b/monai/config/deviceconfig.py @@ -102,7 +102,7 @@ def set_visible_devices(*dev_inds): def _dict_append(in_dict, key, fn): try: - in_dict[key] = fn() + in_dict[key] = fn() if callable(fn) else fn except BaseException: in_dict[key] = "UNKNOWN for given OS" @@ -197,7 +197,7 @@ def get_gpu_info() -> OrderedDict: _dict_append(output, "Current device", lambda: torch.cuda.current_device()) _dict_append(output, "Library compiled for CUDA architectures", lambda: torch.cuda.get_arch_list()) for gpu in range(num_gpus): - _dict_append(output, "Info for GPU", lambda: gpu) + _dict_append(output, "Info for GPU", gpu) gpu_info = torch.cuda.get_device_properties(gpu) _dict_append(output, "\tName", lambda: gpu_info.name) _dict_append(output, "\tIs integrated", lambda: bool(gpu_info.is_integrated)) diff --git a/monai/metrics/confusion_matrix.py b/monai/metrics/confusion_matrix.py index abe077f85e..916a07439f 100644 --- a/monai/metrics/confusion_matrix.py +++ b/monai/metrics/confusion_matrix.py @@ -257,7 +257,6 @@ def compute_confusion_matrix_metric(metric_name: str, confusion_matrix: torch.Te elif metric == "mk": ppv = torch.where((tp + fp) > 0, tp / (tp + fp), nan_tensor) npv = torch.where((tn + fn) > 0, tn / (tn + fn), nan_tensor) - npv = tn / (tn + fn) numerator = ppv + npv - 1.0 denominator = 1.0 else: