From aff3502711ecc3205ba82bc04e7faa5c0457980d Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 27 Jul 2021 21:00:09 +0800 Subject: [PATCH 01/12] [DLMED] init the transform Signed-off-by: Nic Ma --- monai/transforms/spatial/array.py | 69 +++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 37dd9b47c6..b0f0fd5baa 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -1803,3 +1803,72 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]): # but user input is 1-based (because channel dim is 0) coord_channels = coord_channels[[s - 1 for s in self.spatial_channels]] return np.concatenate((img, coord_channels), axis=0) + + +class LongestRescale(Transform): + """ + Rescale an image so that maximum side is equal to specified spatial size, keeping the aspect ratio + of the initial image. Implemented using :py:class:`torch.nn.functional.interpolate`. + + Args: + spatial_size: expected shape of spatial dimensions after resize operation. + if the components of the `spatial_size` are non-positive values, the transform will use the + corresponding components of img size. For example, `spatial_size=(32, -1)` will be adapted + to `(32, 64)` if the second spatial dimension size of img is `64`. + mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} + The interpolation mode. Defaults to ``"area"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + align_corners: This only has an effect when mode is + 'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None. + See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + """ + + def __init__( + self, + spatial_size: Union[Sequence[int], int], + mode: Union[InterpolateMode, str] = InterpolateMode.AREA, + align_corners: Optional[bool] = None, + ) -> None: + self.spatial_size = ensure_tuple(spatial_size) + self.mode: InterpolateMode = look_up_option(mode, InterpolateMode) + self.align_corners = align_corners + + def __call__( + self, + img: np.ndarray, + mode: Optional[Union[InterpolateMode, str]] = None, + align_corners: Optional[bool] = None, + ) -> np.ndarray: + """ + Args: + img: channel first array, must have shape: (num_channels, H[, W, ..., ]). + mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} + The interpolation mode. Defaults to ``self.mode``. + See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + align_corners: This only has an effect when mode is + 'linear', 'bilinear', 'bicubic' or 'trilinear'. Defaults to ``self.align_corners``. + See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + + Raises: + ValueError: When ``self.spatial_size`` length is less than ``img`` spatial dimensions. + + """ + input_ndim = img.ndim - 1 # spatial ndim + output_ndim = len(self.spatial_size) + if output_ndim > input_ndim: + input_shape = ensure_tuple_size(img.shape, output_ndim + 1, 1) + img = img.reshape(input_shape) + elif output_ndim < input_ndim: + raise ValueError( + "len(spatial_size) must be greater or equal to img spatial dimensions, " + f"got spatial_size={output_ndim} img={input_ndim}." + ) + spatial_size = fall_back_tuple(self.spatial_size, img.shape[1:]) + resized = torch.nn.functional.interpolate( # type: ignore + input=torch.as_tensor(np.ascontiguousarray(img), dtype=torch.float).unsqueeze(0), + size=spatial_size, + mode=look_up_option(self.mode if mode is None else mode, InterpolateMode).value, + align_corners=self.align_corners if align_corners is None else align_corners, + ) + resized = resized.squeeze(0).detach().cpu().numpy() + return np.asarray(resized) From b987d4e1015f3fcaeebe611798bc8275786cfa08 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 27 Jul 2021 23:23:41 +0800 Subject: [PATCH 02/12] [DLMED] update doc-string Signed-off-by: Nic Ma --- monai/transforms/spatial/array.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index b0f0fd5baa..9d356c9187 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -1811,25 +1811,22 @@ class LongestRescale(Transform): of the initial image. Implemented using :py:class:`torch.nn.functional.interpolate`. Args: - spatial_size: expected shape of spatial dimensions after resize operation. - if the components of the `spatial_size` are non-positive values, the transform will use the - corresponding components of img size. For example, `spatial_size=(32, -1)` will be adapted - to `(32, 64)` if the second spatial dimension size of img is `64`. + spatial_size: expected spatial size of the longest side after rescale operation. mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. Defaults to ``"area"``. See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate align_corners: This only has an effect when mode is 'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None. See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate - """ + """ def __init__( self, - spatial_size: Union[Sequence[int], int], + spatial_size: int, mode: Union[InterpolateMode, str] = InterpolateMode.AREA, align_corners: Optional[bool] = None, ) -> None: - self.spatial_size = ensure_tuple(spatial_size) + self.spatial_size = spatial_size self.mode: InterpolateMode = look_up_option(mode, InterpolateMode) self.align_corners = align_corners @@ -1849,9 +1846,6 @@ def __call__( 'linear', 'bilinear', 'bicubic' or 'trilinear'. Defaults to ``self.align_corners``. See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate - Raises: - ValueError: When ``self.spatial_size`` length is less than ``img`` spatial dimensions. - """ input_ndim = img.ndim - 1 # spatial ndim output_ndim = len(self.spatial_size) From d35ea5dbb01296f8ce73a50413ed89d74af3fa97 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 28 Jul 2021 20:10:59 +0800 Subject: [PATCH 03/12] [DLMED] complete array transform Signed-off-by: Nic Ma --- monai/transforms/spatial/array.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index b05770e830..7f9e4d22fc 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -14,6 +14,7 @@ """ import warnings +from math import ceil from typing import Any, List, Optional, Sequence, Tuple, Union import numpy as np @@ -1847,22 +1848,13 @@ def __call__( See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate """ - input_ndim = img.ndim - 1 # spatial ndim - output_ndim = len(self.spatial_size) - if output_ndim > input_ndim: - input_shape = ensure_tuple_size(img.shape, output_ndim + 1, 1) - img = img.reshape(input_shape) - elif output_ndim < input_ndim: - raise ValueError( - "len(spatial_size) must be greater or equal to img spatial dimensions, " - f"got spatial_size={output_ndim} img={input_ndim}." - ) - spatial_size = fall_back_tuple(self.spatial_size, img.shape[1:]) + img_size = img.shape[1:] + scale = self.spatial_size / max(img_size) + new_size = [ceil(s * scale) for s in img_size] resized = torch.nn.functional.interpolate( # type: ignore input=torch.as_tensor(np.ascontiguousarray(img), dtype=torch.float).unsqueeze(0), - size=spatial_size, + size=new_size, mode=look_up_option(self.mode if mode is None else mode, InterpolateMode).value, align_corners=self.align_corners if align_corners is None else align_corners, ) - resized = resized.squeeze(0).detach().cpu().numpy() - return np.asarray(resized) + return resized.squeeze(0).detach().cpu().numpy() From 0c97bc8781f292806c695bad261aecd5f6ab7834 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 28 Jul 2021 20:29:40 +0800 Subject: [PATCH 04/12] [DLMED] add unit tests Signed-off-by: Nic Ma --- docs/source/transforms.rst | 6 ++++++ monai/transforms/__init__.py | 1 + monai/transforms/spatial/array.py | 3 +++ tests/test_longest_rescale.py | 36 +++++++++++++++++++++++++++++++ 4 files changed, 46 insertions(+) create mode 100644 tests/test_longest_rescale.py diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 962e1f3769..cbcd81d00b 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -504,6 +504,12 @@ Spatial :members: :special-members: __call__ +`LongestRescale` +"""""""""""""""" +.. autoclass:: LongestRescale + :members: + :special-members: __call__ + Utility ^^^^^^^ diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 45eecd266c..6acaf520dd 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -235,6 +235,7 @@ Affine, AffineGrid, Flip, + LongestRescale, Orientation, Rand2DElastic, Rand3DElastic, diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 7f9e4d22fc..41da9a0daf 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -72,6 +72,7 @@ "Rand2DElastic", "Rand3DElastic", "AddCoordinateChannels", + "LongestRescale", ] RandRange = Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] @@ -1810,6 +1811,8 @@ class LongestRescale(Transform): """ Rescale an image so that maximum side is equal to specified spatial size, keeping the aspect ratio of the initial image. Implemented using :py:class:`torch.nn.functional.interpolate`. + Refer to: https://albumentations.ai/docs/api_reference/augmentations/geometric/resize/ + #albumentations.augmentations.geometric.resize.LongestMaxSize. Args: spatial_size: expected spatial size of the longest side after rescale operation. diff --git a/tests/test_longest_rescale.py b/tests/test_longest_rescale.py new file mode 100644 index 0000000000..84878f8872 --- /dev/null +++ b/tests/test_longest_rescale.py @@ -0,0 +1,36 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import skimage.transform +from parameterized import parameterized + +from monai.transforms import LongestRescale + +TEST_CASE_0 = [{"spatial_size": 15}, (6, 11, 15)] + +TEST_CASE_1 = [{"spatial_size": 15, "mode": "area"}, (6, 11, 15)] + +TEST_CASE_2 = [{"spatial_size": 6, "mode": "trilinear", "align_corners": True}, (3, 5, 6)] + + +class TestLongestRescale(unittest.TestCase): + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) + def test_shape(self, input_param, expected_shape): + input_data = np.random.randint(0, 2, size=[3, 4, 7, 10]) + result = LongestRescale(**input_param)(input_data) + np.testing.assert_allclose(result.shape[1:], expected_shape) + + +if __name__ == "__main__": + unittest.main() From 853a2d1bafb6498ec5cd5108542faaf8a745b563 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 28 Jul 2021 20:56:19 +0800 Subject: [PATCH 05/12] [DLMED] add dict transform and inverse tests Signed-off-by: Nic Ma --- docs/source/transforms.rst | 6 +++ monai/transforms/__init__.py | 3 ++ monai/transforms/spatial/array.py | 1 + monai/transforms/spatial/dictionary.py | 68 ++++++++++++++++++++++++++ tests/test_inverse.py | 14 ++---- tests/test_longest_rescale.py | 1 - tests/test_longest_rescaled.py | 45 +++++++++++++++++ 7 files changed, 128 insertions(+), 10 deletions(-) create mode 100644 tests/test_longest_rescaled.py diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index cbcd81d00b..94bb992767 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -1084,6 +1084,12 @@ Spatial (Dict) :members: :special-members: __call__ +`LongestRescaled` +""""""""""""""""" +.. autoclass:: LongestRescaled + :members: + :special-members: __call__ + Utility (Dict) ^^^^^^^^^^^^^^ diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 6acaf520dd..5a4f0690be 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -264,6 +264,9 @@ Flipd, FlipD, FlipDict, + LongestRescaled, + LongestRescaleD, + LongestRescaleDict, Orientationd, OrientationD, OrientationDict, diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 41da9a0daf..47b96f50df 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -1824,6 +1824,7 @@ class LongestRescale(Transform): See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate """ + def __init__( self, spatial_size: int, diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 2c9cac8438..53975521ad 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -32,6 +32,7 @@ Affine, AffineGrid, Flip, + LongestRescale, Orientation, Rand2DElastic, Rand3DElastic, @@ -75,6 +76,7 @@ "RandRotated", "Zoomd", "RandZoomd", + "LongestRescaled", "SpacingD", "SpacingDict", "OrientationD", @@ -109,6 +111,8 @@ "RandZoomDict", "AddCoordinateChannelsD", "AddCoordinateChannelsDict", + "LongestRescaleD", + "LongestRescaleDict", ] GridSampleModeSequence = Union[Sequence[Union[GridSampleMode, str]], GridSampleMode, str] @@ -1699,6 +1703,69 @@ def __call__( return d +class LongestRescaled(MapTransform, InvertibleTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.LongestRescale`. + + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + spatial_size: expected spatial size of the longest side after rescale operation. + mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} + The interpolation mode. Defaults to ``"area"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + It also can be a sequence of string, each element corresponds to a key in ``keys``. + align_corners: This only has an effect when mode is + 'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None. + See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + It also can be a sequence of bool or None, each element corresponds to a key in ``keys``. + allow_missing_keys: don't raise exception if key is missing. + """ + + def __init__( + self, + keys: KeysCollection, + spatial_size: int, + mode: InterpolateModeSequence = InterpolateMode.AREA, + align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None, + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) + self.mode = ensure_tuple_rep(mode, len(self.keys)) + self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) + self.rescaler = LongestRescale(spatial_size=spatial_size) + + def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = dict(data) + for key, mode, align_corners in self.key_iterator(d, self.mode, self.align_corners): + self.push_transform( + d, + key, + extra_info={ + "mode": mode.value if isinstance(mode, Enum) else mode, + "align_corners": align_corners if align_corners is not None else "none", + }, + ) + d[key] = self.rescaler(d[key], mode=mode, align_corners=align_corners) + return d + + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in self.key_iterator(d): + transform = self.get_most_recent_transform(d, key) + orig_size = transform[InverseKeys.ORIG_SIZE] + mode = transform[InverseKeys.EXTRA_INFO]["mode"] + align_corners = transform[InverseKeys.EXTRA_INFO]["align_corners"] + # Create inverse transform + inverse_transform = Resize(orig_size, mode, None if align_corners == "none" else align_corners) + # Apply inverse transform + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.pop_transform(d, key) + + return d + + SpacingD = SpacingDict = Spacingd OrientationD = OrientationDict = Orientationd Rotate90D = Rotate90Dict = Rotate90d @@ -1716,3 +1783,4 @@ def __call__( ZoomD = ZoomDict = Zoomd RandZoomD = RandZoomDict = RandZoomd AddCoordinateChannelsD = AddCoordinateChannelsDict = AddCoordinateChannelsd +LongestRescaleD = LongestRescaleDict = LongestRescaled diff --git a/tests/test_inverse.py b/tests/test_inverse.py index fd1afbd857..66105745d1 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -36,6 +36,7 @@ Flipd, InvertibleTransform, LoadImaged, + LongestRescaled, Orientationd, RandAffined, RandAxisFlipd, @@ -249,15 +250,6 @@ ) ) -TESTS.append( - ( - "Flipd 3d", - "3D", - 0, - Flipd(KEYS, [1, 2]), - ) -) - TESTS.append( ( "RandFlipd 3d", @@ -319,6 +311,10 @@ TESTS.append(("Resized 3d", "3D", 5e-2, Resized(KEYS, [201, 150, 78]))) +TESTS.append(("LongestRescaled 2d", "2D", 2e-1, LongestRescaled(KEYS, 47, "area"))) + +TESTS.append(("LongestRescaled 3d", "3D", 5e-2, LongestRescaled(KEYS, 201, "trilinear", True))) + TESTS.append( ( diff --git a/tests/test_longest_rescale.py b/tests/test_longest_rescale.py index 84878f8872..25aaf52f2f 100644 --- a/tests/test_longest_rescale.py +++ b/tests/test_longest_rescale.py @@ -12,7 +12,6 @@ import unittest import numpy as np -import skimage.transform from parameterized import parameterized from monai.transforms import LongestRescale diff --git a/tests/test_longest_rescaled.py b/tests/test_longest_rescaled.py new file mode 100644 index 0000000000..4f5673a22b --- /dev/null +++ b/tests/test_longest_rescaled.py @@ -0,0 +1,45 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import LongestRescaled + +TEST_CASE_0 = [{"keys": "img", "spatial_size": 15}, (6, 11, 15)] + +TEST_CASE_1 = [{"keys": "img", "spatial_size": 15, "mode": "area"}, (6, 11, 15)] + +TEST_CASE_2 = [{"keys": "img", "spatial_size": 6, "mode": "trilinear", "align_corners": True}, (3, 5, 6)] + +TEST_CASE_3 = [ + {"keys": ["img", "label"], "spatial_size": 6, "mode": ["trilinear", "nearest"], "align_corners": [True, None]}, + (3, 5, 6), +] + + +class TestLongestRescaled(unittest.TestCase): + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_shape(self, input_param, expected_shape): + input_data = { + "img": np.random.randint(0, 2, size=[3, 4, 7, 10]), + "label": np.random.randint(0, 2, size=[3, 4, 7, 10]), + } + rescaler = LongestRescaled(**input_param) + result = rescaler(input_data) + for k in rescaler.keys: + np.testing.assert_allclose(result[k].shape[1:], expected_shape) + + +if __name__ == "__main__": + unittest.main() From 2b3ba8b966f6845423e1f34223c15cf3ae600b45 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 28 Jul 2021 20:58:32 +0800 Subject: [PATCH 06/12] [DLMED] fix mypy type Signed-off-by: Nic Ma --- monai/transforms/spatial/array.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 47b96f50df..bb61b666d5 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -1861,4 +1861,5 @@ def __call__( mode=look_up_option(self.mode if mode is None else mode, InterpolateMode).value, align_corners=self.align_corners if align_corners is None else align_corners, ) - return resized.squeeze(0).detach().cpu().numpy() + resized = resized.squeeze(0).detach().cpu().numpy() + return np.asarray(resized) From 7e608a3b36eb74373a53e934a9667ccaaf072d74 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 28 Jul 2021 23:16:09 +0800 Subject: [PATCH 07/12] [DLMED] change to enhance Resize transform Signed-off-by: Nic Ma --- docs/source/transforms.rst | 12 --- monai/transforms/__init__.py | 4 - monai/transforms/spatial/array.py | 101 +++++++------------------ monai/transforms/spatial/dictionary.py | 75 ++---------------- tests/test_resize.py | 13 ++++ tests/test_resized.py | 25 +++++- 6 files changed, 73 insertions(+), 157 deletions(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 94bb992767..962e1f3769 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -504,12 +504,6 @@ Spatial :members: :special-members: __call__ -`LongestRescale` -"""""""""""""""" -.. autoclass:: LongestRescale - :members: - :special-members: __call__ - Utility ^^^^^^^ @@ -1084,12 +1078,6 @@ Spatial (Dict) :members: :special-members: __call__ -`LongestRescaled` -""""""""""""""""" -.. autoclass:: LongestRescaled - :members: - :special-members: __call__ - Utility (Dict) ^^^^^^^^^^^^^^ diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 5a4f0690be..45eecd266c 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -235,7 +235,6 @@ Affine, AffineGrid, Flip, - LongestRescale, Orientation, Rand2DElastic, Rand3DElastic, @@ -264,9 +263,6 @@ Flipd, FlipD, FlipDict, - LongestRescaled, - LongestRescaleD, - LongestRescaleDict, Orientationd, OrientationD, OrientationDict, diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index bb61b666d5..4ef76fa1bf 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -72,7 +72,6 @@ "Rand2DElastic", "Rand3DElastic", "AddCoordinateChannels", - "LongestRescale", ] RandRange = Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] @@ -342,6 +341,11 @@ class Resize(Transform): if some components of the `spatial_size` are non-positive values, the transform will use the corresponding components of img size. For example, `spatial_size=(32, -1)` will be adapted to `(32, 64)` if the second spatial dimension size of img is `64`. + size_mode: should be "all" or "longest", if "all", will use `spatial_size` for all the spatial dims, + if "longest", rescale the image so that only the longest side is equal to specified `spatial_size`, + which must be an int number in this case, keeping the aspect ratio of the initial image, refer to: + https://albumentations.ai/docs/api_reference/augmentations/geometric/resize/ + #albumentations.augmentations.geometric.resize.LongestMaxSize. mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. Defaults to ``"area"``. See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate @@ -353,10 +357,16 @@ class Resize(Transform): def __init__( self, spatial_size: Union[Sequence[int], int], + size_mode: str = "all", mode: Union[InterpolateMode, str] = InterpolateMode.AREA, align_corners: Optional[bool] = None, ) -> None: - self.spatial_size = ensure_tuple(spatial_size) + if size_mode not in ["all", "longest"]: + raise ValueError(f"size_mode must be 'all' or 'longest', but got: {size_mode}.") + if size_mode == "longest" and not isinstance(spatial_size, int): + raise ValueError("spatial_size must be an int number if size_mode is 'longest'.") + self.spatial_size = spatial_size + self.size_mode = size_mode self.mode: InterpolateMode = look_up_option(mode, InterpolateMode) self.align_corners = align_corners @@ -380,20 +390,25 @@ def __call__( ValueError: When ``self.spatial_size`` length is less than ``img`` spatial dimensions. """ - input_ndim = img.ndim - 1 # spatial ndim - output_ndim = len(self.spatial_size) - if output_ndim > input_ndim: - input_shape = ensure_tuple_size(img.shape, output_ndim + 1, 1) - img = img.reshape(input_shape) - elif output_ndim < input_ndim: - raise ValueError( - "len(spatial_size) must be greater or equal to img spatial dimensions, " - f"got spatial_size={output_ndim} img={input_ndim}." - ) - spatial_size = fall_back_tuple(self.spatial_size, img.shape[1:]) + if self.size_mode == "all": + input_ndim = img.ndim - 1 # spatial ndim + output_ndim = len(self.spatial_size) + if output_ndim > input_ndim: + input_shape = ensure_tuple_size(img.shape, output_ndim + 1, 1) + img = img.reshape(input_shape) + elif output_ndim < input_ndim: + raise ValueError( + "len(spatial_size) must be greater or equal to img spatial dimensions, " + f"got spatial_size={output_ndim} img={input_ndim}." + ) + spatial_size_ = fall_back_tuple(self.spatial_size, img.shape[1:]) + else: # for the "longest" mode + img_size = img.shape[1:] + scale = self.spatial_size / max(img_size) + spatial_size_ = [ceil(s * scale) for s in img_size] resized = torch.nn.functional.interpolate( # type: ignore input=torch.as_tensor(np.ascontiguousarray(img), dtype=torch.float).unsqueeze(0), - size=spatial_size, + size=spatial_size_, mode=look_up_option(self.mode if mode is None else mode, InterpolateMode).value, align_corners=self.align_corners if align_corners is None else align_corners, ) @@ -1805,61 +1820,3 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]): # but user input is 1-based (because channel dim is 0) coord_channels = coord_channels[[s - 1 for s in self.spatial_channels]] return np.concatenate((img, coord_channels), axis=0) - - -class LongestRescale(Transform): - """ - Rescale an image so that maximum side is equal to specified spatial size, keeping the aspect ratio - of the initial image. Implemented using :py:class:`torch.nn.functional.interpolate`. - Refer to: https://albumentations.ai/docs/api_reference/augmentations/geometric/resize/ - #albumentations.augmentations.geometric.resize.LongestMaxSize. - - Args: - spatial_size: expected spatial size of the longest side after rescale operation. - mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} - The interpolation mode. Defaults to ``"area"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate - align_corners: This only has an effect when mode is - 'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate - - """ - - def __init__( - self, - spatial_size: int, - mode: Union[InterpolateMode, str] = InterpolateMode.AREA, - align_corners: Optional[bool] = None, - ) -> None: - self.spatial_size = spatial_size - self.mode: InterpolateMode = look_up_option(mode, InterpolateMode) - self.align_corners = align_corners - - def __call__( - self, - img: np.ndarray, - mode: Optional[Union[InterpolateMode, str]] = None, - align_corners: Optional[bool] = None, - ) -> np.ndarray: - """ - Args: - img: channel first array, must have shape: (num_channels, H[, W, ..., ]). - mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} - The interpolation mode. Defaults to ``self.mode``. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate - align_corners: This only has an effect when mode is - 'linear', 'bilinear', 'bicubic' or 'trilinear'. Defaults to ``self.align_corners``. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate - - """ - img_size = img.shape[1:] - scale = self.spatial_size / max(img_size) - new_size = [ceil(s * scale) for s in img_size] - resized = torch.nn.functional.interpolate( # type: ignore - input=torch.as_tensor(np.ascontiguousarray(img), dtype=torch.float).unsqueeze(0), - size=new_size, - mode=look_up_option(self.mode if mode is None else mode, InterpolateMode).value, - align_corners=self.align_corners if align_corners is None else align_corners, - ) - resized = resized.squeeze(0).detach().cpu().numpy() - return np.asarray(resized) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 53975521ad..05a04da6e3 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -76,7 +76,6 @@ "RandRotated", "Zoomd", "RandZoomd", - "LongestRescaled", "SpacingD", "SpacingDict", "OrientationD", @@ -111,8 +110,6 @@ "RandZoomDict", "AddCoordinateChannelsD", "AddCoordinateChannelsDict", - "LongestRescaleD", - "LongestRescaleDict", ] GridSampleModeSequence = Union[Sequence[Union[GridSampleMode, str]], GridSampleMode, str] @@ -507,6 +504,11 @@ class Resized(MapTransform, InvertibleTransform): if some components of the `spatial_size` are non-positive values, the transform will use the corresponding components of img size. For example, `spatial_size=(32, -1)` will be adapted to `(32, 64)` if the second spatial dimension size of img is `64`. + size_mode: should be "all" or "longest", if "all", will use `spatial_size` for all the spatial dims, + if "longest", rescale the image so that only the longest side is equal to specified `spatial_size`, + which must be an int number in this case, keeping the aspect ratio of the initial image, refer to: + https://albumentations.ai/docs/api_reference/augmentations/geometric/resize/ + #albumentations.augmentations.geometric.resize.LongestMaxSize. mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. Defaults to ``"area"``. See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate @@ -522,6 +524,7 @@ def __init__( self, keys: KeysCollection, spatial_size: Union[Sequence[int], int], + size_mode: str = "all", mode: InterpolateModeSequence = InterpolateMode.AREA, align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None, allow_missing_keys: bool = False, @@ -529,7 +532,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) - self.resizer = Resize(spatial_size=spatial_size) + self.resizer = Resize(spatial_size=spatial_size, size_mode=size_mode) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) @@ -1703,69 +1706,6 @@ def __call__( return d -class LongestRescaled(MapTransform, InvertibleTransform): - """ - Dictionary-based wrapper of :py:class:`monai.transforms.LongestRescale`. - - Args: - keys: keys of the corresponding items to be transformed. - See also: :py:class:`monai.transforms.compose.MapTransform` - spatial_size: expected spatial size of the longest side after rescale operation. - mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} - The interpolation mode. Defaults to ``"area"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate - It also can be a sequence of string, each element corresponds to a key in ``keys``. - align_corners: This only has an effect when mode is - 'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate - It also can be a sequence of bool or None, each element corresponds to a key in ``keys``. - allow_missing_keys: don't raise exception if key is missing. - """ - - def __init__( - self, - keys: KeysCollection, - spatial_size: int, - mode: InterpolateModeSequence = InterpolateMode.AREA, - align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None, - allow_missing_keys: bool = False, - ) -> None: - super().__init__(keys, allow_missing_keys) - self.mode = ensure_tuple_rep(mode, len(self.keys)) - self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) - self.rescaler = LongestRescale(spatial_size=spatial_size) - - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: - d = dict(data) - for key, mode, align_corners in self.key_iterator(d, self.mode, self.align_corners): - self.push_transform( - d, - key, - extra_info={ - "mode": mode.value if isinstance(mode, Enum) else mode, - "align_corners": align_corners if align_corners is not None else "none", - }, - ) - d[key] = self.rescaler(d[key], mode=mode, align_corners=align_corners) - return d - - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: - d = deepcopy(dict(data)) - for key in self.key_iterator(d): - transform = self.get_most_recent_transform(d, key) - orig_size = transform[InverseKeys.ORIG_SIZE] - mode = transform[InverseKeys.EXTRA_INFO]["mode"] - align_corners = transform[InverseKeys.EXTRA_INFO]["align_corners"] - # Create inverse transform - inverse_transform = Resize(orig_size, mode, None if align_corners == "none" else align_corners) - # Apply inverse transform - d[key] = inverse_transform(d[key]) - # Remove the applied transform - self.pop_transform(d, key) - - return d - - SpacingD = SpacingDict = Spacingd OrientationD = OrientationDict = Orientationd Rotate90D = Rotate90Dict = Rotate90d @@ -1783,4 +1723,3 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar ZoomD = ZoomDict = Zoomd RandZoomD = RandZoomDict = RandZoomd AddCoordinateChannelsD = AddCoordinateChannelsDict = AddCoordinateChannelsd -LongestRescaleD = LongestRescaleDict = LongestRescaled diff --git a/tests/test_resize.py b/tests/test_resize.py index 22a68bcf85..2f54dcc04f 100644 --- a/tests/test_resize.py +++ b/tests/test_resize.py @@ -18,6 +18,12 @@ from monai.transforms import Resize from tests.utils import NumpyImageTestCase2D +TEST_CASE_0 = [{"spatial_size": 15}, (6, 11, 15)] + +TEST_CASE_1 = [{"spatial_size": 15, "mode": "area"}, (6, 11, 15)] + +TEST_CASE_2 = [{"spatial_size": 6, "mode": "trilinear", "align_corners": True}, (3, 5, 6)] + class TestResize(NumpyImageTestCase2D): def test_invalid_inputs(self): @@ -50,6 +56,13 @@ def test_correct_results(self, spatial_size, mode): out = resize(self.imt[0]) np.testing.assert_allclose(out, expected, atol=0.9) + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) + def test_longest_shape(self, input_param, expected_shape): + input_data = np.random.randint(0, 2, size=[3, 4, 7, 10]) + input_param["size_mode"] = "longest" + result = Resize(**input_param)(input_data) + np.testing.assert_allclose(result.shape[1:], expected_shape) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_resized.py b/tests/test_resized.py index d89c866af3..6c4f31c9c8 100644 --- a/tests/test_resized.py +++ b/tests/test_resized.py @@ -18,6 +18,17 @@ from monai.transforms import Resized from tests.utils import NumpyImageTestCase2D +TEST_CASE_0 = [{"keys": "img", "spatial_size": 15}, (6, 11, 15)] + +TEST_CASE_1 = [{"keys": "img", "spatial_size": 15, "mode": "area"}, (6, 11, 15)] + +TEST_CASE_2 = [{"keys": "img", "spatial_size": 6, "mode": "trilinear", "align_corners": True}, (3, 5, 6)] + +TEST_CASE_3 = [ + {"keys": ["img", "label"], "spatial_size": 6, "mode": ["trilinear", "nearest"], "align_corners": [True, None]}, + (3, 5, 6), +] + class TestResized(NumpyImageTestCase2D): def test_invalid_inputs(self): @@ -31,7 +42,7 @@ def test_invalid_inputs(self): @parameterized.expand([((32, -1), "area"), ((64, 64), "area"), ((32, 32, 32), "area"), ((256, 256), "bilinear")]) def test_correct_results(self, spatial_size, mode): - resize = Resized("img", spatial_size, mode) + resize = Resized("img", spatial_size, mode=mode) _order = 0 if mode.endswith("linear"): _order = 1 @@ -48,6 +59,18 @@ def test_correct_results(self, spatial_size, mode): out = resize({"img": self.imt[0]})["img"] np.testing.assert_allclose(out, expected, atol=0.9) + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_longest_shape(self, input_param, expected_shape): + input_data = { + "img": np.random.randint(0, 2, size=[3, 4, 7, 10]), + "label": np.random.randint(0, 2, size=[3, 4, 7, 10]), + } + input_param["size_mode"] = "longest" + rescaler = Resized(**input_param) + result = rescaler(input_data) + for k in rescaler.keys: + np.testing.assert_allclose(result[k].shape[1:], expected_shape) + if __name__ == "__main__": unittest.main() From 0d2aed0b2a169739ae389afd78691d7781cf08c4 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 28 Jul 2021 23:19:43 +0800 Subject: [PATCH 08/12] [DLMED] fix CI tests Signed-off-by: Nic Ma --- monai/transforms/spatial/dictionary.py | 1 - tests/test_inverse.py | 5 ++--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 05a04da6e3..c034859cbc 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -32,7 +32,6 @@ Affine, AffineGrid, Flip, - LongestRescale, Orientation, Rand2DElastic, Rand3DElastic, diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 66105745d1..a1c171200f 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -36,7 +36,6 @@ Flipd, InvertibleTransform, LoadImaged, - LongestRescaled, Orientationd, RandAffined, RandAxisFlipd, @@ -311,9 +310,9 @@ TESTS.append(("Resized 3d", "3D", 5e-2, Resized(KEYS, [201, 150, 78]))) -TESTS.append(("LongestRescaled 2d", "2D", 2e-1, LongestRescaled(KEYS, 47, "area"))) +TESTS.append(("Resized longest 2d", "2D", 2e-1, Resized(KEYS, 47, "longest", "area"))) -TESTS.append(("LongestRescaled 3d", "3D", 5e-2, LongestRescaled(KEYS, 201, "trilinear", True))) +TESTS.append(("Resized longest 3d", "3D", 5e-2, Resized(KEYS, 201, "longest", "trilinear", True))) TESTS.append( From 63873cc1afba55b96d360d2b1f09cdfe12270cfc Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 28 Jul 2021 23:25:12 +0800 Subject: [PATCH 09/12] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/transforms/spatial/array.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 4ef76fa1bf..bfe05ac667 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -361,12 +361,10 @@ def __init__( mode: Union[InterpolateMode, str] = InterpolateMode.AREA, align_corners: Optional[bool] = None, ) -> None: - if size_mode not in ["all", "longest"]: - raise ValueError(f"size_mode must be 'all' or 'longest', but got: {size_mode}.") + self.size_mode = look_up_option(size_mode, ["all", "longest"]) if size_mode == "longest" and not isinstance(spatial_size, int): raise ValueError("spatial_size must be an int number if size_mode is 'longest'.") self.spatial_size = spatial_size - self.size_mode = size_mode self.mode: InterpolateMode = look_up_option(mode, InterpolateMode) self.align_corners = align_corners @@ -392,7 +390,7 @@ def __call__( """ if self.size_mode == "all": input_ndim = img.ndim - 1 # spatial ndim - output_ndim = len(self.spatial_size) + output_ndim = len(ensure_tuple(self.spatial_size)) if output_ndim > input_ndim: input_shape = ensure_tuple_size(img.shape, output_ndim + 1, 1) img = img.reshape(input_shape) From c06a8f2dbcde9efdb999b2c9a63f39ac9253cac7 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 28 Jul 2021 23:33:19 +0800 Subject: [PATCH 10/12] [DLMED] fix TTA Signed-off-by: Nic Ma --- monai/transforms/spatial/dictionary.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index c034859cbc..0d65fdfa29 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -555,7 +555,11 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar mode = transform[InverseKeys.EXTRA_INFO]["mode"] align_corners = transform[InverseKeys.EXTRA_INFO]["align_corners"] # Create inverse transform - inverse_transform = Resize(orig_size, mode, None if align_corners == "none" else align_corners) + inverse_transform = Resize( + spatial_size=orig_size, + mode=mode, + align_corners=None if align_corners == "none" else align_corners, + ) # Apply inverse transform d[key] = inverse_transform(d[key]) # Remove the applied transform From ee5d4620ddc97fa10e46bad01837c7dc0e1b5470 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 28 Jul 2021 23:34:55 +0800 Subject: [PATCH 11/12] [DLMED] remove tests Signed-off-by: Nic Ma --- tests/test_longest_rescale.py | 35 -------------------------- tests/test_longest_rescaled.py | 45 ---------------------------------- 2 files changed, 80 deletions(-) delete mode 100644 tests/test_longest_rescale.py delete mode 100644 tests/test_longest_rescaled.py diff --git a/tests/test_longest_rescale.py b/tests/test_longest_rescale.py deleted file mode 100644 index 25aaf52f2f..0000000000 --- a/tests/test_longest_rescale.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import numpy as np -from parameterized import parameterized - -from monai.transforms import LongestRescale - -TEST_CASE_0 = [{"spatial_size": 15}, (6, 11, 15)] - -TEST_CASE_1 = [{"spatial_size": 15, "mode": "area"}, (6, 11, 15)] - -TEST_CASE_2 = [{"spatial_size": 6, "mode": "trilinear", "align_corners": True}, (3, 5, 6)] - - -class TestLongestRescale(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) - def test_shape(self, input_param, expected_shape): - input_data = np.random.randint(0, 2, size=[3, 4, 7, 10]) - result = LongestRescale(**input_param)(input_data) - np.testing.assert_allclose(result.shape[1:], expected_shape) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_longest_rescaled.py b/tests/test_longest_rescaled.py deleted file mode 100644 index 4f5673a22b..0000000000 --- a/tests/test_longest_rescaled.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import numpy as np -from parameterized import parameterized - -from monai.transforms import LongestRescaled - -TEST_CASE_0 = [{"keys": "img", "spatial_size": 15}, (6, 11, 15)] - -TEST_CASE_1 = [{"keys": "img", "spatial_size": 15, "mode": "area"}, (6, 11, 15)] - -TEST_CASE_2 = [{"keys": "img", "spatial_size": 6, "mode": "trilinear", "align_corners": True}, (3, 5, 6)] - -TEST_CASE_3 = [ - {"keys": ["img", "label"], "spatial_size": 6, "mode": ["trilinear", "nearest"], "align_corners": [True, None]}, - (3, 5, 6), -] - - -class TestLongestRescaled(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_shape(self, input_param, expected_shape): - input_data = { - "img": np.random.randint(0, 2, size=[3, 4, 7, 10]), - "label": np.random.randint(0, 2, size=[3, 4, 7, 10]), - } - rescaler = LongestRescaled(**input_param) - result = rescaler(input_data) - for k in rescaler.keys: - np.testing.assert_allclose(result[k].shape[1:], expected_shape) - - -if __name__ == "__main__": - unittest.main() From ae39fe6557cae252ca898a7e377f7ff94adf886c Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 28 Jul 2021 23:41:56 +0800 Subject: [PATCH 12/12] [DLMED] fix mypy error Signed-off-by: Nic Ma --- monai/transforms/spatial/array.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index bfe05ac667..d9c10cf9c0 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -362,8 +362,6 @@ def __init__( align_corners: Optional[bool] = None, ) -> None: self.size_mode = look_up_option(size_mode, ["all", "longest"]) - if size_mode == "longest" and not isinstance(spatial_size, int): - raise ValueError("spatial_size must be an int number if size_mode is 'longest'.") self.spatial_size = spatial_size self.mode: InterpolateMode = look_up_option(mode, InterpolateMode) self.align_corners = align_corners @@ -402,8 +400,10 @@ def __call__( spatial_size_ = fall_back_tuple(self.spatial_size, img.shape[1:]) else: # for the "longest" mode img_size = img.shape[1:] + if not isinstance(self.spatial_size, int): + raise ValueError("spatial_size must be an int number if size_mode is 'longest'.") scale = self.spatial_size / max(img_size) - spatial_size_ = [ceil(s * scale) for s in img_size] + spatial_size_ = tuple(ceil(s * scale) for s in img_size) resized = torch.nn.functional.interpolate( # type: ignore input=torch.as_tensor(np.ascontiguousarray(img), dtype=torch.float).unsqueeze(0), size=spatial_size_,