diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 3ab10652e2..8daad86dd2 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -20,9 +20,8 @@ import numpy as np import torch -from monai.networks.layers import GaussianFilter from monai.transforms.compose import Randomizable, Transform -from monai.transforms.utils import get_extreme_points, map_binary_to_indices +from monai.transforms.utils import extreme_points_to_image, get_extreme_points, map_binary_to_indices from monai.utils import ensure_tuple # Generic type which can represent either a numpy.ndarray or a torch.Tensor @@ -593,20 +592,8 @@ def __call__( # Generate extreme points self.randomize(label[0, :]) - # points to image - points_image = torch.zeros(label.shape[1:], dtype=torch.float) - for p in self._points: - points_image[p] = 1.0 - - # add channel and add batch - points_image = points_image.unsqueeze(0).unsqueeze(0) - gaussian_filter = GaussianFilter(img.ndim - 1, sigma=sigma) - points_image = gaussian_filter(points_image).squeeze(0).detach().numpy() - - # rescale the points image to [rescale_min, rescale_max] - min_intensity = np.min(points_image) - max_intensity = np.max(points_image) - points_image = (points_image - min_intensity) / (max_intensity - min_intensity) - points_image = points_image * (rescale_max - rescale_min) + rescale_min + points_image = extreme_points_to_image( + points=self._points, label=label, sigma=sigma, rescale_min=rescale_min, rescale_max=rescale_max + ) return np.concatenate([img, points_image], axis=0) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index fad98fdb62..28d7452e77 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -17,16 +17,16 @@ import copy import logging -from typing import Callable, Dict, Hashable, Mapping, Optional, Sequence, Union +from typing import Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np import torch from monai.config import KeysCollection -from monai.transforms.compose import MapTransform +from monai.transforms import extreme_points_to_image, get_extreme_points +from monai.transforms.compose import MapTransform, Randomizable from monai.transforms.utility.array import ( AddChannel, - AddExtremePointsChannel, AsChannelFirst, AsChannelLast, CastToType, @@ -661,7 +661,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d -class AddExtremePointsChanneld(MapTransform): +class AddExtremePointsChanneld(Randomizable, MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.AddExtremePointsChannel`. @@ -690,22 +690,37 @@ def __init__( rescale_max: float = 1.0, ): super().__init__(keys) + self.background = background + self.pert = pert + self.points: List[Tuple[int, ...]] = [] self.label_key = label_key - self.add_extreme_points_channel = AddExtremePointsChannel(background=background, pert=pert) self.sigma = sigma self.rescale_min = rescale_min self.rescale_max = rescale_max + def randomize(self, label: np.ndarray) -> None: + self.points = get_extreme_points(label, rand_state=self.R, background=self.background, pert=self.pert) + def __call__(self, data): d = dict(data) label = d[self.label_key] + if label.shape[0] != 1: + raise ValueError("Only supports single channel labels!") + + # Generate extreme points + self.randomize(label[0, :]) for key in data.keys(): if key in self.keys: img = d[key] - d[key] = self.add_extreme_points_channel( - img, label=label, sigma=self.sigma, rescale_min=self.rescale_min, rescale_max=self.rescale_max + points_image = extreme_points_to_image( + points=self.points, + label=label, + sigma=self.sigma, + rescale_min=self.rescale_min, + rescale_max=self.rescale_max, ) + d[key] = np.concatenate([img, points_image], axis=0) return d diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 1523ce1e22..3b552f543c 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -18,6 +18,7 @@ import torch from monai.config import IndexSelection +from monai.networks.layers import GaussianFilter from monai.utils import ensure_tuple, ensure_tuple_rep, ensure_tuple_size, fall_back_tuple, min_version, optional_import measure, _ = optional_import("skimage.measure", "0.14.2", min_version) @@ -620,3 +621,44 @@ def _get_point(val, dim): points.append(tuple(_get_point(np.max(indices[i][...]), i))) return points + + +def extreme_points_to_image( + points: List[Tuple[int, ...]], + label: np.ndarray, + sigma: Union[Sequence[float], float, Sequence[torch.Tensor], torch.Tensor] = 0.0, + rescale_min: float = -1.0, + rescale_max: float = 1.0, +): + """ + Please refer to :py:class:`monai.transforms.AddExtremePointsChannel` for the usage. + + Applies a gaussian filter to the extreme points image. Then the pixel values in points image are rescaled + to range [rescale_min, rescale_max]. + + Args: + points: Extreme points of the object/organ. + label: label image to get extreme points from. Shape must be + (1, spatial_dim1, [, spatial_dim2, ...]). Doesn't support one-hot labels. + sigma: if a list of values, must match the count of spatial dimensions of input data, + and apply every value in the list to 1 spatial dimension. if only 1 value provided, + use it for all spatial dimensions. + rescale_min: minimum value of output data. + rescale_max: maximum value of output data. + """ + # points to image + points_image = torch.zeros(label.shape[1:], dtype=torch.float) + for p in points: + points_image[p] = 1.0 + + # add channel and add batch + points_image = points_image.unsqueeze(0).unsqueeze(0) + gaussian_filter = GaussianFilter(label.ndim - 1, sigma=sigma) + points_image = gaussian_filter(points_image).squeeze(0).detach().numpy() + + # rescale the points image to [rescale_min, rescale_max] + min_intensity = np.min(points_image) + max_intensity = np.max(points_image) + points_image = (points_image - min_intensity) / (max_intensity - min_intensity) + points_image = points_image * (rescale_max - rescale_min) + rescale_min + return points_image diff --git a/tests/test_add_extreme_points_channel.py b/tests/test_add_extreme_points_channel.py index 01277bfa19..f4f3fa6d02 100644 --- a/tests/test_add_extreme_points_channel.py +++ b/tests/test_add_extreme_points_channel.py @@ -57,7 +57,7 @@ class TestAddExtremePointsChannel(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) - def test_type_shape(self, input_data, expected): + def test_correct_results(self, input_data, expected): add_extreme_points_channel = AddExtremePointsChannel() result = add_extreme_points_channel(**input_data) np.testing.assert_allclose(result[IMG_CHANNEL], expected, rtol=1e-4) diff --git a/tests/test_add_extreme_points_channeld.py b/tests/test_add_extreme_points_channeld.py index 6cf4e7be87..4fee176b20 100644 --- a/tests/test_add_extreme_points_channeld.py +++ b/tests/test_add_extreme_points_channeld.py @@ -45,7 +45,7 @@ class TestAddExtremePointsChanneld(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) - def test_type_shape(self, input_data, expected): + def test_correct_results(self, input_data, expected): add_extreme_points_channel = AddExtremePointsChanneld( keys="img", label_key="label", sigma=1.0, rescale_min=0.0, rescale_max=1.0 )