diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index f7e075f376..4fad271109 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -504,6 +504,12 @@ Utility :members: :special-members: __call__ +`TorchVision` +""""""""""""" +.. autoclass:: TorchVision + :members: + :special-members: __call__ + Dictionary Transforms --------------------- @@ -969,6 +975,12 @@ Utility (Dict) :members: :special-members: __call__ +`TorchVisiond` +"""""""""""""" +.. autoclass:: TorchVisiond + :members: + :special-members: __call__ + Transform Adaptors ------------------ .. automodule:: monai.transforms.adaptors diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index cc9be79abd..c7b4c67488 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -209,6 +209,7 @@ SplitChannel, SqueezeDim, ToNumpy, + TorchVision, ToTensor, Transpose, ) @@ -233,6 +234,7 @@ SplitChanneld, SqueezeDimd, ToNumpyd, + TorchVisiond, ToTensord, ) from .utils import ( diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 4268df1e25..7e7fe816a9 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -22,7 +22,7 @@ from monai.transforms.compose import Randomizable, Transform from monai.transforms.utils import extreme_points_to_image, get_extreme_points, map_binary_to_indices -from monai.utils import ensure_tuple +from monai.utils import ensure_tuple, min_version, optional_import __all__ = [ "Identity", @@ -42,6 +42,7 @@ "LabelToMask", "FgBgToIndices", "AddExtremePointsChannel", + "TorchVision", ] # Generic type which can represent either a numpy.ndarray or a torch.Tensor @@ -615,3 +616,32 @@ def __call__( ) return np.concatenate([img, points_image], axis=0) + + +class TorchVision: + """ + This is a wrapper transform for PyTorch TorchVision transform based on the specified transform name and args. + As most of the TorchVision transforms only work for PIL image and PyTorch Tensor, this transform expects input + data to be PyTorch Tensor, users can easily call `ToTensor` transform to convert a Numpy array to Tensor. + + """ + + def __init__(self, name: str, *args, **kwargs) -> None: + """ + Args: + name: The transform name in TorchVision package. + args: parameters for the TorchVision transform. + kwargs: parameters for the TorchVision transform. + + """ + super().__init__() + transform, _ = optional_import("torchvision.transforms", "0.8.0", min_version, name=name) + self.trans = transform(*args, **kwargs) + + def __call__(self, img: torch.Tensor): + """ + Args: + img: PyTorch Tensor data for the TorchVision transform. + + """ + return self.trans(img) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index e822d5a289..5c08f72c92 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -39,6 +39,7 @@ SplitChannel, SqueezeDim, ToNumpy, + TorchVision, ToTensor, ) from monai.transforms.utils import extreme_points_to_image, get_extreme_points @@ -66,6 +67,7 @@ "FgBgToIndicesd", "ConvertToMultiChannelBasedOnBratsClassesd", "AddExtremePointsChanneld", + "TorchVisiond", ] @@ -732,6 +734,33 @@ def __call__(self, data): return d +class TorchVisiond(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.TorchVision`. + As most of the TorchVision transforms only work for PIL image and PyTorch Tensor, this transform expects input + data to be dict of PyTorch Tensors, users can easily call `ToTensord` transform to convert Numpy to Tensor. + """ + + def __init__(self, keys: KeysCollection, name: str, *args, **kwargs) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + name: The transform name in TorchVision package. + args: parameters for the TorchVision transform. + kwargs: parameters for the TorchVision transform. + + """ + super().__init__(keys) + self.trans = TorchVision(name, *args, **kwargs) + + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + d = dict(data) + for key in self.keys: + d[key] = self.trans(d[key]) + return d + + IdentityD = IdentityDict = Identityd AsChannelFirstD = AsChannelFirstDict = AsChannelFirstd AsChannelLastD = AsChannelLastDict = AsChannelLastd @@ -753,3 +782,4 @@ def __call__(self, data): ConvertToMultiChannelBasedOnBratsClassesDict ) = ConvertToMultiChannelBasedOnBratsClassesd AddExtremePointsChannelD = AddExtremePointsChannelDict = AddExtremePointsChanneld +TorchVisionD = TorchVisionDict = TorchVisiond diff --git a/tests/min_tests.py b/tests/min_tests.py index 32e61ed6f7..0e2e8d3917 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -102,6 +102,8 @@ def run_testsuit(): "test_zoom_affine", "test_zoomd", "test_occlusion_sensitivity", + "test_torchvision", + "test_torchvisiond", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_torchvision.py b/tests/test_torchvision.py new file mode 100644 index 0000000000..0846b7f6b6 --- /dev/null +++ b/tests/test_torchvision.py @@ -0,0 +1,86 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.transforms import TorchVision +from monai.utils import set_determinism +from tests.utils import SkipIfBeforePyTorchVersion + +TEST_CASE_1 = [ + {"name": "ColorJitter"}, + torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]), + torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]), +] + +TEST_CASE_2 = [ + {"name": "ColorJitter", "brightness": 0.5, "contrast": 0.5, "saturation": [0.1, 0.8], "hue": 0.5}, + torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]), + torch.tensor( + [ + [ + [0.1090, 0.6193], + [0.6193, 0.9164], + ], + [ + [0.1090, 0.6193], + [0.6193, 0.9164], + ], + [ + [0.1090, 0.6193], + [0.6193, 0.9164], + ], + ], + ), +] + +TEST_CASE_3 = [ + {"name": "Pad", "padding": [1, 1, 1, 1]}, + torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]), + torch.tensor( + [ + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 2.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 2.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 2.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + ] + ), +] + + +@SkipIfBeforePyTorchVersion((1, 7)) +class TestTorchVision(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_value(self, input_param, input_data, expected_value): + set_determinism(seed=0) + result = TorchVision(**input_param)(input_data) + torch.testing.assert_allclose(result, expected_value) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_torchvisiond.py b/tests/test_torchvisiond.py new file mode 100644 index 0000000000..4f42bc95f7 --- /dev/null +++ b/tests/test_torchvisiond.py @@ -0,0 +1,86 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.transforms import TorchVisiond +from monai.utils import set_determinism +from tests.utils import SkipIfBeforePyTorchVersion + +TEST_CASE_1 = [ + {"keys": "img", "name": "ColorJitter"}, + {"img": torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]])}, + torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]), +] + +TEST_CASE_2 = [ + {"keys": "img", "name": "ColorJitter", "brightness": 0.5, "contrast": 0.5, "saturation": [0.1, 0.8], "hue": 0.5}, + {"img": torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]])}, + torch.tensor( + [ + [ + [0.1090, 0.6193], + [0.6193, 0.9164], + ], + [ + [0.1090, 0.6193], + [0.6193, 0.9164], + ], + [ + [0.1090, 0.6193], + [0.6193, 0.9164], + ], + ], + ), +] + +TEST_CASE_3 = [ + {"keys": "img", "name": "Pad", "padding": [1, 1, 1, 1]}, + {"img": torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]])}, + torch.tensor( + [ + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 2.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 2.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 2.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + ] + ), +] + + +@SkipIfBeforePyTorchVersion((1, 7)) +class TestTorchVisiond(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_value(self, input_param, input_data, expected_value): + set_determinism(seed=0) + result = TorchVisiond(**input_param)(input_data) + torch.testing.assert_allclose(result["img"], expected_value) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/utils.py b/tests/utils.py index efedcdc859..20de17bbff 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -121,7 +121,7 @@ def __call__(self, obj): class SkipIfAtLeastPyTorchVersion(object): """Decorator to be used if test should be skipped - with PyTorch versions older than that given.""" + with PyTorch versions newer than that given.""" def __init__(self, pytorch_version_tuple): self.max_version = pytorch_version_tuple