Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 44 additions & 33 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.utils import compute_shape_offset, to_affine_nd, zoom_affine
from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull
from monai.transforms.croppad.array import CenterSpatialCrop
from monai.transforms.croppad.array import CenterSpatialCrop, Pad
from monai.transforms.transform import Randomizable, RandomizableTransform, ThreadUnsafe, Transform
from monai.transforms.utils import (
create_control_grid,
Expand All @@ -47,6 +47,7 @@
)
from monai.utils.enums import TransformBackends
from monai.utils.module import look_up_option
from monai.utils.type_conversion import convert_data_type

nib, _ = optional_import("nibabel")

Expand Down Expand Up @@ -555,6 +556,8 @@ class Zoom(Transform):

"""

backend = [TransformBackends.TORCH]

def __init__(
self,
zoom: Union[Sequence[float], float],
Expand All @@ -573,11 +576,11 @@ def __init__(

def __call__(
self,
img: np.ndarray,
img: NdarrayOrTensor,
mode: Optional[Union[InterpolateMode, str]] = None,
padding_mode: Optional[Union[NumpyPadMode, str]] = None,
align_corners: Optional[bool] = None,
):
) -> torch.Tensor:
"""
Args:
img: channel first array, must have shape: (num_channels, H[, W, ..., ]).
Expand All @@ -593,31 +596,37 @@ def __call__(
See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate

"""
img_t: torch.Tensor
img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float32) # type: ignore

_zoom = ensure_tuple_rep(self.zoom, img.ndim - 1) # match the spatial image dim
zoomed = torch.nn.functional.interpolate( # type: ignore
zoomed: torch.Tensor = torch.nn.functional.interpolate( # type: ignore
recompute_scale_factor=True,
input=torch.as_tensor(np.ascontiguousarray(img), dtype=torch.float).unsqueeze(0),
input=img_t.unsqueeze(0),
scale_factor=list(_zoom),
mode=look_up_option(self.mode if mode is None else mode, InterpolateMode).value,
align_corners=self.align_corners if align_corners is None else align_corners,
)
zoomed = zoomed.squeeze(0).detach().cpu().numpy()
if not self.keep_size or np.allclose(img.shape, zoomed.shape):
return zoomed
zoomed = zoomed.squeeze(0)

if self.keep_size and not np.allclose(img_t.shape, zoomed.shape):

pad_vec = [[0, 0]] * len(img.shape)
slice_vec = [slice(None)] * len(img.shape)
for idx, (od, zd) in enumerate(zip(img.shape, zoomed.shape)):
diff = od - zd
half = abs(diff) // 2
if diff > 0: # need padding
pad_vec[idx] = [half, diff - half]
elif diff < 0: # need slicing
slice_vec[idx] = slice(half, half + od)
pad_vec = [(0, 0)] * len(img_t.shape)
slice_vec = [slice(None)] * len(img_t.shape)
for idx, (od, zd) in enumerate(zip(img_t.shape, zoomed.shape)):
diff = od - zd
half = abs(diff) // 2
if diff > 0: # need padding
pad_vec[idx] = (half, diff - half)
elif diff < 0: # need slicing
slice_vec[idx] = slice(half, half + od)

padding_mode = look_up_option(self.padding_mode if padding_mode is None else padding_mode, NumpyPadMode)
zoomed = np.pad(zoomed, pad_vec, mode=padding_mode.value, **self.np_kwargs) # type: ignore
return zoomed[tuple(slice_vec)]
padding_mode = look_up_option(padding_mode or self.padding_mode, NumpyPadMode)
padder = Pad(pad_vec, padding_mode)
zoomed = padder(zoomed)
zoomed = zoomed[tuple(slice_vec)]

return zoomed


class Rotate90(Transform):
Expand Down Expand Up @@ -886,6 +895,8 @@ class RandZoom(RandomizableTransform):

"""

backend = Zoom.backend

def __init__(
self,
prob: float = 0.1,
Expand Down Expand Up @@ -916,11 +927,11 @@ def randomize(self, data: Optional[Any] = None) -> None:

def __call__(
self,
img: np.ndarray,
img: NdarrayOrTensor,
mode: Optional[Union[InterpolateMode, str]] = None,
padding_mode: Optional[Union[NumpyPadMode, str]] = None,
align_corners: Optional[bool] = None,
) -> np.ndarray:
) -> torch.Tensor:
"""
Args:
img: channel first array, must have shape 2D: (nchannels, H, W), or 3D: (nchannels, H, W, D).
Expand All @@ -937,25 +948,25 @@ def __call__(
"""
# match the spatial image dim
self.randomize()
_dtype = np.float32
if not self._do_transform:
return img.astype(_dtype)
img_t: torch.Tensor
img_t, *_ = convert_data_type(img, dtype=torch.float32) # type: ignore
return img_t
if len(self._zoom) == 1:
# to keep the spatial shape ratio, use same random zoom factor for all dims
self._zoom = ensure_tuple_rep(self._zoom[0], img.ndim - 1)
elif len(self._zoom) == 2 and img.ndim > 3:
# if 2 zoom factors provided for 3D data, use the first factor for H and W dims, second factor for D dim
self._zoom = ensure_tuple_rep(self._zoom[0], img.ndim - 2) + ensure_tuple(self._zoom[-1])
zoomer = Zoom(self._zoom, keep_size=self.keep_size, **self.np_kwargs)
return np.asarray(
zoomer(
img,
mode=look_up_option(mode or self.mode, InterpolateMode),
padding_mode=look_up_option(padding_mode or self.padding_mode, NumpyPadMode),
align_corners=self.align_corners if align_corners is None else align_corners,
),
dtype=_dtype,
zoomer = Zoom(
self._zoom,
keep_size=self.keep_size,
mode=look_up_option(mode or self.mode, InterpolateMode),
padding_mode=look_up_option(padding_mode or self.padding_mode, NumpyPadMode),
align_corners=align_corners or self.align_corners,
**self.np_kwargs,
)
return zoomer(img)


class AffineGrid(Transform):
Expand Down
16 changes: 10 additions & 6 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -1537,6 +1537,8 @@ class Zoomd(MapTransform, InvertibleTransform):

"""

backend = Zoom.backend

def __init__(
self,
keys: KeysCollection,
Expand All @@ -1554,7 +1556,7 @@ def __init__(
self.align_corners = ensure_tuple_rep(align_corners, len(self.keys))
self.zoomer = Zoom(zoom=zoom, keep_size=keep_size, **np_kwargs)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key, mode, padding_mode, align_corners in self.key_iterator(
d, self.mode, self.padding_mode, self.align_corners
Expand All @@ -1576,7 +1578,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda
)
return d

def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = deepcopy(dict(data))
for key in self.key_iterator(d):
transform = self.get_most_recent_transform(d, key)
Expand All @@ -1594,7 +1596,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar
align_corners=None if align_corners == "none" else align_corners,
)
# Size might be out by 1 voxel so pad
d[key] = SpatialPad(transform[InverseKeys.ORIG_SIZE], mode="edge")(d[key])
d[key] = SpatialPad(transform[InverseKeys.ORIG_SIZE], mode="edge")(d[key]) # type: ignore
# Remove the applied transform
self.pop_transform(d, key)

Expand Down Expand Up @@ -1637,6 +1639,8 @@ class RandZoomd(RandomizableTransform, MapTransform, InvertibleTransform):

"""

backend = Zoom.backend

def __init__(
self,
keys: KeysCollection,
Expand Down Expand Up @@ -1669,7 +1673,7 @@ def randomize(self, data: Optional[Any] = None) -> None:
super().randomize(None)
self._zoom = [self.R.uniform(l, h) for l, h in zip(self.min_zoom, self.max_zoom)]

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
# match the spatial dim of first item
self.randomize()
d = dict(data)
Expand Down Expand Up @@ -1704,7 +1708,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda
)
return d

def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = deepcopy(dict(data))
for key in self.key_iterator(d):
transform = self.get_most_recent_transform(d, key)
Expand All @@ -1724,7 +1728,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar
align_corners=None if align_corners == "none" else align_corners,
)
# Size might be out by 1 voxel so pad
d[key] = SpatialPad(transform[InverseKeys.ORIG_SIZE], mode="edge")(d[key])
d[key] = SpatialPad(transform[InverseKeys.ORIG_SIZE], mode="edge")(d[key]) # type: ignore
# Remove the applied transform
self.pop_transform(d, key)

Expand Down
86 changes: 42 additions & 44 deletions tests/test_rand_zoom.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,44 +17,40 @@

from monai.transforms import RandZoom
from monai.utils import GridSampleMode, InterpolateMode
from tests.utils import NumpyImageTestCase2D
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose

VALID_CASES = [(0.8, 1.2, "nearest", False), (0.8, 1.2, InterpolateMode.NEAREST, False)]


class TestRandZoom(NumpyImageTestCase2D):
@parameterized.expand(VALID_CASES)
def test_correct_results(self, min_zoom, max_zoom, mode, keep_size):
random_zoom = RandZoom(
prob=1.0,
min_zoom=min_zoom,
max_zoom=max_zoom,
mode=mode,
keep_size=keep_size,
)
random_zoom.set_random_state(1234)
zoomed = random_zoom(self.imt[0])
expected = []
for channel in self.imt[0]:
expected.append(zoom_scipy(channel, zoom=random_zoom._zoom, mode="nearest", order=0, prefilter=False))
expected = np.stack(expected).astype(np.float32)
np.testing.assert_allclose(zoomed, expected, atol=1.0)
for p in TEST_NDARRAYS:
random_zoom = RandZoom(
prob=1.0,
min_zoom=min_zoom,
max_zoom=max_zoom,
mode=mode,
keep_size=keep_size,
)
random_zoom.set_random_state(1234)
zoomed = random_zoom(p(self.imt[0]))
expected = []
for channel in self.imt[0]:
expected.append(zoom_scipy(channel, zoom=random_zoom._zoom, mode="nearest", order=0, prefilter=False))
expected = np.stack(expected).astype(np.float32)
assert_allclose(zoomed, expected, atol=1.0)

def test_keep_size(self):
random_zoom = RandZoom(
prob=1.0,
min_zoom=0.6,
max_zoom=0.7,
keep_size=True,
padding_mode="constant",
constant_values=2,
)
zoomed = random_zoom(self.imt[0])
self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:]))
zoomed = random_zoom(self.imt[0])
self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:]))
zoomed = random_zoom(self.imt[0])
self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:]))
for p in TEST_NDARRAYS:
im = p(self.imt[0])
random_zoom = RandZoom(prob=1.0, min_zoom=0.6, max_zoom=0.7, keep_size=True)
zoomed = random_zoom(im)
self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:]))
zoomed = random_zoom(im)
self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:]))
zoomed = random_zoom(im)
self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:]))

@parameterized.expand(
[
Expand All @@ -64,23 +60,25 @@ def test_keep_size(self):
]
)
def test_invalid_inputs(self, _, min_zoom, max_zoom, mode, raises):
with self.assertRaises(raises):
random_zoom = RandZoom(prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, mode=mode)
random_zoom(self.imt[0])
for p in TEST_NDARRAYS:
with self.assertRaises(raises):
random_zoom = RandZoom(prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, mode=mode)
random_zoom(p(self.imt[0]))

def test_auto_expand_3d(self):
random_zoom = RandZoom(
prob=1.0,
min_zoom=[0.8, 0.7],
max_zoom=[1.2, 1.3],
mode="nearest",
keep_size=False,
)
random_zoom.set_random_state(1234)
test_data = np.random.randint(0, 2, size=[2, 2, 3, 4])
zoomed = random_zoom(test_data)
np.testing.assert_allclose(random_zoom._zoom, (1.048844, 1.048844, 0.962637), atol=1e-2)
np.testing.assert_allclose(zoomed.shape, (2, 2, 3, 3))
for p in TEST_NDARRAYS:
random_zoom = RandZoom(
prob=1.0,
min_zoom=[0.8, 0.7],
max_zoom=[1.2, 1.3],
mode="nearest",
keep_size=False,
)
random_zoom.set_random_state(1234)
test_data = p(np.random.randint(0, 2, size=[2, 2, 3, 4]))
zoomed = random_zoom(test_data)
assert_allclose(random_zoom._zoom, (1.048844, 1.048844, 0.962637), atol=1e-2)
assert_allclose(zoomed.shape, (2, 2, 3, 3))


if __name__ == "__main__":
Expand Down
Loading