Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions monai/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""

import numpy as np
from scipy.ndimage import zoom as zoom_cpu
import torch

import monai
Expand Down Expand Up @@ -62,6 +63,65 @@ def __call__(self, img):
return rescale_array(img, self.minv, self.maxv, self.dtype)


@export
class Zoom:
""" Zooms a 3d image. Uses scipy.ndimage.zoom or cupyx.scipy.ndimage.zoom in case of gpu.
For details, please see https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.zoom.html.

Args:
zoom (float or sequence): The zoom factor along the axes. If a float, zoom is the same for each axis.
If a sequence, zoom should contain one value for each axis.
order (int): order of interpolation. Default=3.
mode (str): Determines how input is extended beyond boundaries. Default is 'constant'.
cval (scalar, optional): Value to fill past edges. Default is 0.
use_gpu (bool): Should use cpu or gpu.
keep_size (bool): Should keep original size (pad if needed).
"""
def __init__(self, zoom, order=3, mode='constant', cval=0, prefilter=True, use_gpu=False, keep_size=False):
assert isinstance(order, int), "Order must be integer."
self.zoom = zoom
self.order = order
self.mode = mode
self.cval = cval
self.prefilter = prefilter
self.use_gpu = use_gpu
self.keep_size = keep_size

def __call__(self, img):
zoomed = None
if self.use_gpu:
try:
import cupy
from cupyx.scipy.ndimage import zoom as zoom_gpu

zoomed_gpu = zoom_gpu(cupy.array(img), zoom=self.zoom, order=self.order,
mode=self.mode, cval=self.cval, prefilter=self.prefilter)
zoomed = cupy.asnumpy()
except Exception:
print('Warning: Zoom gpu failed. Defaulting to cpu.')
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Defaulting to cpu if gpu not found. Is this expected behavior?


if not zoomed or not self.use_gpu:
zoomed = zoom_cpu(img, zoom=self.zoom, order=self.order,
mode=self.mode, cval=self.cval, prefilter=self.prefilter)

# Crops to original size or pads.
if self.keep_size:
shape = img.shape
pad_vec = [[0, 0]] * len(shape)
crop_vec = list(zoomed.shape)
for d in range(len(shape)):
if zoomed.shape[d] > shape[d]:
crop_vec[d] = shape[d]
elif zoomed.shape[d] < shape[d]:
# pad_vec[d] = [0, shape[d] - zoomed.shape[d]]
pad_h = (float(shape[d]) - float(zoomed.shape[d])) / 2
pad_vec[d] = [int(np.floor(pad_h)), int(np.ceil(pad_h))]
zoomed = zoomed[0:crop_vec[0], 0:crop_vec[1], 0:crop_vec[2]]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To handle both 2d and 3d i need to way to write this sentence if crop_vec is of size 3 or 4. @wyli @Nic-Ma do you know how to do this?

zoomed = np.pad(zoomed, pad_vec, mode='constant', constant_values=self.cval)

return zoomed


@export
class ToTensor:
"""
Expand Down
65 changes: 65 additions & 0 deletions tests/test_zoom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
from scipy.ndimage import zoom as zoom_scipy
from parameterized import parameterized

import torch
from monai.transforms import Zoom
from tests.utils import NumpyImageTestCase2D


class ZoomTest(NumpyImageTestCase2D):

@parameterized.expand([
(1.1, 3, 'constant', 0, True, False, False),
(0.9, 3, 'constant', 0, True, False, False),
(0.8, 1, 'reflect', 0, False, False, False)
])
def test_correct_results(self, zoom, order, mode, cval, prefilter, use_gpu, keep_size):
zoom_fn = Zoom(zoom=zoom, order=order, mode=mode, cval=cval,
prefilter=prefilter, use_gpu=use_gpu, keep_size=keep_size)
zoomed = zoom_fn(self.imt)
expected = zoom_scipy(self.imt, zoom=zoom, mode=mode, order=order,
cval=cval, prefilter=prefilter)
self.assertTrue(np.allclose(expected, zoomed))

@parameterized.expand([
("gpu_zoom", 0.6, 3, 'constant', 0, True)
])
def test_gpu_zoom(self, _, zoom, order, mode, cval, prefilter):
zoom_fn = Zoom(zoom=zoom, order=order, mode=mode, cval=cval,
prefilter=prefilter, use_gpu=True, keep_size=False)
zoomed = zoom_fn(self.imt)
expected = zoom_scipy(self.imt, zoom=zoom, mode=mode, order=order,
cval=cval, prefilter=prefilter)
self.assertTrue(np.allclose(expected, zoomed))

def test_keep_size(self):
zoom_fn = Zoom(zoom=0.6, keep_size=True)
zoomed = zoom_fn(self.imt)
self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape))

@parameterized.expand([
("no_zoom", None, 1, TypeError),
("invalid_order", 0.9, 's', AssertionError)
])
def test_invalid_inputs(self, _, zoom, order, raises):
with self.assertRaises(raises):
zoom_fn = Zoom(zoom=zoom, order=order)
zoomed = zoom_fn(self.imt)


if __name__ == '__main__':
unittest.main()