Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 4 additions & 17 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@
import numpy as np
import torch

from monai.networks.layers import GaussianFilter
from monai.transforms.compose import Randomizable, Transform
from monai.transforms.utils import get_extreme_points, map_binary_to_indices
from monai.transforms.utils import extreme_points_to_image, get_extreme_points, map_binary_to_indices
from monai.utils import ensure_tuple

# Generic type which can represent either a numpy.ndarray or a torch.Tensor
Expand Down Expand Up @@ -593,20 +592,8 @@ def __call__(
# Generate extreme points
self.randomize(label[0, :])

# points to image
points_image = torch.zeros(label.shape[1:], dtype=torch.float)
for p in self._points:
points_image[p] = 1.0

# add channel and add batch
points_image = points_image.unsqueeze(0).unsqueeze(0)
gaussian_filter = GaussianFilter(img.ndim - 1, sigma=sigma)
points_image = gaussian_filter(points_image).squeeze(0).detach().numpy()

# rescale the points image to [rescale_min, rescale_max]
min_intensity = np.min(points_image)
max_intensity = np.max(points_image)
points_image = (points_image - min_intensity) / (max_intensity - min_intensity)
points_image = points_image * (rescale_max - rescale_min) + rescale_min
points_image = extreme_points_to_image(
points=self._points, label=label, sigma=sigma, rescale_min=rescale_min, rescale_max=rescale_max
)

return np.concatenate([img, points_image], axis=0)
29 changes: 22 additions & 7 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@

import copy
import logging
from typing import Callable, Dict, Hashable, Mapping, Optional, Sequence, Union
from typing import Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union

import numpy as np
import torch

from monai.config import KeysCollection
from monai.transforms.compose import MapTransform
from monai.transforms import extreme_points_to_image, get_extreme_points
from monai.transforms.compose import MapTransform, Randomizable
from monai.transforms.utility.array import (
AddChannel,
AddExtremePointsChannel,
AsChannelFirst,
AsChannelLast,
CastToType,
Expand Down Expand Up @@ -661,7 +661,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda
return d


class AddExtremePointsChanneld(MapTransform):
class AddExtremePointsChanneld(Randomizable, MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.AddExtremePointsChannel`.

Expand Down Expand Up @@ -690,22 +690,37 @@ def __init__(
rescale_max: float = 1.0,
):
super().__init__(keys)
self.background = background
self.pert = pert
self.points: List[Tuple[int, ...]] = []
self.label_key = label_key
self.add_extreme_points_channel = AddExtremePointsChannel(background=background, pert=pert)
self.sigma = sigma
self.rescale_min = rescale_min
self.rescale_max = rescale_max

def randomize(self, label: np.ndarray) -> None:
self.points = get_extreme_points(label, rand_state=self.R, background=self.background, pert=self.pert)

def __call__(self, data):
d = dict(data)
label = d[self.label_key]
if label.shape[0] != 1:
raise ValueError("Only supports single channel labels!")

# Generate extreme points
self.randomize(label[0, :])

for key in data.keys():
if key in self.keys:
img = d[key]
d[key] = self.add_extreme_points_channel(
img, label=label, sigma=self.sigma, rescale_min=self.rescale_min, rescale_max=self.rescale_max
points_image = extreme_points_to_image(
points=self.points,
label=label,
sigma=self.sigma,
rescale_min=self.rescale_min,
rescale_max=self.rescale_max,
)
d[key] = np.concatenate([img, points_image], axis=0)
return d


Expand Down
42 changes: 42 additions & 0 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch

from monai.config import IndexSelection
from monai.networks.layers import GaussianFilter
from monai.utils import ensure_tuple, ensure_tuple_rep, ensure_tuple_size, fall_back_tuple, min_version, optional_import

measure, _ = optional_import("skimage.measure", "0.14.2", min_version)
Expand Down Expand Up @@ -620,3 +621,44 @@ def _get_point(val, dim):
points.append(tuple(_get_point(np.max(indices[i][...]), i)))

return points


def extreme_points_to_image(
points: List[Tuple[int, ...]],
label: np.ndarray,
sigma: Union[Sequence[float], float, Sequence[torch.Tensor], torch.Tensor] = 0.0,
rescale_min: float = -1.0,
rescale_max: float = 1.0,
):
"""
Please refer to :py:class:`monai.transforms.AddExtremePointsChannel` for the usage.

Applies a gaussian filter to the extreme points image. Then the pixel values in points image are rescaled
to range [rescale_min, rescale_max].

Args:
points: Extreme points of the object/organ.
label: label image to get extreme points from. Shape must be
(1, spatial_dim1, [, spatial_dim2, ...]). Doesn't support one-hot labels.
sigma: if a list of values, must match the count of spatial dimensions of input data,
and apply every value in the list to 1 spatial dimension. if only 1 value provided,
use it for all spatial dimensions.
rescale_min: minimum value of output data.
rescale_max: maximum value of output data.
"""
# points to image
points_image = torch.zeros(label.shape[1:], dtype=torch.float)
for p in points:
points_image[p] = 1.0

# add channel and add batch
points_image = points_image.unsqueeze(0).unsqueeze(0)
gaussian_filter = GaussianFilter(label.ndim - 1, sigma=sigma)
points_image = gaussian_filter(points_image).squeeze(0).detach().numpy()

# rescale the points image to [rescale_min, rescale_max]
min_intensity = np.min(points_image)
max_intensity = np.max(points_image)
points_image = (points_image - min_intensity) / (max_intensity - min_intensity)
points_image = points_image * (rescale_max - rescale_min) + rescale_min
return points_image
2 changes: 1 addition & 1 deletion tests/test_add_extreme_points_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@

class TestAddExtremePointsChannel(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
def test_type_shape(self, input_data, expected):
def test_correct_results(self, input_data, expected):
add_extreme_points_channel = AddExtremePointsChannel()
result = add_extreme_points_channel(**input_data)
np.testing.assert_allclose(result[IMG_CHANNEL], expected, rtol=1e-4)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_add_extreme_points_channeld.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@

class TestAddExtremePointsChanneld(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
def test_type_shape(self, input_data, expected):
def test_correct_results(self, input_data, expected):
add_extreme_points_channel = AddExtremePointsChanneld(
keys="img", label_key="label", sigma=1.0, rescale_min=0.0, rescale_max=1.0
)
Expand Down