Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1418,3 +1418,6 @@ Utilities
---------
.. automodule:: monai.transforms.utils
:members:

.. automodule:: monai.transforms.utils_pytorch_numpy_unification
:members:
22 changes: 9 additions & 13 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
map_binary_to_indices,
map_classes_to_indices,
)
from monai.transforms.utils_pytorch_numpy_unification import in1d, moveaxis
from monai.transforms.utils_pytorch_numpy_unification import in1d, moveaxis, unravel_indices
from monai.utils import (
convert_data_type,
convert_to_cupy,
Expand Down Expand Up @@ -789,16 +789,18 @@ class FgBgToIndices(Transform):

"""

backend = [TransformBackends.NUMPY, TransformBackends.TORCH]

def __init__(self, image_threshold: float = 0.0, output_shape: Optional[Sequence[int]] = None) -> None:
self.image_threshold = image_threshold
self.output_shape = output_shape

def __call__(
self,
label: np.ndarray,
image: Optional[np.ndarray] = None,
label: NdarrayOrTensor,
image: Optional[NdarrayOrTensor] = None,
output_shape: Optional[Sequence[int]] = None,
) -> Tuple[np.ndarray, np.ndarray]:
) -> Tuple[NdarrayOrTensor, NdarrayOrTensor]:
"""
Args:
label: input data to compute foreground and background indices.
Expand All @@ -807,18 +809,12 @@ def __call__(
output_shape: expected shape of output indices. if None, use `self.output_shape` instead.

"""
fg_indices: np.ndarray
bg_indices: np.ndarray
label, *_ = convert_data_type(label, np.ndarray) # type: ignore
if image is not None:
image, *_ = convert_data_type(image, np.ndarray) # type: ignore
if output_shape is None:
output_shape = self.output_shape
fg_indices, bg_indices = map_binary_to_indices(label, image, self.image_threshold) # type: ignore
fg_indices, bg_indices = map_binary_to_indices(label, image, self.image_threshold)
if output_shape is not None:
fg_indices = np.stack([np.unravel_index(i, output_shape) for i in fg_indices])
bg_indices = np.stack([np.unravel_index(i, output_shape) for i in bg_indices])

fg_indices = unravel_indices(fg_indices, output_shape)
bg_indices = unravel_indices(bg_indices, output_shape)
return fg_indices, bg_indices


Expand Down
4 changes: 3 additions & 1 deletion monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,6 +1119,8 @@ class FgBgToIndicesd(MapTransform):

"""

backend = FgBgToIndices.backend

def __init__(
self,
keys: KeysCollection,
Expand All @@ -1135,7 +1137,7 @@ def __init__(
self.image_key = image_key
self.converter = FgBgToIndices(image_threshold, output_shape)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
image = d[self.image_key] if self.image_key else None
for key in self.key_iterator(d):
Expand Down
4 changes: 1 addition & 3 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,9 +286,7 @@ def map_binary_to_indices(
fg_indices = nonzero(label_flat)
if image is not None:
img_flat = ravel(any_np_pt(image > image_threshold, 0))
img_flat, *_ = convert_data_type(
img_flat, type(label), device=label.device if isinstance(label, torch.Tensor) else None
)
img_flat, *_ = convert_to_dst_type(img_flat, label, dtype=img_flat.dtype)
bg_indices = nonzero(img_flat & ~label_flat)
else:
bg_indices = nonzero(~label_flat)
Expand Down
26 changes: 20 additions & 6 deletions monai/transforms/utils_pytorch_numpy_unification.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"nonzero",
"floor_divide",
"unravel_index",
"unravel_indices",
"ravel",
"any_np_pt",
"maximum",
Expand Down Expand Up @@ -91,9 +92,8 @@ def percentile(x: NdarrayOrTensor, q) -> Union[NdarrayOrTensor, float, int]:
if np.isscalar(q):
if not 0 <= q <= 100:
raise ValueError
else:
if any(q < 0) or any(q > 100):
raise ValueError
elif any(q < 0) or any(q > 100):
raise ValueError
result: Union[NdarrayOrTensor, float, int]
if isinstance(x, np.ndarray):
result = np.percentile(x, q)
Expand Down Expand Up @@ -167,20 +167,34 @@ def unravel_index(idx, shape):

Args:
idx: index to unravel
b: shape of array/tensor
shape: shape of array/tensor

Returns:
Index unravelled for given shape
"""
if isinstance(idx, torch.Tensor):
coord = []
for dim in reversed(shape):
coord.insert(0, idx % dim)
coord.append(idx % dim)
idx = floor_divide(idx, dim)
return torch.stack(coord)
return torch.stack(coord[::-1])
return np.unravel_index(np.asarray(idx, dtype=int), shape)


def unravel_indices(idx, shape):
"""Computing unravel cooridnates from indices.

Args:
idx: a sequence of indices to unravel
shape: shape of array/tensor

Returns:
Stacked indices unravelled for given shape
"""
lib_stack = torch.stack if isinstance(idx[0], torch.Tensor) else np.stack
return lib_stack([unravel_index(i, shape) for i in idx])


def ravel(x: NdarrayOrTensor):
"""`np.ravel` with equivalent implementation for torch.

Expand Down
5 changes: 1 addition & 4 deletions monai/utils/type_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,10 +270,7 @@ def convert_to_dst_type(
See Also:
:func:`convert_data_type`
"""
device = None
if isinstance(dst, torch.Tensor):
device = dst.device

device = dst.device if isinstance(dst, torch.Tensor) else None
if dtype is None:
dtype = dst.dtype

Expand Down
90 changes: 51 additions & 39 deletions tests/test_fg_bg_to_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,58 +11,70 @@

import unittest

import numpy as np
from parameterized import parameterized

from monai.transforms import FgBgToIndices
from tests.utils import TEST_NDARRAYS, assert_allclose

TEST_CASE_1 = [
{"image_threshold": 0.0, "output_shape": None},
np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]),
None,
np.array([1, 2, 3, 5, 6, 7]),
np.array([0, 4, 8]),
]
TESTS_CASES = []
for p in TEST_NDARRAYS:
TESTS_CASES.append(
[
{"image_threshold": 0.0, "output_shape": None},
p([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]),
None,
p([1, 2, 3, 5, 6, 7]),
p([0, 4, 8]),
]
)

TEST_CASE_2 = [
{"image_threshold": 0.0, "output_shape": None},
np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]),
np.array([[[1, 1, 1], [1, 0, 1], [1, 1, 1]]]),
np.array([1, 2, 3, 5, 6, 7]),
np.array([0, 8]),
]
TESTS_CASES.append(
[
{"image_threshold": 0.0, "output_shape": None},
p([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]),
p([[[1, 1, 1], [1, 0, 1], [1, 1, 1]]]),
p([1, 2, 3, 5, 6, 7]),
p([0, 8]),
]
)

TEST_CASE_3 = [
{"image_threshold": 1.0, "output_shape": None},
np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]),
np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]]),
np.array([1, 2, 3, 5, 6, 7]),
np.array([0, 8]),
]
TESTS_CASES.append(
[
{"image_threshold": 1.0, "output_shape": None},
p([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]),
p([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]]),
p([1, 2, 3, 5, 6, 7]),
p([0, 8]),
]
)

TEST_CASE_4 = [
{"image_threshold": 1.0, "output_shape": None},
np.array([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]),
np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]]),
np.array([1, 2, 3, 5, 6, 7]),
np.array([0, 8]),
]
TESTS_CASES.append(
[
{"image_threshold": 1.0, "output_shape": None},
p([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]),
p([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]]),
p([1, 2, 3, 5, 6, 7]),
p([0, 8]),
]
)

TEST_CASE_5 = [
{"image_threshold": 1.0, "output_shape": [3, 3]},
np.array([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]),
np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]]),
np.array([[0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1]]),
np.array([[0, 0], [2, 2]]),
]
TESTS_CASES.append(
[
{"image_threshold": 1.0, "output_shape": [3, 3]},
p([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]),
p([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]]),
p([[0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1]]),
p([[0, 0], [2, 2]]),
]
)


class TestFgBgToIndices(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])
@parameterized.expand(TESTS_CASES)
def test_type_shape(self, input_data, label, image, expected_fg, expected_bg):
fg_indices, bg_indices = FgBgToIndices(**input_data)(label, image)
np.testing.assert_allclose(fg_indices, expected_fg)
np.testing.assert_allclose(bg_indices, expected_bg)
assert_allclose(fg_indices, expected_fg)
assert_allclose(bg_indices, expected_bg)


if __name__ == "__main__":
Expand Down
81 changes: 47 additions & 34 deletions tests/test_fg_bg_to_indicesd.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,53 +11,66 @@

import unittest

import numpy as np
from parameterized import parameterized

from monai.transforms import FgBgToIndicesd
from tests.utils import TEST_NDARRAYS, assert_allclose

TEST_CASE_1 = [
{"keys": "label", "image_key": None, "image_threshold": 0.0, "output_shape": None},
{"label": np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]])},
np.array([1, 2, 3, 5, 6, 7]),
np.array([0, 4, 8]),
]
TEST_CASES = []
for p in TEST_NDARRAYS:

TEST_CASE_2 = [
{"keys": "label", "image_key": "image", "image_threshold": 0.0, "output_shape": None},
{"label": np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), "image": np.array([[[1, 1, 1], [1, 0, 1], [1, 1, 1]]])},
np.array([1, 2, 3, 5, 6, 7]),
np.array([0, 8]),
]
TEST_CASES.append(
[
{"keys": "label", "image_key": None, "image_threshold": 0.0, "output_shape": None},
{"label": p([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]])},
p([1, 2, 3, 5, 6, 7]),
p([0, 4, 8]),
]
)

TEST_CASE_3 = [
{"keys": "label", "image_key": "image", "image_threshold": 1.0, "output_shape": None},
{"label": np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), "image": np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])},
np.array([1, 2, 3, 5, 6, 7]),
np.array([0, 8]),
]
TEST_CASES.append(
[
{"keys": "label", "image_key": "image", "image_threshold": 0.0, "output_shape": None},
{"label": p([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), "image": p([[[1, 1, 1], [1, 0, 1], [1, 1, 1]]])},
p([1, 2, 3, 5, 6, 7]),
p([0, 8]),
]
)

TEST_CASE_4 = [
{"keys": "label", "image_key": "image", "image_threshold": 1.0, "output_shape": None},
{"label": np.array([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]), "image": np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])},
np.array([1, 2, 3, 5, 6, 7]),
np.array([0, 8]),
]
TEST_CASES.append(
[
{"keys": "label", "image_key": "image", "image_threshold": 1.0, "output_shape": None},
{"label": p([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), "image": p([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])},
p([1, 2, 3, 5, 6, 7]),
p([0, 8]),
]
)

TEST_CASE_5 = [
{"keys": "label", "image_key": "image", "image_threshold": 1.0, "output_shape": [3, 3]},
{"label": np.array([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]), "image": np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])},
np.array([[0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1]]),
np.array([[0, 0], [2, 2]]),
]
TEST_CASES.append(
[
{"keys": "label", "image_key": "image", "image_threshold": 1.0, "output_shape": None},
{"label": p([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]), "image": p([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])},
p([1, 2, 3, 5, 6, 7]),
p([0, 8]),
]
)

TEST_CASES.append(
[
{"keys": "label", "image_key": "image", "image_threshold": 1.0, "output_shape": [3, 3]},
{"label": p([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]), "image": p([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])},
p([[0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1]]),
p([[0, 0], [2, 2]]),
]
)


class TestFgBgToIndicesd(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])
@parameterized.expand(TEST_CASES)
def test_type_shape(self, input_data, data, expected_fg, expected_bg):
result = FgBgToIndicesd(**input_data)(data)
np.testing.assert_allclose(result["label_fg_indices"], expected_fg)
np.testing.assert_allclose(result["label_bg_indices"], expected_bg)
assert_allclose(result["label_fg_indices"], expected_fg)
assert_allclose(result["label_bg_indices"], expected_bg)


if __name__ == "__main__":
Expand Down