Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 77 additions & 2 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
import numpy as np
import torch

from monai.transforms.compose import Transform
from monai.transforms.utils import map_binary_to_indices
from monai.networks.layers import GaussianFilter
from monai.transforms.compose import Randomizable, Transform
from monai.transforms.utils import get_extreme_points, map_binary_to_indices
from monai.utils import ensure_tuple

# Generic type which can represent either a numpy.ndarray or a torch.Tensor
Expand Down Expand Up @@ -535,3 +536,77 @@ def __call__(
bg_indices = np.stack([np.unravel_index(i, output_shape) for i in bg_indices])

return fg_indices, bg_indices


class AddExtremePointsChannel(Transform, Randomizable):
"""
Add extreme points of label to the image as a new channel. This transform generates extreme
point from label and applies a gaussian filter. The pixel values in points image are rescaled
to range [rescale_min, rescale_max] and added as a new channel to input image. The algorithm is
described in Roth et al., Going to Extremes: Weakly Supervised Medical Image Segmentation
https://arxiv.org/abs/2009.11988.

This transform only supports single channel labels (1, spatial_dim1, [spatial_dim2, ...]). The
background ``index`` is ignored when calculating extreme points.

Args:
background: Class index of background label, defaults to 0.
pert: Random perturbation amount to add to the points, defaults to 0.0.

Raises:
ValueError: When no label image provided.
ValueError: When label image is not single channel.
"""

def __init__(self, background: int = 0, pert: float = 0.0) -> None:
self._background = background
self._pert = pert
self._points: List[Tuple[int, ...]] = []

def randomize(self, label: np.ndarray) -> None:
self._points = get_extreme_points(label, rand_state=self.R, background=self._background, pert=self._pert)

def __call__(
self,
img: np.ndarray,
label: Optional[np.ndarray] = None,
sigma: Union[Sequence[float], float, Sequence[torch.Tensor], torch.Tensor] = 3.0,
rescale_min: float = -1.0,
rescale_max: float = 1.0,
) -> np.ndarray:
"""
Args:
img: the image that we want to add new channel to.
label: label image to get extreme points from. Shape must be
(1, spatial_dim1, [, spatial_dim2, ...]). Doesn't support one-hot labels.
sigma: if a list of values, must match the count of spatial dimensions of input data,
and apply every value in the list to 1 spatial dimension. if only 1 value provided,
use it for all spatial dimensions.
rescale_min: minimum value of output data.
rescale_max: maximum value of output data.
"""
if label is None:
raise ValueError("This transform requires a label array!")
if label.shape[0] != 1:
raise ValueError("Only supports single channel labels!")

# Generate extreme points
self.randomize(label[0, :])

# points to image
points_image = torch.zeros(label.shape[1:], dtype=torch.float)
for p in self._points:
points_image[p] = 1.0

# add channel and add batch
points_image = points_image.unsqueeze(0).unsqueeze(0)
gaussian_filter = GaussianFilter(img.ndim - 1, sigma=sigma)
points_image = gaussian_filter(points_image).squeeze(0).detach().numpy()

# rescale the points image to [rescale_min, rescale_max]
min_intensity = np.min(points_image)
max_intensity = np.max(points_image)
points_image = (points_image - min_intensity) / (max_intensity - min_intensity)
points_image = points_image * (rescale_max - rescale_min) + rescale_min

return np.concatenate([img, points_image], axis=0)
50 changes: 50 additions & 0 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from monai.transforms.compose import MapTransform
from monai.transforms.utility.array import (
AddChannel,
AddExtremePointsChannel,
AsChannelFirst,
AsChannelLast,
CastToType,
Expand Down Expand Up @@ -660,6 +661,54 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda
return d


class AddExtremePointsChanneld(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.AddExtremePointsChannel`.

Args:
keys: keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
label_key: key to label source to get the extreme points.
background: Class index of background label, defaults to 0.
pert: Random perturbation amount to add to the points, defaults to 0.0.
sigma: if a list of values, must match the count of spatial dimensions of input data,
and apply every value in the list to 1 spatial dimension. if only 1 value provided,
use it for all spatial dimensions.
rescale_min: minimum value of output data.
rescale_max: maximum value of output data.

"""

def __init__(
self,
keys: KeysCollection,
label_key: str,
background: int = 0,
pert: float = 0.0,
sigma: Union[Sequence[float], float, Sequence[torch.Tensor], torch.Tensor] = 3.0,
rescale_min: float = -1.0,
rescale_max: float = 1.0,
):
super().__init__(keys)
self.label_key = label_key
self.add_extreme_points_channel = AddExtremePointsChannel(background=background, pert=pert)
self.sigma = sigma
self.rescale_min = rescale_min
self.rescale_max = rescale_max

def __call__(self, data):
d = dict(data)
label = d[self.label_key]

for key in data.keys():
if key in self.keys:
img = d[key]
d[key] = self.add_extreme_points_channel(
img, label=label, sigma=self.sigma, rescale_min=self.rescale_min, rescale_max=self.rescale_max
)
return d


IdentityD = IdentityDict = Identityd
AsChannelFirstD = AsChannelFirstDict = AsChannelFirstd
AsChannelLastD = AsChannelLastDict = AsChannelLastd
Expand All @@ -680,3 +729,4 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda
ConvertToMultiChannelBasedOnBratsClassesD = (
ConvertToMultiChannelBasedOnBratsClassesDict
) = ConvertToMultiChannelBasedOnBratsClassesd
AddExtremePointsChannelD = AddExtremePointsChannelDict = AddExtremePointsChanneld
53 changes: 53 additions & 0 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,3 +571,56 @@ def get_largest_connected_component_mask(img: torch.Tensor, connectivity: Option
if item.max() != 0:
largest_cc[i, ...] = item == (np.argmax(np.bincount(item.flat)[1:]) + 1)
return torch.as_tensor(largest_cc, device=img.device)


def get_extreme_points(
img: np.ndarray, rand_state: np.random.RandomState = np.random, background: int = 0, pert: float = 0.0
) -> List[Tuple[int, ...]]:
"""
Generate extreme points from an image. These are used to generate initial segmentation
for annotation models. An optional perturbation can be passed to simulate user clicks.

Args:
img:
Image to generate extreme points from. Expected Shape is ``(spatial_dim1, [, spatial_dim2, ...])``.
rand_state: `np.random.RandomState` object used to select random indices.
background: Value to be consider as background, defaults to 0.
pert: Random perturbation amount to add to the points, defaults to 0.0.

Returns:
A list of extreme points, its length is equal to 2 * spatial dimension of input image.
The output format of the coordinates is:

[1st_spatial_dim_min, 1st_spatial_dim_max, 2nd_spatial_dim_min, ..., Nth_spatial_dim_max]

Raises:
ValueError: When the input image does not have any foreground pixel.
"""
indices = np.where(img != background)
if np.size(indices[0]) == 0:
raise ValueError("get_extreme_points: no foreground object in mask!")

def _get_point(val, dim):
"""
Select one of the indices within slice containing val.

Args:
val : value for comparison
dim : dimension in which to look for value
"""
idx = rand_state.choice(np.where(indices[dim] == val)[0])
pt = []
for j in range(img.ndim):
# add +- pert to each dimension
val = int(indices[j][idx] + 2.0 * pert * (rand_state.rand() - 0.5))
val = max(val, 0)
val = min(val, img.shape[j] - 1)
pt.append(val)
return pt

points = []
for i in range(img.ndim):
points.append(tuple(_get_point(np.min(indices[i][...]), i)))
points.append(tuple(_get_point(np.max(indices[i][...]), i)))

return points
67 changes: 67 additions & 0 deletions tests/test_add_extreme_points_channel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
from parameterized import parameterized

from monai.transforms import AddExtremePointsChannel

IMG_CHANNEL = 3

TEST_CASE_1 = [
{
"img": np.zeros((IMG_CHANNEL, 4, 3)),
"label": np.array([[[0, 1, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]]]),
"sigma": 1.0,
"rescale_min": 0.0,
"rescale_max": 1.0,
},
np.array(
[
[0.38318458, 0.98615628, 0.85551184],
[0.35422316, 0.94430935, 1.0],
[0.46000731, 0.57319659, 0.46000722],
[0.64577687, 0.38318464, 0.0],
]
),
]

TEST_CASE_2 = [
{
"img": np.zeros((IMG_CHANNEL, 4, 3)),
"label": np.array([[[0, 1, 0], [1, 1, 1], [0, 1, 0], [0, 1, 0]]]),
"sigma": 1.0,
"rescale_min": 0.0,
"rescale_max": 1.0,
},
np.array(
[
[0.44628328, 0.80495411, 0.44628328],
[0.6779086, 1.0, 0.67790854],
[0.33002687, 0.62079221, 0.33002687],
[0.0, 0.31848389, 0.0],
]
),
]


class TestAddExtremePointsChannel(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
def test_type_shape(self, input_data, expected):
add_extreme_points_channel = AddExtremePointsChannel()
result = add_extreme_points_channel(**input_data)
np.testing.assert_allclose(result[IMG_CHANNEL], expected, rtol=1e-4)


if __name__ == "__main__":
unittest.main()
57 changes: 57 additions & 0 deletions tests/test_add_extreme_points_channeld.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
from parameterized import parameterized

from monai.transforms import AddExtremePointsChanneld

IMG_CHANNEL = 3

TEST_CASE_1 = [
{"img": np.zeros((IMG_CHANNEL, 4, 3)), "label": np.array([[[0, 1, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]]])},
np.array(
[
[0.38318458, 0.98615628, 0.85551184],
[0.35422316, 0.94430935, 1.0],
[0.46000731, 0.57319659, 0.46000722],
[0.64577687, 0.38318464, 0.0],
]
),
]

TEST_CASE_2 = [
{"img": np.zeros((IMG_CHANNEL, 4, 3)), "label": np.array([[[0, 1, 0], [1, 1, 1], [0, 1, 0], [0, 1, 0]]])},
np.array(
[
[0.44628328, 0.80495411, 0.44628328],
[0.6779086, 1.0, 0.67790854],
[0.33002687, 0.62079221, 0.33002687],
[0.0, 0.31848389, 0.0],
]
),
]


class TestAddExtremePointsChanneld(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
def test_type_shape(self, input_data, expected):
add_extreme_points_channel = AddExtremePointsChanneld(
keys="img", label_key="label", sigma=1.0, rescale_min=0.0, rescale_max=1.0
)
result = add_extreme_points_channel(input_data)
np.testing.assert_allclose(result["img"][IMG_CHANNEL], expected, rtol=1e-4)


if __name__ == "__main__":
unittest.main()
48 changes: 48 additions & 0 deletions tests/test_get_extreme_points.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
from parameterized import parameterized

from monai.transforms import get_extreme_points

TEST_CASE_1 = [
{
"img": np.array([[0, 1, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]]),
"rand_state": np.random,
"background": 0,
"pert": 0.0,
},
[(0, 1), (3, 0), (3, 0), (1, 2)],
]

TEST_CASE_2 = [
{
"img": np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0], [0, 1, 0]]),
"rand_state": np.random,
"background": 0,
"pert": 0.0,
},
[(0, 1), (3, 1), (1, 0), (1, 2)],
]


class TestGetExtremePoints(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
def test_type_shape(self, input_data, expected):
result = get_extreme_points(**input_data)
self.assertEqual(result, expected)


if __name__ == "__main__":
unittest.main()