diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index f417fabffa..3ab10652e2 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -20,8 +20,9 @@ import numpy as np import torch -from monai.transforms.compose import Transform -from monai.transforms.utils import map_binary_to_indices +from monai.networks.layers import GaussianFilter +from monai.transforms.compose import Randomizable, Transform +from monai.transforms.utils import get_extreme_points, map_binary_to_indices from monai.utils import ensure_tuple # Generic type which can represent either a numpy.ndarray or a torch.Tensor @@ -535,3 +536,77 @@ def __call__( bg_indices = np.stack([np.unravel_index(i, output_shape) for i in bg_indices]) return fg_indices, bg_indices + + +class AddExtremePointsChannel(Transform, Randomizable): + """ + Add extreme points of label to the image as a new channel. This transform generates extreme + point from label and applies a gaussian filter. The pixel values in points image are rescaled + to range [rescale_min, rescale_max] and added as a new channel to input image. The algorithm is + described in Roth et al., Going to Extremes: Weakly Supervised Medical Image Segmentation + https://arxiv.org/abs/2009.11988. + + This transform only supports single channel labels (1, spatial_dim1, [spatial_dim2, ...]). The + background ``index`` is ignored when calculating extreme points. + + Args: + background: Class index of background label, defaults to 0. + pert: Random perturbation amount to add to the points, defaults to 0.0. + + Raises: + ValueError: When no label image provided. + ValueError: When label image is not single channel. + """ + + def __init__(self, background: int = 0, pert: float = 0.0) -> None: + self._background = background + self._pert = pert + self._points: List[Tuple[int, ...]] = [] + + def randomize(self, label: np.ndarray) -> None: + self._points = get_extreme_points(label, rand_state=self.R, background=self._background, pert=self._pert) + + def __call__( + self, + img: np.ndarray, + label: Optional[np.ndarray] = None, + sigma: Union[Sequence[float], float, Sequence[torch.Tensor], torch.Tensor] = 3.0, + rescale_min: float = -1.0, + rescale_max: float = 1.0, + ) -> np.ndarray: + """ + Args: + img: the image that we want to add new channel to. + label: label image to get extreme points from. Shape must be + (1, spatial_dim1, [, spatial_dim2, ...]). Doesn't support one-hot labels. + sigma: if a list of values, must match the count of spatial dimensions of input data, + and apply every value in the list to 1 spatial dimension. if only 1 value provided, + use it for all spatial dimensions. + rescale_min: minimum value of output data. + rescale_max: maximum value of output data. + """ + if label is None: + raise ValueError("This transform requires a label array!") + if label.shape[0] != 1: + raise ValueError("Only supports single channel labels!") + + # Generate extreme points + self.randomize(label[0, :]) + + # points to image + points_image = torch.zeros(label.shape[1:], dtype=torch.float) + for p in self._points: + points_image[p] = 1.0 + + # add channel and add batch + points_image = points_image.unsqueeze(0).unsqueeze(0) + gaussian_filter = GaussianFilter(img.ndim - 1, sigma=sigma) + points_image = gaussian_filter(points_image).squeeze(0).detach().numpy() + + # rescale the points image to [rescale_min, rescale_max] + min_intensity = np.min(points_image) + max_intensity = np.max(points_image) + points_image = (points_image - min_intensity) / (max_intensity - min_intensity) + points_image = points_image * (rescale_max - rescale_min) + rescale_min + + return np.concatenate([img, points_image], axis=0) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 5ae33626a5..fad98fdb62 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -26,6 +26,7 @@ from monai.transforms.compose import MapTransform from monai.transforms.utility.array import ( AddChannel, + AddExtremePointsChannel, AsChannelFirst, AsChannelLast, CastToType, @@ -660,6 +661,54 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d +class AddExtremePointsChanneld(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.AddExtremePointsChannel`. + + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + label_key: key to label source to get the extreme points. + background: Class index of background label, defaults to 0. + pert: Random perturbation amount to add to the points, defaults to 0.0. + sigma: if a list of values, must match the count of spatial dimensions of input data, + and apply every value in the list to 1 spatial dimension. if only 1 value provided, + use it for all spatial dimensions. + rescale_min: minimum value of output data. + rescale_max: maximum value of output data. + + """ + + def __init__( + self, + keys: KeysCollection, + label_key: str, + background: int = 0, + pert: float = 0.0, + sigma: Union[Sequence[float], float, Sequence[torch.Tensor], torch.Tensor] = 3.0, + rescale_min: float = -1.0, + rescale_max: float = 1.0, + ): + super().__init__(keys) + self.label_key = label_key + self.add_extreme_points_channel = AddExtremePointsChannel(background=background, pert=pert) + self.sigma = sigma + self.rescale_min = rescale_min + self.rescale_max = rescale_max + + def __call__(self, data): + d = dict(data) + label = d[self.label_key] + + for key in data.keys(): + if key in self.keys: + img = d[key] + d[key] = self.add_extreme_points_channel( + img, label=label, sigma=self.sigma, rescale_min=self.rescale_min, rescale_max=self.rescale_max + ) + return d + + IdentityD = IdentityDict = Identityd AsChannelFirstD = AsChannelFirstDict = AsChannelFirstd AsChannelLastD = AsChannelLastDict = AsChannelLastd @@ -680,3 +729,4 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda ConvertToMultiChannelBasedOnBratsClassesD = ( ConvertToMultiChannelBasedOnBratsClassesDict ) = ConvertToMultiChannelBasedOnBratsClassesd +AddExtremePointsChannelD = AddExtremePointsChannelDict = AddExtremePointsChanneld diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 44205e4e09..4a4b79cdf5 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -571,3 +571,56 @@ def get_largest_connected_component_mask(img: torch.Tensor, connectivity: Option if item.max() != 0: largest_cc[i, ...] = item == (np.argmax(np.bincount(item.flat)[1:]) + 1) return torch.as_tensor(largest_cc, device=img.device) + + +def get_extreme_points( + img: np.ndarray, rand_state: np.random.RandomState = np.random, background: int = 0, pert: float = 0.0 +) -> List[Tuple[int, ...]]: + """ + Generate extreme points from an image. These are used to generate initial segmentation + for annotation models. An optional perturbation can be passed to simulate user clicks. + + Args: + img: + Image to generate extreme points from. Expected Shape is ``(spatial_dim1, [, spatial_dim2, ...])``. + rand_state: `np.random.RandomState` object used to select random indices. + background: Value to be consider as background, defaults to 0. + pert: Random perturbation amount to add to the points, defaults to 0.0. + + Returns: + A list of extreme points, its length is equal to 2 * spatial dimension of input image. + The output format of the coordinates is: + + [1st_spatial_dim_min, 1st_spatial_dim_max, 2nd_spatial_dim_min, ..., Nth_spatial_dim_max] + + Raises: + ValueError: When the input image does not have any foreground pixel. + """ + indices = np.where(img != background) + if np.size(indices[0]) == 0: + raise ValueError("get_extreme_points: no foreground object in mask!") + + def _get_point(val, dim): + """ + Select one of the indices within slice containing val. + + Args: + val : value for comparison + dim : dimension in which to look for value + """ + idx = rand_state.choice(np.where(indices[dim] == val)[0]) + pt = [] + for j in range(img.ndim): + # add +- pert to each dimension + val = int(indices[j][idx] + 2.0 * pert * (rand_state.rand() - 0.5)) + val = max(val, 0) + val = min(val, img.shape[j] - 1) + pt.append(val) + return pt + + points = [] + for i in range(img.ndim): + points.append(tuple(_get_point(np.min(indices[i][...]), i))) + points.append(tuple(_get_point(np.max(indices[i][...]), i))) + + return points diff --git a/tests/test_add_extreme_points_channel.py b/tests/test_add_extreme_points_channel.py new file mode 100644 index 0000000000..01277bfa19 --- /dev/null +++ b/tests/test_add_extreme_points_channel.py @@ -0,0 +1,67 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import AddExtremePointsChannel + +IMG_CHANNEL = 3 + +TEST_CASE_1 = [ + { + "img": np.zeros((IMG_CHANNEL, 4, 3)), + "label": np.array([[[0, 1, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]]]), + "sigma": 1.0, + "rescale_min": 0.0, + "rescale_max": 1.0, + }, + np.array( + [ + [0.38318458, 0.98615628, 0.85551184], + [0.35422316, 0.94430935, 1.0], + [0.46000731, 0.57319659, 0.46000722], + [0.64577687, 0.38318464, 0.0], + ] + ), +] + +TEST_CASE_2 = [ + { + "img": np.zeros((IMG_CHANNEL, 4, 3)), + "label": np.array([[[0, 1, 0], [1, 1, 1], [0, 1, 0], [0, 1, 0]]]), + "sigma": 1.0, + "rescale_min": 0.0, + "rescale_max": 1.0, + }, + np.array( + [ + [0.44628328, 0.80495411, 0.44628328], + [0.6779086, 1.0, 0.67790854], + [0.33002687, 0.62079221, 0.33002687], + [0.0, 0.31848389, 0.0], + ] + ), +] + + +class TestAddExtremePointsChannel(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_type_shape(self, input_data, expected): + add_extreme_points_channel = AddExtremePointsChannel() + result = add_extreme_points_channel(**input_data) + np.testing.assert_allclose(result[IMG_CHANNEL], expected, rtol=1e-4) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_add_extreme_points_channeld.py b/tests/test_add_extreme_points_channeld.py new file mode 100644 index 0000000000..6cf4e7be87 --- /dev/null +++ b/tests/test_add_extreme_points_channeld.py @@ -0,0 +1,57 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import AddExtremePointsChanneld + +IMG_CHANNEL = 3 + +TEST_CASE_1 = [ + {"img": np.zeros((IMG_CHANNEL, 4, 3)), "label": np.array([[[0, 1, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]]])}, + np.array( + [ + [0.38318458, 0.98615628, 0.85551184], + [0.35422316, 0.94430935, 1.0], + [0.46000731, 0.57319659, 0.46000722], + [0.64577687, 0.38318464, 0.0], + ] + ), +] + +TEST_CASE_2 = [ + {"img": np.zeros((IMG_CHANNEL, 4, 3)), "label": np.array([[[0, 1, 0], [1, 1, 1], [0, 1, 0], [0, 1, 0]]])}, + np.array( + [ + [0.44628328, 0.80495411, 0.44628328], + [0.6779086, 1.0, 0.67790854], + [0.33002687, 0.62079221, 0.33002687], + [0.0, 0.31848389, 0.0], + ] + ), +] + + +class TestAddExtremePointsChanneld(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_type_shape(self, input_data, expected): + add_extreme_points_channel = AddExtremePointsChanneld( + keys="img", label_key="label", sigma=1.0, rescale_min=0.0, rescale_max=1.0 + ) + result = add_extreme_points_channel(input_data) + np.testing.assert_allclose(result["img"][IMG_CHANNEL], expected, rtol=1e-4) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_get_extreme_points.py b/tests/test_get_extreme_points.py new file mode 100644 index 0000000000..dd38af573e --- /dev/null +++ b/tests/test_get_extreme_points.py @@ -0,0 +1,48 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import get_extreme_points + +TEST_CASE_1 = [ + { + "img": np.array([[0, 1, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]]), + "rand_state": np.random, + "background": 0, + "pert": 0.0, + }, + [(0, 1), (3, 0), (3, 0), (1, 2)], +] + +TEST_CASE_2 = [ + { + "img": np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0], [0, 1, 0]]), + "rand_state": np.random, + "background": 0, + "pert": 0.0, + }, + [(0, 1), (3, 1), (1, 0), (1, 2)], +] + + +class TestGetExtremePoints(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_type_shape(self, input_data, expected): + result = get_extreme_points(**input_data) + self.assertEqual(result, expected) + + +if __name__ == "__main__": + unittest.main()