diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index c3bd4a3433..5979b996ee 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -22,7 +22,7 @@ from monai.config.type_definitions import NdarrayOrTensor from monai.data.utils import compute_shape_offset, to_affine_nd, zoom_affine from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull -from monai.transforms.croppad.array import CenterSpatialCrop +from monai.transforms.croppad.array import CenterSpatialCrop, Pad from monai.transforms.transform import Randomizable, RandomizableTransform, ThreadUnsafe, Transform from monai.transforms.utils import ( create_control_grid, @@ -47,6 +47,7 @@ ) from monai.utils.enums import TransformBackends from monai.utils.module import look_up_option +from monai.utils.type_conversion import convert_data_type nib, _ = optional_import("nibabel") @@ -555,6 +556,8 @@ class Zoom(Transform): """ + backend = [TransformBackends.TORCH] + def __init__( self, zoom: Union[Sequence[float], float], @@ -573,11 +576,11 @@ def __init__( def __call__( self, - img: np.ndarray, + img: NdarrayOrTensor, mode: Optional[Union[InterpolateMode, str]] = None, padding_mode: Optional[Union[NumpyPadMode, str]] = None, align_corners: Optional[bool] = None, - ): + ) -> torch.Tensor: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]). @@ -593,31 +596,37 @@ def __call__( See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate """ + img_t: torch.Tensor + img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float32) # type: ignore + _zoom = ensure_tuple_rep(self.zoom, img.ndim - 1) # match the spatial image dim - zoomed = torch.nn.functional.interpolate( # type: ignore + zoomed: torch.Tensor = torch.nn.functional.interpolate( # type: ignore recompute_scale_factor=True, - input=torch.as_tensor(np.ascontiguousarray(img), dtype=torch.float).unsqueeze(0), + input=img_t.unsqueeze(0), scale_factor=list(_zoom), mode=look_up_option(self.mode if mode is None else mode, InterpolateMode).value, align_corners=self.align_corners if align_corners is None else align_corners, ) - zoomed = zoomed.squeeze(0).detach().cpu().numpy() - if not self.keep_size or np.allclose(img.shape, zoomed.shape): - return zoomed + zoomed = zoomed.squeeze(0) + + if self.keep_size and not np.allclose(img_t.shape, zoomed.shape): - pad_vec = [[0, 0]] * len(img.shape) - slice_vec = [slice(None)] * len(img.shape) - for idx, (od, zd) in enumerate(zip(img.shape, zoomed.shape)): - diff = od - zd - half = abs(diff) // 2 - if diff > 0: # need padding - pad_vec[idx] = [half, diff - half] - elif diff < 0: # need slicing - slice_vec[idx] = slice(half, half + od) + pad_vec = [(0, 0)] * len(img_t.shape) + slice_vec = [slice(None)] * len(img_t.shape) + for idx, (od, zd) in enumerate(zip(img_t.shape, zoomed.shape)): + diff = od - zd + half = abs(diff) // 2 + if diff > 0: # need padding + pad_vec[idx] = (half, diff - half) + elif diff < 0: # need slicing + slice_vec[idx] = slice(half, half + od) - padding_mode = look_up_option(self.padding_mode if padding_mode is None else padding_mode, NumpyPadMode) - zoomed = np.pad(zoomed, pad_vec, mode=padding_mode.value, **self.np_kwargs) # type: ignore - return zoomed[tuple(slice_vec)] + padding_mode = look_up_option(padding_mode or self.padding_mode, NumpyPadMode) + padder = Pad(pad_vec, padding_mode) + zoomed = padder(zoomed) + zoomed = zoomed[tuple(slice_vec)] + + return zoomed class Rotate90(Transform): @@ -886,6 +895,8 @@ class RandZoom(RandomizableTransform): """ + backend = Zoom.backend + def __init__( self, prob: float = 0.1, @@ -916,11 +927,11 @@ def randomize(self, data: Optional[Any] = None) -> None: def __call__( self, - img: np.ndarray, + img: NdarrayOrTensor, mode: Optional[Union[InterpolateMode, str]] = None, padding_mode: Optional[Union[NumpyPadMode, str]] = None, align_corners: Optional[bool] = None, - ) -> np.ndarray: + ) -> torch.Tensor: """ Args: img: channel first array, must have shape 2D: (nchannels, H, W), or 3D: (nchannels, H, W, D). @@ -937,25 +948,25 @@ def __call__( """ # match the spatial image dim self.randomize() - _dtype = np.float32 if not self._do_transform: - return img.astype(_dtype) + img_t: torch.Tensor + img_t, *_ = convert_data_type(img, dtype=torch.float32) # type: ignore + return img_t if len(self._zoom) == 1: # to keep the spatial shape ratio, use same random zoom factor for all dims self._zoom = ensure_tuple_rep(self._zoom[0], img.ndim - 1) elif len(self._zoom) == 2 and img.ndim > 3: # if 2 zoom factors provided for 3D data, use the first factor for H and W dims, second factor for D dim self._zoom = ensure_tuple_rep(self._zoom[0], img.ndim - 2) + ensure_tuple(self._zoom[-1]) - zoomer = Zoom(self._zoom, keep_size=self.keep_size, **self.np_kwargs) - return np.asarray( - zoomer( - img, - mode=look_up_option(mode or self.mode, InterpolateMode), - padding_mode=look_up_option(padding_mode or self.padding_mode, NumpyPadMode), - align_corners=self.align_corners if align_corners is None else align_corners, - ), - dtype=_dtype, + zoomer = Zoom( + self._zoom, + keep_size=self.keep_size, + mode=look_up_option(mode or self.mode, InterpolateMode), + padding_mode=look_up_option(padding_mode or self.padding_mode, NumpyPadMode), + align_corners=align_corners or self.align_corners, + **self.np_kwargs, ) + return zoomer(img) class AffineGrid(Transform): diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index b0558a6556..22dafb3f80 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -1537,6 +1537,8 @@ class Zoomd(MapTransform, InvertibleTransform): """ + backend = Zoom.backend + def __init__( self, keys: KeysCollection, @@ -1554,7 +1556,7 @@ def __init__( self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.zoomer = Zoom(zoom=zoom, keep_size=keep_size, **np_kwargs) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, mode, padding_mode, align_corners in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners @@ -1576,7 +1578,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda ) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) @@ -1594,7 +1596,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar align_corners=None if align_corners == "none" else align_corners, ) # Size might be out by 1 voxel so pad - d[key] = SpatialPad(transform[InverseKeys.ORIG_SIZE], mode="edge")(d[key]) + d[key] = SpatialPad(transform[InverseKeys.ORIG_SIZE], mode="edge")(d[key]) # type: ignore # Remove the applied transform self.pop_transform(d, key) @@ -1637,6 +1639,8 @@ class RandZoomd(RandomizableTransform, MapTransform, InvertibleTransform): """ + backend = Zoom.backend + def __init__( self, keys: KeysCollection, @@ -1669,7 +1673,7 @@ def randomize(self, data: Optional[Any] = None) -> None: super().randomize(None) self._zoom = [self.R.uniform(l, h) for l, h in zip(self.min_zoom, self.max_zoom)] - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: # match the spatial dim of first item self.randomize() d = dict(data) @@ -1704,7 +1708,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda ) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) @@ -1724,7 +1728,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar align_corners=None if align_corners == "none" else align_corners, ) # Size might be out by 1 voxel so pad - d[key] = SpatialPad(transform[InverseKeys.ORIG_SIZE], mode="edge")(d[key]) + d[key] = SpatialPad(transform[InverseKeys.ORIG_SIZE], mode="edge")(d[key]) # type: ignore # Remove the applied transform self.pop_transform(d, key) diff --git a/tests/test_rand_zoom.py b/tests/test_rand_zoom.py index c21bc8b9e9..0ac1b92c39 100644 --- a/tests/test_rand_zoom.py +++ b/tests/test_rand_zoom.py @@ -17,7 +17,7 @@ from monai.transforms import RandZoom from monai.utils import GridSampleMode, InterpolateMode -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose VALID_CASES = [(0.8, 1.2, "nearest", False), (0.8, 1.2, InterpolateMode.NEAREST, False)] @@ -25,36 +25,32 @@ class TestRandZoom(NumpyImageTestCase2D): @parameterized.expand(VALID_CASES) def test_correct_results(self, min_zoom, max_zoom, mode, keep_size): - random_zoom = RandZoom( - prob=1.0, - min_zoom=min_zoom, - max_zoom=max_zoom, - mode=mode, - keep_size=keep_size, - ) - random_zoom.set_random_state(1234) - zoomed = random_zoom(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(zoom_scipy(channel, zoom=random_zoom._zoom, mode="nearest", order=0, prefilter=False)) - expected = np.stack(expected).astype(np.float32) - np.testing.assert_allclose(zoomed, expected, atol=1.0) + for p in TEST_NDARRAYS: + random_zoom = RandZoom( + prob=1.0, + min_zoom=min_zoom, + max_zoom=max_zoom, + mode=mode, + keep_size=keep_size, + ) + random_zoom.set_random_state(1234) + zoomed = random_zoom(p(self.imt[0])) + expected = [] + for channel in self.imt[0]: + expected.append(zoom_scipy(channel, zoom=random_zoom._zoom, mode="nearest", order=0, prefilter=False)) + expected = np.stack(expected).astype(np.float32) + assert_allclose(zoomed, expected, atol=1.0) def test_keep_size(self): - random_zoom = RandZoom( - prob=1.0, - min_zoom=0.6, - max_zoom=0.7, - keep_size=True, - padding_mode="constant", - constant_values=2, - ) - zoomed = random_zoom(self.imt[0]) - self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:])) - zoomed = random_zoom(self.imt[0]) - self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:])) - zoomed = random_zoom(self.imt[0]) - self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:])) + for p in TEST_NDARRAYS: + im = p(self.imt[0]) + random_zoom = RandZoom(prob=1.0, min_zoom=0.6, max_zoom=0.7, keep_size=True) + zoomed = random_zoom(im) + self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:])) + zoomed = random_zoom(im) + self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:])) + zoomed = random_zoom(im) + self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:])) @parameterized.expand( [ @@ -64,23 +60,25 @@ def test_keep_size(self): ] ) def test_invalid_inputs(self, _, min_zoom, max_zoom, mode, raises): - with self.assertRaises(raises): - random_zoom = RandZoom(prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, mode=mode) - random_zoom(self.imt[0]) + for p in TEST_NDARRAYS: + with self.assertRaises(raises): + random_zoom = RandZoom(prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, mode=mode) + random_zoom(p(self.imt[0])) def test_auto_expand_3d(self): - random_zoom = RandZoom( - prob=1.0, - min_zoom=[0.8, 0.7], - max_zoom=[1.2, 1.3], - mode="nearest", - keep_size=False, - ) - random_zoom.set_random_state(1234) - test_data = np.random.randint(0, 2, size=[2, 2, 3, 4]) - zoomed = random_zoom(test_data) - np.testing.assert_allclose(random_zoom._zoom, (1.048844, 1.048844, 0.962637), atol=1e-2) - np.testing.assert_allclose(zoomed.shape, (2, 2, 3, 3)) + for p in TEST_NDARRAYS: + random_zoom = RandZoom( + prob=1.0, + min_zoom=[0.8, 0.7], + max_zoom=[1.2, 1.3], + mode="nearest", + keep_size=False, + ) + random_zoom.set_random_state(1234) + test_data = p(np.random.randint(0, 2, size=[2, 2, 3, 4])) + zoomed = random_zoom(test_data) + assert_allclose(random_zoom._zoom, (1.048844, 1.048844, 0.962637), atol=1e-2) + assert_allclose(zoomed.shape, (2, 2, 3, 3)) if __name__ == "__main__": diff --git a/tests/test_rand_zoomd.py b/tests/test_rand_zoomd.py index 4ccb1aad64..fafaf748bd 100644 --- a/tests/test_rand_zoomd.py +++ b/tests/test_rand_zoomd.py @@ -16,7 +16,7 @@ from scipy.ndimage import zoom as zoom_scipy from monai.transforms import RandZoomd -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose VALID_CASES = [(0.8, 1.2, "nearest", None, False)] @@ -34,14 +34,15 @@ def test_correct_results(self, min_zoom, max_zoom, mode, align_corners, keep_siz align_corners=align_corners, keep_size=keep_size, ) - random_zoom.set_random_state(1234) + for p in TEST_NDARRAYS: + random_zoom.set_random_state(1234) - zoomed = random_zoom({key: self.imt[0]}) - expected = [] - for channel in self.imt[0]: - expected.append(zoom_scipy(channel, zoom=random_zoom._zoom, mode="nearest", order=0, prefilter=False)) - expected = np.stack(expected).astype(np.float32) - np.testing.assert_allclose(expected, zoomed[key], atol=1.0) + zoomed = random_zoom({key: p(self.imt[0])}) + expected = [] + for channel in self.imt[0]: + expected.append(zoom_scipy(channel, zoom=random_zoom._zoom, mode="nearest", order=0, prefilter=False)) + expected = np.stack(expected).astype(np.float32) + assert_allclose(expected, zoomed[key], atol=1.0) def test_keep_size(self): key = "img" @@ -54,17 +55,19 @@ def test_keep_size(self): padding_mode="constant", constant_values=2, ) - zoomed = random_zoom({key: self.imt[0]}) - self.assertTrue(np.array_equal(zoomed[key].shape, self.imt.shape[1:])) + for p in TEST_NDARRAYS: + zoomed = random_zoom({key: p(self.imt[0])}) + np.testing.assert_array_equal(zoomed[key].shape, self.imt.shape[1:]) @parameterized.expand( [("no_min_zoom", None, 1.1, "bilinear", TypeError), ("invalid_order", 0.9, 1.1, "s", ValueError)] ) def test_invalid_inputs(self, _, min_zoom, max_zoom, mode, raises): key = "img" - with self.assertRaises(raises): - random_zoom = RandZoomd(key, prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, mode=mode) - random_zoom({key: self.imt[0]}) + for p in TEST_NDARRAYS: + with self.assertRaises(raises): + random_zoom = RandZoomd(key, prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, mode=mode) + random_zoom({key: p(self.imt[0])}) def test_auto_expand_3d(self): random_zoom = RandZoomd( @@ -75,11 +78,12 @@ def test_auto_expand_3d(self): mode="nearest", keep_size=False, ) - random_zoom.set_random_state(1234) - test_data = {"img": np.random.randint(0, 2, size=[2, 2, 3, 4])} - zoomed = random_zoom(test_data) - np.testing.assert_allclose(random_zoom._zoom, (1.048844, 1.048844, 0.962637), atol=1e-2) - np.testing.assert_allclose(zoomed["img"].shape, (2, 2, 3, 3)) + for p in TEST_NDARRAYS: + random_zoom.set_random_state(1234) + test_data = {"img": p(np.random.randint(0, 2, size=[2, 2, 3, 4]))} + zoomed = random_zoom(test_data) + assert_allclose(random_zoom._zoom, (1.048844, 1.048844, 0.962637), atol=1e-2) + assert_allclose(zoomed["img"].shape, (2, 2, 3, 3)) if __name__ == "__main__": diff --git a/tests/test_zoom.py b/tests/test_zoom.py index e6710ede29..a99e110052 100644 --- a/tests/test_zoom.py +++ b/tests/test_zoom.py @@ -12,11 +12,12 @@ import unittest import numpy as np +import torch from parameterized import parameterized from scipy.ndimage import zoom as zoom_scipy from monai.transforms import Zoom -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose VALID_CASES = [(1.5, "nearest"), (1.5, "nearest"), (0.8, "bilinear"), (0.8, "area")] @@ -26,38 +27,42 @@ class TestZoom(NumpyImageTestCase2D): @parameterized.expand(VALID_CASES) def test_correct_results(self, zoom, mode): - zoom_fn = Zoom(zoom=zoom, mode=mode, keep_size=False) - zoomed = zoom_fn(self.imt[0]) - _order = 0 - if mode.endswith("linear"): - _order = 1 - expected = [] - for channel in self.imt[0]: - expected.append(zoom_scipy(channel, zoom=zoom, mode="nearest", order=_order, prefilter=False)) - expected = np.stack(expected).astype(np.float32) - np.testing.assert_allclose(zoomed, expected, atol=1.0) + for p in TEST_NDARRAYS: + zoom_fn = Zoom(zoom=zoom, mode=mode, keep_size=False) + zoomed = zoom_fn(p(self.imt[0])) + _order = 0 + if mode.endswith("linear"): + _order = 1 + expected = [] + for channel in self.imt[0]: + expected.append(zoom_scipy(channel, zoom=zoom, mode="nearest", order=_order, prefilter=False)) + expected = np.stack(expected).astype(np.float32) + assert_allclose(zoomed, expected, atol=1.0) def test_keep_size(self): - zoom_fn = Zoom(zoom=[0.6, 0.6], keep_size=True, align_corners=True, padding_mode="constant", constant_values=2) - zoomed = zoom_fn(self.imt[0], mode="bilinear") - np.testing.assert_allclose(zoomed.shape, self.imt.shape[1:]) + for p in TEST_NDARRAYS: + zoom_fn = Zoom(zoom=[0.6, 0.6], keep_size=True, align_corners=True) + zoomed = zoom_fn(p(self.imt[0]), mode="bilinear") + assert_allclose(zoomed.shape, self.imt.shape[1:]) - zoom_fn = Zoom(zoom=[1.3, 1.3], keep_size=True) - zoomed = zoom_fn(self.imt[0]) - np.testing.assert_allclose(zoomed.shape, self.imt.shape[1:]) + zoom_fn = Zoom(zoom=[1.3, 1.3], keep_size=True) + zoomed = zoom_fn(p(self.imt[0])) + assert_allclose(zoomed.shape, self.imt.shape[1:]) @parameterized.expand(INVALID_CASES) def test_invalid_inputs(self, zoom, mode, raises): - with self.assertRaises(raises): - zoom_fn = Zoom(zoom=zoom, mode=mode) - zoom_fn(self.imt[0]) + for p in TEST_NDARRAYS: + with self.assertRaises(raises): + zoom_fn = Zoom(zoom=zoom, mode=mode) + zoom_fn(p(self.imt[0])) def test_padding_mode(self): - zoom_fn = Zoom(zoom=0.5, mode="nearest", padding_mode="constant", keep_size=True) - test_data = np.array([[[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]]) - zoomed = zoom_fn(test_data) - expected = np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 1.0, 0.0], [0.0, 1.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]) - np.testing.assert_allclose(zoomed, expected) + for p in TEST_NDARRAYS: + zoom_fn = Zoom(zoom=0.5, mode="nearest", padding_mode="constant", keep_size=True) + test_data = p([[[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]]) + zoomed = zoom_fn(test_data) + expected = p([[[0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 1.0, 0.0], [0.0, 1.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]) + torch.testing.assert_allclose(zoomed, expected) if __name__ == "__main__": diff --git a/tests/test_zoomd.py b/tests/test_zoomd.py index 1a1a905d80..1ebd7d2d08 100644 --- a/tests/test_zoomd.py +++ b/tests/test_zoomd.py @@ -16,7 +16,7 @@ from scipy.ndimage import zoom as zoom_scipy from monai.transforms import Zoomd -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose VALID_CASES = [(1.5, "nearest", False), (0.3, "bilinear", False), (0.8, "bilinear", False)] @@ -33,32 +33,35 @@ def test_correct_results(self, zoom, mode, keep_size): mode=mode, keep_size=keep_size, ) - zoomed = zoom_fn({key: self.imt[0]}) - _order = 0 - if mode.endswith("linear"): - _order = 1 - expected = [] - for channel in self.imt[0]: - expected.append(zoom_scipy(channel, zoom=zoom, mode="nearest", order=_order, prefilter=False)) - expected = np.stack(expected).astype(np.float32) - np.testing.assert_allclose(expected, zoomed[key], atol=1.0) + for p in TEST_NDARRAYS: + zoomed = zoom_fn({key: p(self.imt[0])}) + _order = 0 + if mode.endswith("linear"): + _order = 1 + expected = [] + for channel in self.imt[0]: + expected.append(zoom_scipy(channel, zoom=zoom, mode="nearest", order=_order, prefilter=False)) + expected = np.stack(expected).astype(np.float32) + assert_allclose(expected, zoomed[key], atol=1.0) def test_keep_size(self): key = "img" zoom_fn = Zoomd(key, zoom=0.6, keep_size=True, padding_mode="constant", constant_values=2) - zoomed = zoom_fn({key: self.imt[0]}) - self.assertTrue(np.array_equal(zoomed[key].shape, self.imt.shape[1:])) + for p in TEST_NDARRAYS: + zoomed = zoom_fn({key: p(self.imt[0])}) + np.testing.assert_array_equal(zoomed[key].shape, self.imt.shape[1:]) - zoom_fn = Zoomd(key, zoom=1.3, keep_size=True) - zoomed = zoom_fn({key: self.imt[0]}) - self.assertTrue(np.array_equal(zoomed[key].shape, self.imt.shape[1:])) + zoom_fn = Zoomd(key, zoom=1.3, keep_size=True) + zoomed = zoom_fn({key: self.imt[0]}) + self.assertTrue(np.array_equal(zoomed[key].shape, self.imt.shape[1:])) @parameterized.expand(INVALID_CASES) def test_invalid_inputs(self, _, zoom, mode, raises): key = "img" - with self.assertRaises(raises): - zoom_fn = Zoomd(key, zoom=zoom, mode=mode) - zoom_fn({key: self.imt[0]}) + for p in TEST_NDARRAYS: + with self.assertRaises(raises): + zoom_fn = Zoomd(key, zoom=zoom, mode=mode) + zoom_fn({key: p(self.imt[0])}) if __name__ == "__main__":