diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 3499afcf95..f5c7c826e9 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -256,6 +256,7 @@ SplitChannel, SqueezeDim, ToNumpy, + ToPIL, TorchVision, ToTensor, Transpose, @@ -323,7 +324,14 @@ SqueezeDimD, SqueezeDimDict, ToNumpyd, + ToNumpyD, + ToNumpyDict, + ToPILd, + ToPILD, + ToPILDict, TorchVisiond, + TorchVisionD, + TorchVisionDict, ToTensord, ToTensorD, ToTensorDict, diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index fb9ae3c089..0ee88e1a6c 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -15,7 +15,7 @@ import logging import time -from typing import Callable, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Callable, List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -25,6 +25,15 @@ from monai.transforms.utils import extreme_points_to_image, get_extreme_points, map_binary_to_indices from monai.utils import ensure_tuple, min_version, optional_import +if TYPE_CHECKING: + from PIL.Image import Image as PILImageImage + from PIL.Image import fromarray as pil_image_fromarray + + has_pil = True +else: + PILImageImage, has_pil = optional_import("PIL.Image", name="Image") + pil_image_fromarray, _ = optional_import("PIL.Image", name="fromarray") + __all__ = [ "Identity", "AsChannelFirst", @@ -265,7 +274,7 @@ class ToTensor(Transform): Converts the input image to a tensor without applying any other transformations. """ - def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: + def __call__(self, img: Union[np.ndarray, torch.Tensor, PILImageImage]) -> torch.Tensor: """ Apply the transform to `img` and make it contiguous. """ @@ -279,7 +288,7 @@ class ToNumpy(Transform): Converts the input data to numpy array, can support list or tuple of numbers and PyTorch Tensor. """ - def __call__(self, img: Union[List, Tuple, np.ndarray, torch.Tensor]) -> np.ndarray: + def __call__(self, img: Union[List, Tuple, np.ndarray, torch.Tensor, PILImageImage]) -> np.ndarray: """ Apply the transform to `img` and make it contiguous. """ @@ -288,6 +297,22 @@ def __call__(self, img: Union[List, Tuple, np.ndarray, torch.Tensor]) -> np.ndar return np.ascontiguousarray(img) +class ToPIL(Transform): + """ + Converts the input image (in the form of NumPy array or PyTorch Tensor) to PIL image + """ + + def __call__(self, img: Union[np.ndarray, torch.Tensor, PILImageImage]) -> PILImageImage: + """ + Apply the transform to `img` and make it contiguous. + """ + if isinstance(img, PILImageImage): + return img + if isinstance(img, torch.Tensor): + img = img.detach().cpu().numpy() + return pil_image_fromarray(img) + + class Transpose(Transform): """ Transposes the input image based on the given `indices` dimension ordering. diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 83426734eb..f9612c2408 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -17,7 +17,7 @@ import copy import logging -from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -41,11 +41,19 @@ SplitChannel, SqueezeDim, ToNumpy, + ToPIL, TorchVision, ToTensor, ) from monai.transforms.utils import extreme_points_to_image, get_extreme_points -from monai.utils import ensure_tuple, ensure_tuple_rep +from monai.utils import ensure_tuple, ensure_tuple_rep, optional_import + +if TYPE_CHECKING: + from PIL.Image import Image as PILImageImage + + has_pil = True +else: + PILImageImage, has_pil = optional_import("PIL.Image", name="Image") __all__ = [ "Identityd", @@ -58,6 +66,7 @@ "CastToTyped", "ToTensord", "ToNumpyd", + "ToPILd", "DeleteItemsd", "SelectItemsd", "SqueezeDimd", @@ -348,8 +357,8 @@ def __init__(self, keys: KeysCollection) -> None: self.converter = ToTensor() def __call__( - self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] - ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: + self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor, PILImageImage]] + ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor, PILImageImage]]: d = dict(data) for key in self.keys: d[key] = self.converter(d[key]) @@ -371,8 +380,31 @@ def __init__(self, keys: KeysCollection) -> None: self.converter = ToNumpy() def __call__( - self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] - ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: + self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor, PILImageImage]] + ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor, PILImageImage]]: + d = dict(data) + for key in self.keys: + d[key] = self.converter(d[key]) + return d + + +class ToPILd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.ToNumpy`. + """ + + def __init__(self, keys: KeysCollection) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + """ + super().__init__(keys) + self.converter = ToPIL() + + def __call__( + self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor, PILImageImage]] + ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor, PILImageImage]]: d = dict(data) for key in self.keys: d[key] = self.converter(d[key]) @@ -867,6 +899,8 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc SplitChannelD = SplitChannelDict = SplitChanneld CastToTypeD = CastToTypeDict = CastToTyped ToTensorD = ToTensorDict = ToTensord +ToNumpyD = ToNumpyDict = ToNumpyd +ToPILD = ToPILDict = ToPILd DeleteItemsD = DeleteItemsDict = DeleteItemsd SqueezeDimD = SqueezeDimDict = SqueezeDimd DataStatsD = DataStatsDict = DataStatsd diff --git a/tests/test_to_pil.py b/tests/test_to_pil.py new file mode 100644 index 0000000000..ec63750ce4 --- /dev/null +++ b/tests/test_to_pil.py @@ -0,0 +1,64 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from typing import TYPE_CHECKING +from unittest import skipUnless + +import numpy as np +import torch +from parameterized import parameterized + +from monai.transforms import ToPIL +from monai.utils import optional_import + +if TYPE_CHECKING: + from PIL.Image import Image as PILImageImage + from PIL.Image import fromarray as pil_image_fromarray + + has_pil = True +else: + pil_image_fromarray, has_pil = optional_import("PIL.Image", name="fromarray") + PILImageImage, _ = optional_import("PIL.Image", name="Image") + +TEST_CASE_ARRAY_1 = [np.array([[1.0, 2.0], [3.0, 4.0]])] +TEST_CASE_TENSOR_1 = [torch.tensor([[1.0, 2.0], [3.0, 4.0]])] + + +class TestToPIL(unittest.TestCase): + @parameterized.expand([TEST_CASE_ARRAY_1]) + @skipUnless(has_pil, "Requires `pillow` package.") + def test_numpy_input(self, test_data): + self.assertTrue(isinstance(test_data, np.ndarray)) + result = ToPIL()(test_data) + self.assertTrue(isinstance(result, PILImageImage)) + np.testing.assert_allclose(np.array(result), test_data) + + @parameterized.expand([TEST_CASE_TENSOR_1]) + @skipUnless(has_pil, "Requires `pillow` package.") + def test_tensor_input(self, test_data): + self.assertTrue(isinstance(test_data, torch.Tensor)) + result = ToPIL()(test_data) + self.assertTrue(isinstance(result, PILImageImage)) + np.testing.assert_allclose(np.array(result), test_data.numpy()) + + @parameterized.expand([TEST_CASE_ARRAY_1]) + @skipUnless(has_pil, "Requires `pillow` package.") + def test_pil_input(self, test_data): + test_data_pil = pil_image_fromarray(test_data) + self.assertTrue(isinstance(test_data_pil, PILImageImage)) + result = ToPIL()(test_data_pil) + self.assertTrue(isinstance(result, PILImageImage)) + np.testing.assert_allclose(np.array(result), test_data) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_to_pild.py b/tests/test_to_pild.py new file mode 100644 index 0000000000..43778022ee --- /dev/null +++ b/tests/test_to_pild.py @@ -0,0 +1,65 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from typing import TYPE_CHECKING +from unittest import skipUnless + +import numpy as np +import torch +from parameterized import parameterized + +from monai.transforms import ToPILd +from monai.utils import optional_import + +if TYPE_CHECKING: + from PIL.Image import Image as PILImageImage + from PIL.Image import fromarray as pil_image_fromarray + + has_pil = True +else: + pil_image_fromarray, has_pil = optional_import("PIL.Image", name="fromarray") + PILImageImage, _ = optional_import("PIL.Image", name="Image") + +TEST_CASE_ARRAY_1 = [{"keys": "image"}, {"image": np.array([[1.0, 2.0], [3.0, 4.0]])}] +TEST_CASE__TENSOR_1 = [{"keys": "image"}, {"image": torch.tensor([[1.0, 2.0], [3.0, 4.0]])}] + + +class TestToPIL(unittest.TestCase): + @parameterized.expand([TEST_CASE_ARRAY_1]) + @skipUnless(has_pil, "Requires `pillow` package.") + def test_numpy_input(self, input_param, test_data): + self.assertTrue(isinstance(test_data[input_param["keys"]], np.ndarray)) + result = ToPILd(**input_param)(test_data)[input_param["keys"]] + self.assertTrue(isinstance(result, PILImageImage)) + np.testing.assert_allclose(np.array(result), test_data[input_param["keys"]]) + + @parameterized.expand([TEST_CASE__TENSOR_1]) + @skipUnless(has_pil, "Requires `pillow` package.") + def test_tensor_input(self, input_param, test_data): + self.assertTrue(isinstance(test_data[input_param["keys"]], torch.Tensor)) + result = ToPILd(**input_param)(test_data)[input_param["keys"]] + self.assertTrue(isinstance(result, PILImageImage)) + np.testing.assert_allclose(np.array(result), test_data[input_param["keys"]].numpy()) + + @parameterized.expand([TEST_CASE_ARRAY_1]) + @skipUnless(has_pil, "Requires `pillow` package.") + def test_pil_input(self, input_param, test_data): + input_array = test_data[input_param["keys"]] + test_data[input_param["keys"]] = pil_image_fromarray(input_array) + self.assertTrue(isinstance(test_data[input_param["keys"]], PILImageImage)) + result = ToPILd(**input_param)(test_data)[input_param["keys"]] + self.assertTrue(isinstance(result, PILImageImage)) + np.testing.assert_allclose(np.array(result), test_data[input_param["keys"]]) + + +if __name__ == "__main__": + unittest.main()