diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 1223254db5..59712681ef 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -553,6 +553,7 @@ nonzero, percentile, ravel, + repeat, unravel_index, where, ) diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index be7724f1d9..15ae8fb111 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence, Union +from typing import Optional, Sequence, Union import numpy as np import torch @@ -35,6 +35,7 @@ "cumsum", "isfinite", "searchsorted", + "repeat", ] @@ -301,3 +302,10 @@ def searchsorted(a: NdarrayOrTensor, v: NdarrayOrTensor, right=False, sorter=Non ret = np.searchsorted(a.cpu().numpy(), v.cpu().numpy(), side, sorter) # type: ignore ret, *_ = convert_to_dst_type(ret, a) return ret + + +def repeat(a: NdarrayOrTensor, repeats: int, axis: Optional[int] = None): + """`np.repeat` with equivalent implementation for torch (`repeat_interleave`).""" + if isinstance(a, np.ndarray): + return np.repeat(a, repeats, axis) + return torch.repeat_interleave(a, repeats, dim=axis) diff --git a/monai/visualize/__init__.py b/monai/visualize/__init__.py index 19697253ed..4356b5f8d9 100644 --- a/monai/visualize/__init__.py +++ b/monai/visualize/__init__.py @@ -17,5 +17,5 @@ plot_2d_or_3d_image, ) from .occlusion_sensitivity import OcclusionSensitivity -from .utils import matshow3d +from .utils import blend_images, matshow3d from .visualizer import default_upsampler diff --git a/monai/visualize/utils.py b/monai/visualize/utils.py index daf9a1a3e5..c0304e9e68 100644 --- a/monai/visualize/utils.py +++ b/monai/visualize/utils.py @@ -13,13 +13,17 @@ import numpy as np +from monai.config.type_definitions import NdarrayOrTensor from monai.transforms.croppad.array import SpatialPad +from monai.transforms.utils import rescale_array +from monai.transforms.utils_pytorch_numpy_unification import repeat, where from monai.utils.module import optional_import -from monai.utils.type_conversion import convert_data_type +from monai.utils.type_conversion import convert_data_type, convert_to_dst_type plt, _ = optional_import("matplotlib", name="pyplot") +cm, _ = optional_import("matplotlib", name="cm") -__all__ = ["matshow3d"] +__all__ = ["matshow3d", "blend_images"] def matshow3d( @@ -122,3 +126,36 @@ def matshow3d( if show: plt.show() return fig, im + + +def blend_images( + image: NdarrayOrTensor, label: NdarrayOrTensor, alpha: float = 0.5, cmap: str = "hsv", rescale_arrays: bool = True +): + """Blend two images. Both should have the shape CHW[D]. + The image may have C==1 or 3 channels (greyscale or RGB). + The label is expected to have C==1.""" + if label.shape[0] != 1: + raise ValueError("Label should have 1 channel") + if image.shape[0] not in (1, 3): + raise ValueError("Image should have 1 or 3 channels") + # rescale arrays to [0, 1] if desired + if rescale_arrays: + image = rescale_array(image) + label = rescale_array(label) + # convert image to rgb (if necessary) and then rgba + if image.shape[0] == 1: + image = repeat(image, 3, axis=0) + + def get_label_rgb(cmap: str, label: NdarrayOrTensor): + _cmap = cm.get_cmap(cmap) + label_np: np.ndarray + label_np, *_ = convert_data_type(label, np.ndarray) # type: ignore + label_rgb_np = _cmap(label_np[0]) + label_rgb_np = np.moveaxis(label_rgb_np, -1, 0)[:3] + label_rgb, *_ = convert_to_dst_type(label_rgb_np, label) + return label_rgb + + label_rgb = get_label_rgb(cmap, label) + w_image = where(label == 0, 1.0, alpha) + w_label = where(label == 0, 0.0, 1 - alpha) + return w_image * image + w_label * label_rgb diff --git a/tests/test_blend_images.py b/tests/test_blend_images.py new file mode 100644 index 0000000000..67e967b89d --- /dev/null +++ b/tests/test_blend_images.py @@ -0,0 +1,53 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.case import skipUnless + +import torch +from parameterized import parameterized + +from monai.data.synthetic import create_test_image_2d, create_test_image_3d +from monai.transforms.utils_pytorch_numpy_unification import moveaxis +from monai.utils.module import optional_import +from monai.visualize.utils import blend_images +from tests.utils import TEST_NDARRAYS + +plt, has_matplotlib = optional_import("matplotlib.pyplot") + +TESTS = [] +for p in TEST_NDARRAYS: + image, label = create_test_image_2d(100, 101) + TESTS.append((p(image), p(label))) + + image, label = create_test_image_3d(100, 101, 102) + TESTS.append((p(image), p(label))) + + +@skipUnless(has_matplotlib, "Matplotlib required") +class TestBlendImages(unittest.TestCase): + @parameterized.expand(TESTS) + def test_blend(self, image, label): + blended = blend_images(image[None], label[None]) + self.assertEqual(type(image), type(blended)) + if isinstance(blended, torch.Tensor): + self.assertEqual(blended.device, image.device) + blended = blended.cpu().numpy() + self.assertEqual((3,) + image.shape, blended.shape) + + blended = moveaxis(blended, 0, -1) # move RGB component to end + if blended.ndim > 3: + blended = blended[blended.shape[0] // 2] + plt.imshow(blended) + + +if __name__ == "__main__": + unittest.main()