diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 06b98cdd2e..d9c10cf9c0 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -14,6 +14,7 @@ """ import warnings +from math import ceil from typing import Any, List, Optional, Sequence, Tuple, Union import numpy as np @@ -340,6 +341,11 @@ class Resize(Transform): if some components of the `spatial_size` are non-positive values, the transform will use the corresponding components of img size. For example, `spatial_size=(32, -1)` will be adapted to `(32, 64)` if the second spatial dimension size of img is `64`. + size_mode: should be "all" or "longest", if "all", will use `spatial_size` for all the spatial dims, + if "longest", rescale the image so that only the longest side is equal to specified `spatial_size`, + which must be an int number in this case, keeping the aspect ratio of the initial image, refer to: + https://albumentations.ai/docs/api_reference/augmentations/geometric/resize/ + #albumentations.augmentations.geometric.resize.LongestMaxSize. mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. Defaults to ``"area"``. See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate @@ -351,10 +357,12 @@ class Resize(Transform): def __init__( self, spatial_size: Union[Sequence[int], int], + size_mode: str = "all", mode: Union[InterpolateMode, str] = InterpolateMode.AREA, align_corners: Optional[bool] = None, ) -> None: - self.spatial_size = ensure_tuple(spatial_size) + self.size_mode = look_up_option(size_mode, ["all", "longest"]) + self.spatial_size = spatial_size self.mode: InterpolateMode = look_up_option(mode, InterpolateMode) self.align_corners = align_corners @@ -378,20 +386,27 @@ def __call__( ValueError: When ``self.spatial_size`` length is less than ``img`` spatial dimensions. """ - input_ndim = img.ndim - 1 # spatial ndim - output_ndim = len(self.spatial_size) - if output_ndim > input_ndim: - input_shape = ensure_tuple_size(img.shape, output_ndim + 1, 1) - img = img.reshape(input_shape) - elif output_ndim < input_ndim: - raise ValueError( - "len(spatial_size) must be greater or equal to img spatial dimensions, " - f"got spatial_size={output_ndim} img={input_ndim}." - ) - spatial_size = fall_back_tuple(self.spatial_size, img.shape[1:]) + if self.size_mode == "all": + input_ndim = img.ndim - 1 # spatial ndim + output_ndim = len(ensure_tuple(self.spatial_size)) + if output_ndim > input_ndim: + input_shape = ensure_tuple_size(img.shape, output_ndim + 1, 1) + img = img.reshape(input_shape) + elif output_ndim < input_ndim: + raise ValueError( + "len(spatial_size) must be greater or equal to img spatial dimensions, " + f"got spatial_size={output_ndim} img={input_ndim}." + ) + spatial_size_ = fall_back_tuple(self.spatial_size, img.shape[1:]) + else: # for the "longest" mode + img_size = img.shape[1:] + if not isinstance(self.spatial_size, int): + raise ValueError("spatial_size must be an int number if size_mode is 'longest'.") + scale = self.spatial_size / max(img_size) + spatial_size_ = tuple(ceil(s * scale) for s in img_size) resized = torch.nn.functional.interpolate( # type: ignore input=torch.as_tensor(np.ascontiguousarray(img), dtype=torch.float).unsqueeze(0), - size=spatial_size, + size=spatial_size_, mode=look_up_option(self.mode if mode is None else mode, InterpolateMode).value, align_corners=self.align_corners if align_corners is None else align_corners, ) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 2c9cac8438..0d65fdfa29 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -503,6 +503,11 @@ class Resized(MapTransform, InvertibleTransform): if some components of the `spatial_size` are non-positive values, the transform will use the corresponding components of img size. For example, `spatial_size=(32, -1)` will be adapted to `(32, 64)` if the second spatial dimension size of img is `64`. + size_mode: should be "all" or "longest", if "all", will use `spatial_size` for all the spatial dims, + if "longest", rescale the image so that only the longest side is equal to specified `spatial_size`, + which must be an int number in this case, keeping the aspect ratio of the initial image, refer to: + https://albumentations.ai/docs/api_reference/augmentations/geometric/resize/ + #albumentations.augmentations.geometric.resize.LongestMaxSize. mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. Defaults to ``"area"``. See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate @@ -518,6 +523,7 @@ def __init__( self, keys: KeysCollection, spatial_size: Union[Sequence[int], int], + size_mode: str = "all", mode: InterpolateModeSequence = InterpolateMode.AREA, align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None, allow_missing_keys: bool = False, @@ -525,7 +531,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) - self.resizer = Resize(spatial_size=spatial_size) + self.resizer = Resize(spatial_size=spatial_size, size_mode=size_mode) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) @@ -549,7 +555,11 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar mode = transform[InverseKeys.EXTRA_INFO]["mode"] align_corners = transform[InverseKeys.EXTRA_INFO]["align_corners"] # Create inverse transform - inverse_transform = Resize(orig_size, mode, None if align_corners == "none" else align_corners) + inverse_transform = Resize( + spatial_size=orig_size, + mode=mode, + align_corners=None if align_corners == "none" else align_corners, + ) # Apply inverse transform d[key] = inverse_transform(d[key]) # Remove the applied transform diff --git a/tests/test_inverse.py b/tests/test_inverse.py index fd1afbd857..a1c171200f 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -249,15 +249,6 @@ ) ) -TESTS.append( - ( - "Flipd 3d", - "3D", - 0, - Flipd(KEYS, [1, 2]), - ) -) - TESTS.append( ( "RandFlipd 3d", @@ -319,6 +310,10 @@ TESTS.append(("Resized 3d", "3D", 5e-2, Resized(KEYS, [201, 150, 78]))) +TESTS.append(("Resized longest 2d", "2D", 2e-1, Resized(KEYS, 47, "longest", "area"))) + +TESTS.append(("Resized longest 3d", "3D", 5e-2, Resized(KEYS, 201, "longest", "trilinear", True))) + TESTS.append( ( diff --git a/tests/test_resize.py b/tests/test_resize.py index 22a68bcf85..2f54dcc04f 100644 --- a/tests/test_resize.py +++ b/tests/test_resize.py @@ -18,6 +18,12 @@ from monai.transforms import Resize from tests.utils import NumpyImageTestCase2D +TEST_CASE_0 = [{"spatial_size": 15}, (6, 11, 15)] + +TEST_CASE_1 = [{"spatial_size": 15, "mode": "area"}, (6, 11, 15)] + +TEST_CASE_2 = [{"spatial_size": 6, "mode": "trilinear", "align_corners": True}, (3, 5, 6)] + class TestResize(NumpyImageTestCase2D): def test_invalid_inputs(self): @@ -50,6 +56,13 @@ def test_correct_results(self, spatial_size, mode): out = resize(self.imt[0]) np.testing.assert_allclose(out, expected, atol=0.9) + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) + def test_longest_shape(self, input_param, expected_shape): + input_data = np.random.randint(0, 2, size=[3, 4, 7, 10]) + input_param["size_mode"] = "longest" + result = Resize(**input_param)(input_data) + np.testing.assert_allclose(result.shape[1:], expected_shape) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_resized.py b/tests/test_resized.py index d89c866af3..6c4f31c9c8 100644 --- a/tests/test_resized.py +++ b/tests/test_resized.py @@ -18,6 +18,17 @@ from monai.transforms import Resized from tests.utils import NumpyImageTestCase2D +TEST_CASE_0 = [{"keys": "img", "spatial_size": 15}, (6, 11, 15)] + +TEST_CASE_1 = [{"keys": "img", "spatial_size": 15, "mode": "area"}, (6, 11, 15)] + +TEST_CASE_2 = [{"keys": "img", "spatial_size": 6, "mode": "trilinear", "align_corners": True}, (3, 5, 6)] + +TEST_CASE_3 = [ + {"keys": ["img", "label"], "spatial_size": 6, "mode": ["trilinear", "nearest"], "align_corners": [True, None]}, + (3, 5, 6), +] + class TestResized(NumpyImageTestCase2D): def test_invalid_inputs(self): @@ -31,7 +42,7 @@ def test_invalid_inputs(self): @parameterized.expand([((32, -1), "area"), ((64, 64), "area"), ((32, 32, 32), "area"), ((256, 256), "bilinear")]) def test_correct_results(self, spatial_size, mode): - resize = Resized("img", spatial_size, mode) + resize = Resized("img", spatial_size, mode=mode) _order = 0 if mode.endswith("linear"): _order = 1 @@ -48,6 +59,18 @@ def test_correct_results(self, spatial_size, mode): out = resize({"img": self.imt[0]})["img"] np.testing.assert_allclose(out, expected, atol=0.9) + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_longest_shape(self, input_param, expected_shape): + input_data = { + "img": np.random.randint(0, 2, size=[3, 4, 7, 10]), + "label": np.random.randint(0, 2, size=[3, 4, 7, 10]), + } + input_param["size_mode"] = "longest" + rescaler = Resized(**input_param) + result = rescaler(input_data) + for k in rescaler.keys: + np.testing.assert_allclose(result[k].shape[1:], expected_shape) + if __name__ == "__main__": unittest.main()