Skip to content
12 changes: 12 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,12 @@ Utility
:members:
:special-members: __call__

`TorchVision`
"""""""""""""
.. autoclass:: TorchVision
:members:
:special-members: __call__

Dictionary Transforms
---------------------

Expand Down Expand Up @@ -969,6 +975,12 @@ Utility (Dict)
:members:
:special-members: __call__

`TorchVisiond`
""""""""""""""
.. autoclass:: TorchVisiond
:members:
:special-members: __call__

Transform Adaptors
------------------
.. automodule:: monai.transforms.adaptors
Expand Down
2 changes: 2 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@
SplitChannel,
SqueezeDim,
ToNumpy,
TorchVision,
ToTensor,
Transpose,
)
Expand All @@ -233,6 +234,7 @@
SplitChanneld,
SqueezeDimd,
ToNumpyd,
TorchVisiond,
ToTensord,
)
from .utils import (
Expand Down
32 changes: 31 additions & 1 deletion monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from monai.transforms.compose import Randomizable, Transform
from monai.transforms.utils import extreme_points_to_image, get_extreme_points, map_binary_to_indices
from monai.utils import ensure_tuple
from monai.utils import ensure_tuple, min_version, optional_import

__all__ = [
"Identity",
Expand All @@ -42,6 +42,7 @@
"LabelToMask",
"FgBgToIndices",
"AddExtremePointsChannel",
"TorchVision",
]

# Generic type which can represent either a numpy.ndarray or a torch.Tensor
Expand Down Expand Up @@ -615,3 +616,32 @@ def __call__(
)

return np.concatenate([img, points_image], axis=0)


class TorchVision:
"""
This is a wrapper transform for PyTorch TorchVision transform based on the specified transform name and args.
As most of the TorchVision transforms only work for PIL image and PyTorch Tensor, this transform expects input
data to be PyTorch Tensor, users can easily call `ToTensor` transform to convert a Numpy array to Tensor.

"""

def __init__(self, name: str, *args, **kwargs) -> None:
"""
Args:
name: The transform name in TorchVision package.
args: parameters for the TorchVision transform.
kwargs: parameters for the TorchVision transform.

"""
super().__init__()
transform, _ = optional_import("torchvision.transforms", "0.8.0", min_version, name=name)
self.trans = transform(*args, **kwargs)

def __call__(self, img: torch.Tensor):
"""
Args:
img: PyTorch Tensor data for the TorchVision transform.

"""
return self.trans(img)
30 changes: 30 additions & 0 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
SplitChannel,
SqueezeDim,
ToNumpy,
TorchVision,
ToTensor,
)
from monai.transforms.utils import extreme_points_to_image, get_extreme_points
Expand Down Expand Up @@ -66,6 +67,7 @@
"FgBgToIndicesd",
"ConvertToMultiChannelBasedOnBratsClassesd",
"AddExtremePointsChanneld",
"TorchVisiond",
]


Expand Down Expand Up @@ -732,6 +734,33 @@ def __call__(self, data):
return d


class TorchVisiond(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.TorchVision`.
As most of the TorchVision transforms only work for PIL image and PyTorch Tensor, this transform expects input
data to be dict of PyTorch Tensors, users can easily call `ToTensord` transform to convert Numpy to Tensor.
"""

def __init__(self, keys: KeysCollection, name: str, *args, **kwargs) -> None:
"""
Args:
keys: keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
name: The transform name in TorchVision package.
args: parameters for the TorchVision transform.
kwargs: parameters for the TorchVision transform.

"""
super().__init__(keys)
self.trans = TorchVision(name, *args, **kwargs)

def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
d = dict(data)
for key in self.keys:
d[key] = self.trans(d[key])
return d


IdentityD = IdentityDict = Identityd
AsChannelFirstD = AsChannelFirstDict = AsChannelFirstd
AsChannelLastD = AsChannelLastDict = AsChannelLastd
Expand All @@ -753,3 +782,4 @@ def __call__(self, data):
ConvertToMultiChannelBasedOnBratsClassesDict
) = ConvertToMultiChannelBasedOnBratsClassesd
AddExtremePointsChannelD = AddExtremePointsChannelDict = AddExtremePointsChanneld
TorchVisionD = TorchVisionDict = TorchVisiond
2 changes: 2 additions & 0 deletions tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ def run_testsuit():
"test_zoom_affine",
"test_zoomd",
"test_occlusion_sensitivity",
"test_torchvision",
"test_torchvisiond",
]
assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}"

Expand Down
86 changes: 86 additions & 0 deletions tests/test_torchvision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch
from parameterized import parameterized

from monai.transforms import TorchVision
from monai.utils import set_determinism
from tests.utils import SkipIfBeforePyTorchVersion

TEST_CASE_1 = [
{"name": "ColorJitter"},
torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]),
torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]),
]

TEST_CASE_2 = [
{"name": "ColorJitter", "brightness": 0.5, "contrast": 0.5, "saturation": [0.1, 0.8], "hue": 0.5},
torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]),
torch.tensor(
[
[
[0.1090, 0.6193],
[0.6193, 0.9164],
],
[
[0.1090, 0.6193],
[0.6193, 0.9164],
],
[
[0.1090, 0.6193],
[0.6193, 0.9164],
],
],
),
]

TEST_CASE_3 = [
{"name": "Pad", "padding": [1, 1, 1, 1]},
torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]),
torch.tensor(
[
[
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 1.0, 2.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
],
[
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 1.0, 2.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
],
[
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 1.0, 2.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
],
]
),
]


@SkipIfBeforePyTorchVersion((1, 7))
class TestTorchVision(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_value(self, input_param, input_data, expected_value):
set_determinism(seed=0)
result = TorchVision(**input_param)(input_data)
torch.testing.assert_allclose(result, expected_value)


if __name__ == "__main__":
unittest.main()
86 changes: 86 additions & 0 deletions tests/test_torchvisiond.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch
from parameterized import parameterized

from monai.transforms import TorchVisiond
from monai.utils import set_determinism
from tests.utils import SkipIfBeforePyTorchVersion

TEST_CASE_1 = [
{"keys": "img", "name": "ColorJitter"},
{"img": torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]])},
torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]),
]

TEST_CASE_2 = [
{"keys": "img", "name": "ColorJitter", "brightness": 0.5, "contrast": 0.5, "saturation": [0.1, 0.8], "hue": 0.5},
{"img": torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]])},
torch.tensor(
[
[
[0.1090, 0.6193],
[0.6193, 0.9164],
],
[
[0.1090, 0.6193],
[0.6193, 0.9164],
],
[
[0.1090, 0.6193],
[0.6193, 0.9164],
],
],
),
]

TEST_CASE_3 = [
{"keys": "img", "name": "Pad", "padding": [1, 1, 1, 1]},
{"img": torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]])},
torch.tensor(
[
[
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 1.0, 2.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
],
[
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 1.0, 2.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
],
[
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 1.0, 2.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
],
]
),
]


@SkipIfBeforePyTorchVersion((1, 7))
class TestTorchVisiond(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_value(self, input_param, input_data, expected_value):
set_determinism(seed=0)
result = TorchVisiond(**input_param)(input_data)
torch.testing.assert_allclose(result["img"], expected_value)


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def __call__(self, obj):

class SkipIfAtLeastPyTorchVersion(object):
"""Decorator to be used if test should be skipped
with PyTorch versions older than that given."""
with PyTorch versions newer than that given."""

def __init__(self, pytorch_version_tuple):
self.max_version = pytorch_version_tuple
Expand Down