diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 0b0f9f0792..4c3bff363d 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -1418,3 +1418,6 @@ Utilities --------- .. automodule:: monai.transforms.utils :members: + +.. automodule:: monai.transforms.utils_pytorch_numpy_unification + :members: diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 272c9963a8..8c15d0c148 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -31,7 +31,7 @@ map_binary_to_indices, map_classes_to_indices, ) -from monai.transforms.utils_pytorch_numpy_unification import in1d, moveaxis +from monai.transforms.utils_pytorch_numpy_unification import in1d, moveaxis, unravel_indices from monai.utils import ( convert_data_type, convert_to_cupy, @@ -789,16 +789,18 @@ class FgBgToIndices(Transform): """ + backend = [TransformBackends.NUMPY, TransformBackends.TORCH] + def __init__(self, image_threshold: float = 0.0, output_shape: Optional[Sequence[int]] = None) -> None: self.image_threshold = image_threshold self.output_shape = output_shape def __call__( self, - label: np.ndarray, - image: Optional[np.ndarray] = None, + label: NdarrayOrTensor, + image: Optional[NdarrayOrTensor] = None, output_shape: Optional[Sequence[int]] = None, - ) -> Tuple[np.ndarray, np.ndarray]: + ) -> Tuple[NdarrayOrTensor, NdarrayOrTensor]: """ Args: label: input data to compute foreground and background indices. @@ -807,18 +809,12 @@ def __call__( output_shape: expected shape of output indices. if None, use `self.output_shape` instead. """ - fg_indices: np.ndarray - bg_indices: np.ndarray - label, *_ = convert_data_type(label, np.ndarray) # type: ignore - if image is not None: - image, *_ = convert_data_type(image, np.ndarray) # type: ignore if output_shape is None: output_shape = self.output_shape - fg_indices, bg_indices = map_binary_to_indices(label, image, self.image_threshold) # type: ignore + fg_indices, bg_indices = map_binary_to_indices(label, image, self.image_threshold) if output_shape is not None: - fg_indices = np.stack([np.unravel_index(i, output_shape) for i in fg_indices]) - bg_indices = np.stack([np.unravel_index(i, output_shape) for i in bg_indices]) - + fg_indices = unravel_indices(fg_indices, output_shape) + bg_indices = unravel_indices(bg_indices, output_shape) return fg_indices, bg_indices diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 4fb07644a9..0be46fd02b 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -1119,6 +1119,8 @@ class FgBgToIndicesd(MapTransform): """ + backend = FgBgToIndices.backend + def __init__( self, keys: KeysCollection, @@ -1135,7 +1137,7 @@ def __init__( self.image_key = image_key self.converter = FgBgToIndices(image_threshold, output_shape) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) image = d[self.image_key] if self.image_key else None for key in self.key_iterator(d): diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 4927f9a478..01a62e36ff 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -286,9 +286,7 @@ def map_binary_to_indices( fg_indices = nonzero(label_flat) if image is not None: img_flat = ravel(any_np_pt(image > image_threshold, 0)) - img_flat, *_ = convert_data_type( - img_flat, type(label), device=label.device if isinstance(label, torch.Tensor) else None - ) + img_flat, *_ = convert_to_dst_type(img_flat, label, dtype=img_flat.dtype) bg_indices = nonzero(img_flat & ~label_flat) else: bg_indices = nonzero(~label_flat) diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index 4283c4a81f..3fe2402504 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -26,6 +26,7 @@ "nonzero", "floor_divide", "unravel_index", + "unravel_indices", "ravel", "any_np_pt", "maximum", @@ -91,9 +92,8 @@ def percentile(x: NdarrayOrTensor, q) -> Union[NdarrayOrTensor, float, int]: if np.isscalar(q): if not 0 <= q <= 100: raise ValueError - else: - if any(q < 0) or any(q > 100): - raise ValueError + elif any(q < 0) or any(q > 100): + raise ValueError result: Union[NdarrayOrTensor, float, int] if isinstance(x, np.ndarray): result = np.percentile(x, q) @@ -167,7 +167,7 @@ def unravel_index(idx, shape): Args: idx: index to unravel - b: shape of array/tensor + shape: shape of array/tensor Returns: Index unravelled for given shape @@ -175,12 +175,26 @@ def unravel_index(idx, shape): if isinstance(idx, torch.Tensor): coord = [] for dim in reversed(shape): - coord.insert(0, idx % dim) + coord.append(idx % dim) idx = floor_divide(idx, dim) - return torch.stack(coord) + return torch.stack(coord[::-1]) return np.unravel_index(np.asarray(idx, dtype=int), shape) +def unravel_indices(idx, shape): + """Computing unravel cooridnates from indices. + + Args: + idx: a sequence of indices to unravel + shape: shape of array/tensor + + Returns: + Stacked indices unravelled for given shape + """ + lib_stack = torch.stack if isinstance(idx[0], torch.Tensor) else np.stack + return lib_stack([unravel_index(i, shape) for i in idx]) + + def ravel(x: NdarrayOrTensor): """`np.ravel` with equivalent implementation for torch. diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index 87095fef99..648e68440e 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -270,10 +270,7 @@ def convert_to_dst_type( See Also: :func:`convert_data_type` """ - device = None - if isinstance(dst, torch.Tensor): - device = dst.device - + device = dst.device if isinstance(dst, torch.Tensor) else None if dtype is None: dtype = dst.dtype diff --git a/tests/test_fg_bg_to_indices.py b/tests/test_fg_bg_to_indices.py index 98626c7028..0d35dd23f8 100644 --- a/tests/test_fg_bg_to_indices.py +++ b/tests/test_fg_bg_to_indices.py @@ -11,58 +11,70 @@ import unittest -import numpy as np from parameterized import parameterized from monai.transforms import FgBgToIndices +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - {"image_threshold": 0.0, "output_shape": None}, - np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), - None, - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 4, 8]), -] +TESTS_CASES = [] +for p in TEST_NDARRAYS: + TESTS_CASES.append( + [ + {"image_threshold": 0.0, "output_shape": None}, + p([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), + None, + p([1, 2, 3, 5, 6, 7]), + p([0, 4, 8]), + ] + ) -TEST_CASE_2 = [ - {"image_threshold": 0.0, "output_shape": None}, - np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), - np.array([[[1, 1, 1], [1, 0, 1], [1, 1, 1]]]), - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 8]), -] + TESTS_CASES.append( + [ + {"image_threshold": 0.0, "output_shape": None}, + p([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), + p([[[1, 1, 1], [1, 0, 1], [1, 1, 1]]]), + p([1, 2, 3, 5, 6, 7]), + p([0, 8]), + ] + ) -TEST_CASE_3 = [ - {"image_threshold": 1.0, "output_shape": None}, - np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), - np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]]), - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 8]), -] + TESTS_CASES.append( + [ + {"image_threshold": 1.0, "output_shape": None}, + p([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), + p([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]]), + p([1, 2, 3, 5, 6, 7]), + p([0, 8]), + ] + ) -TEST_CASE_4 = [ - {"image_threshold": 1.0, "output_shape": None}, - np.array([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]), - np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]]), - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 8]), -] + TESTS_CASES.append( + [ + {"image_threshold": 1.0, "output_shape": None}, + p([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]), + p([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]]), + p([1, 2, 3, 5, 6, 7]), + p([0, 8]), + ] + ) -TEST_CASE_5 = [ - {"image_threshold": 1.0, "output_shape": [3, 3]}, - np.array([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]), - np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]]), - np.array([[0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1]]), - np.array([[0, 0], [2, 2]]), -] + TESTS_CASES.append( + [ + {"image_threshold": 1.0, "output_shape": [3, 3]}, + p([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]), + p([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]]), + p([[0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1]]), + p([[0, 0], [2, 2]]), + ] + ) class TestFgBgToIndices(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + @parameterized.expand(TESTS_CASES) def test_type_shape(self, input_data, label, image, expected_fg, expected_bg): fg_indices, bg_indices = FgBgToIndices(**input_data)(label, image) - np.testing.assert_allclose(fg_indices, expected_fg) - np.testing.assert_allclose(bg_indices, expected_bg) + assert_allclose(fg_indices, expected_fg) + assert_allclose(bg_indices, expected_bg) if __name__ == "__main__": diff --git a/tests/test_fg_bg_to_indicesd.py b/tests/test_fg_bg_to_indicesd.py index ce6ca30f1b..4691526d94 100644 --- a/tests/test_fg_bg_to_indicesd.py +++ b/tests/test_fg_bg_to_indicesd.py @@ -11,53 +11,66 @@ import unittest -import numpy as np from parameterized import parameterized from monai.transforms import FgBgToIndicesd +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - {"keys": "label", "image_key": None, "image_threshold": 0.0, "output_shape": None}, - {"label": np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]])}, - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 4, 8]), -] +TEST_CASES = [] +for p in TEST_NDARRAYS: -TEST_CASE_2 = [ - {"keys": "label", "image_key": "image", "image_threshold": 0.0, "output_shape": None}, - {"label": np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), "image": np.array([[[1, 1, 1], [1, 0, 1], [1, 1, 1]]])}, - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 8]), -] + TEST_CASES.append( + [ + {"keys": "label", "image_key": None, "image_threshold": 0.0, "output_shape": None}, + {"label": p([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]])}, + p([1, 2, 3, 5, 6, 7]), + p([0, 4, 8]), + ] + ) -TEST_CASE_3 = [ - {"keys": "label", "image_key": "image", "image_threshold": 1.0, "output_shape": None}, - {"label": np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), "image": np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])}, - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 8]), -] + TEST_CASES.append( + [ + {"keys": "label", "image_key": "image", "image_threshold": 0.0, "output_shape": None}, + {"label": p([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), "image": p([[[1, 1, 1], [1, 0, 1], [1, 1, 1]]])}, + p([1, 2, 3, 5, 6, 7]), + p([0, 8]), + ] + ) -TEST_CASE_4 = [ - {"keys": "label", "image_key": "image", "image_threshold": 1.0, "output_shape": None}, - {"label": np.array([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]), "image": np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])}, - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 8]), -] + TEST_CASES.append( + [ + {"keys": "label", "image_key": "image", "image_threshold": 1.0, "output_shape": None}, + {"label": p([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), "image": p([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])}, + p([1, 2, 3, 5, 6, 7]), + p([0, 8]), + ] + ) -TEST_CASE_5 = [ - {"keys": "label", "image_key": "image", "image_threshold": 1.0, "output_shape": [3, 3]}, - {"label": np.array([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]), "image": np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])}, - np.array([[0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1]]), - np.array([[0, 0], [2, 2]]), -] + TEST_CASES.append( + [ + {"keys": "label", "image_key": "image", "image_threshold": 1.0, "output_shape": None}, + {"label": p([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]), "image": p([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])}, + p([1, 2, 3, 5, 6, 7]), + p([0, 8]), + ] + ) + + TEST_CASES.append( + [ + {"keys": "label", "image_key": "image", "image_threshold": 1.0, "output_shape": [3, 3]}, + {"label": p([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]), "image": p([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])}, + p([[0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1]]), + p([[0, 0], [2, 2]]), + ] + ) class TestFgBgToIndicesd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + @parameterized.expand(TEST_CASES) def test_type_shape(self, input_data, data, expected_fg, expected_bg): result = FgBgToIndicesd(**input_data)(data) - np.testing.assert_allclose(result["label_fg_indices"], expected_fg) - np.testing.assert_allclose(result["label_bg_indices"], expected_bg) + assert_allclose(result["label_fg_indices"], expected_fg) + assert_allclose(result["label_bg_indices"], expected_bg) if __name__ == "__main__":