Skip to content
8 changes: 8 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@
SplitChannel,
SqueezeDim,
ToNumpy,
ToPIL,
TorchVision,
ToTensor,
Transpose,
Expand Down Expand Up @@ -323,7 +324,14 @@
SqueezeDimD,
SqueezeDimDict,
ToNumpyd,
ToNumpyD,
ToNumpyDict,
ToPILd,
ToPILD,
ToPILDict,
TorchVisiond,
TorchVisionD,
TorchVisionDict,
ToTensord,
ToTensorD,
ToTensorDict,
Expand Down
31 changes: 28 additions & 3 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import logging
import time
from typing import Callable, List, Optional, Sequence, Tuple, Union
from typing import TYPE_CHECKING, Callable, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
Expand All @@ -25,6 +25,15 @@
from monai.transforms.utils import extreme_points_to_image, get_extreme_points, map_binary_to_indices
from monai.utils import ensure_tuple, min_version, optional_import

if TYPE_CHECKING:
from PIL.Image import Image as PILImageImage
from PIL.Image import fromarray as pil_image_fromarray

has_pil = True
else:
PILImageImage, has_pil = optional_import("PIL.Image", name="Image")
pil_image_fromarray, _ = optional_import("PIL.Image", name="fromarray")

__all__ = [
"Identity",
"AsChannelFirst",
Expand Down Expand Up @@ -265,7 +274,7 @@ class ToTensor(Transform):
Converts the input image to a tensor without applying any other transformations.
"""

def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
def __call__(self, img: Union[np.ndarray, torch.Tensor, PILImageImage]) -> torch.Tensor:
"""
Apply the transform to `img` and make it contiguous.
"""
Expand All @@ -279,7 +288,7 @@ class ToNumpy(Transform):
Converts the input data to numpy array, can support list or tuple of numbers and PyTorch Tensor.
"""

def __call__(self, img: Union[List, Tuple, np.ndarray, torch.Tensor]) -> np.ndarray:
def __call__(self, img: Union[List, Tuple, np.ndarray, torch.Tensor, PILImageImage]) -> np.ndarray:
"""
Apply the transform to `img` and make it contiguous.
"""
Expand All @@ -288,6 +297,22 @@ def __call__(self, img: Union[List, Tuple, np.ndarray, torch.Tensor]) -> np.ndar
return np.ascontiguousarray(img)


class ToPIL(Transform):
"""
Converts the input image (in the form of NumPy array or PyTorch Tensor) to PIL image
"""

def __call__(self, img: Union[np.ndarray, torch.Tensor, PILImageImage]) -> PILImageImage:
"""
Apply the transform to `img` and make it contiguous.
"""
if isinstance(img, PILImageImage):
return img
if isinstance(img, torch.Tensor):
img = img.detach().cpu().numpy()
return pil_image_fromarray(img)


class Transpose(Transform):
"""
Transposes the input image based on the given `indices` dimension ordering.
Expand Down
46 changes: 40 additions & 6 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import copy
import logging
from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union

import numpy as np
import torch
Expand All @@ -41,11 +41,19 @@
SplitChannel,
SqueezeDim,
ToNumpy,
ToPIL,
TorchVision,
ToTensor,
)
from monai.transforms.utils import extreme_points_to_image, get_extreme_points
from monai.utils import ensure_tuple, ensure_tuple_rep
from monai.utils import ensure_tuple, ensure_tuple_rep, optional_import

if TYPE_CHECKING:
from PIL.Image import Image as PILImageImage

has_pil = True
else:
PILImageImage, has_pil = optional_import("PIL.Image", name="Image")

__all__ = [
"Identityd",
Expand All @@ -58,6 +66,7 @@
"CastToTyped",
"ToTensord",
"ToNumpyd",
"ToPILd",
"DeleteItemsd",
"SelectItemsd",
"SqueezeDimd",
Expand Down Expand Up @@ -348,8 +357,8 @@ def __init__(self, keys: KeysCollection) -> None:
self.converter = ToTensor()

def __call__(
self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]]
) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]:
self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor, PILImageImage]]
) -> Dict[Hashable, Union[np.ndarray, torch.Tensor, PILImageImage]]:
d = dict(data)
for key in self.keys:
d[key] = self.converter(d[key])
Expand All @@ -371,8 +380,31 @@ def __init__(self, keys: KeysCollection) -> None:
self.converter = ToNumpy()

def __call__(
self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]]
) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]:
self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor, PILImageImage]]
) -> Dict[Hashable, Union[np.ndarray, torch.Tensor, PILImageImage]]:
d = dict(data)
for key in self.keys:
d[key] = self.converter(d[key])
return d


class ToPILd(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.ToNumpy`.
"""

def __init__(self, keys: KeysCollection) -> None:
"""
Args:
keys: keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
"""
super().__init__(keys)
self.converter = ToPIL()

def __call__(
self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor, PILImageImage]]
) -> Dict[Hashable, Union[np.ndarray, torch.Tensor, PILImageImage]]:
d = dict(data)
for key in self.keys:
d[key] = self.converter(d[key])
Expand Down Expand Up @@ -867,6 +899,8 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc
SplitChannelD = SplitChannelDict = SplitChanneld
CastToTypeD = CastToTypeDict = CastToTyped
ToTensorD = ToTensorDict = ToTensord
ToNumpyD = ToNumpyDict = ToNumpyd
ToPILD = ToPILDict = ToPILd
DeleteItemsD = DeleteItemsDict = DeleteItemsd
SqueezeDimD = SqueezeDimDict = SqueezeDimd
DataStatsD = DataStatsDict = DataStatsd
Expand Down
64 changes: 64 additions & 0 deletions tests/test_to_pil.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
from typing import TYPE_CHECKING
from unittest import skipUnless

import numpy as np
import torch
from parameterized import parameterized

from monai.transforms import ToPIL
from monai.utils import optional_import

if TYPE_CHECKING:
from PIL.Image import Image as PILImageImage
from PIL.Image import fromarray as pil_image_fromarray

has_pil = True
else:
pil_image_fromarray, has_pil = optional_import("PIL.Image", name="fromarray")
PILImageImage, _ = optional_import("PIL.Image", name="Image")

TEST_CASE_ARRAY_1 = [np.array([[1.0, 2.0], [3.0, 4.0]])]
TEST_CASE_TENSOR_1 = [torch.tensor([[1.0, 2.0], [3.0, 4.0]])]


class TestToPIL(unittest.TestCase):
@parameterized.expand([TEST_CASE_ARRAY_1])
@skipUnless(has_pil, "Requires `pillow` package.")
def test_numpy_input(self, test_data):
self.assertTrue(isinstance(test_data, np.ndarray))
result = ToPIL()(test_data)
self.assertTrue(isinstance(result, PILImageImage))
np.testing.assert_allclose(np.array(result), test_data)

@parameterized.expand([TEST_CASE_TENSOR_1])
@skipUnless(has_pil, "Requires `pillow` package.")
def test_tensor_input(self, test_data):
self.assertTrue(isinstance(test_data, torch.Tensor))
result = ToPIL()(test_data)
self.assertTrue(isinstance(result, PILImageImage))
np.testing.assert_allclose(np.array(result), test_data.numpy())

@parameterized.expand([TEST_CASE_ARRAY_1])
@skipUnless(has_pil, "Requires `pillow` package.")
def test_pil_input(self, test_data):
test_data_pil = pil_image_fromarray(test_data)
self.assertTrue(isinstance(test_data_pil, PILImageImage))
result = ToPIL()(test_data_pil)
self.assertTrue(isinstance(result, PILImageImage))
np.testing.assert_allclose(np.array(result), test_data)


if __name__ == "__main__":
unittest.main()
65 changes: 65 additions & 0 deletions tests/test_to_pild.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
from typing import TYPE_CHECKING
from unittest import skipUnless

import numpy as np
import torch
from parameterized import parameterized

from monai.transforms import ToPILd
from monai.utils import optional_import

if TYPE_CHECKING:
from PIL.Image import Image as PILImageImage
from PIL.Image import fromarray as pil_image_fromarray

has_pil = True
else:
pil_image_fromarray, has_pil = optional_import("PIL.Image", name="fromarray")
PILImageImage, _ = optional_import("PIL.Image", name="Image")

TEST_CASE_ARRAY_1 = [{"keys": "image"}, {"image": np.array([[1.0, 2.0], [3.0, 4.0]])}]
TEST_CASE__TENSOR_1 = [{"keys": "image"}, {"image": torch.tensor([[1.0, 2.0], [3.0, 4.0]])}]


class TestToPIL(unittest.TestCase):
@parameterized.expand([TEST_CASE_ARRAY_1])
@skipUnless(has_pil, "Requires `pillow` package.")
def test_numpy_input(self, input_param, test_data):
self.assertTrue(isinstance(test_data[input_param["keys"]], np.ndarray))
result = ToPILd(**input_param)(test_data)[input_param["keys"]]
self.assertTrue(isinstance(result, PILImageImage))
np.testing.assert_allclose(np.array(result), test_data[input_param["keys"]])

@parameterized.expand([TEST_CASE__TENSOR_1])
@skipUnless(has_pil, "Requires `pillow` package.")
def test_tensor_input(self, input_param, test_data):
self.assertTrue(isinstance(test_data[input_param["keys"]], torch.Tensor))
result = ToPILd(**input_param)(test_data)[input_param["keys"]]
self.assertTrue(isinstance(result, PILImageImage))
np.testing.assert_allclose(np.array(result), test_data[input_param["keys"]].numpy())

@parameterized.expand([TEST_CASE_ARRAY_1])
@skipUnless(has_pil, "Requires `pillow` package.")
def test_pil_input(self, input_param, test_data):
input_array = test_data[input_param["keys"]]
test_data[input_param["keys"]] = pil_image_fromarray(input_array)
self.assertTrue(isinstance(test_data[input_param["keys"]], PILImageImage))
result = ToPILd(**input_param)(test_data)[input_param["keys"]]
self.assertTrue(isinstance(result, PILImageImage))
np.testing.assert_allclose(np.array(result), test_data[input_param["keys"]])


if __name__ == "__main__":
unittest.main()