Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions monai/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ class Zoom:
order (int): order of interpolation. Default=3.
mode (str): Determines how input is extended beyond boundaries. Default is 'constant'.
cval (scalar, optional): Value to fill past edges. Default is 0.
use_gpu (bool): Should use cpu or gpu.
use_gpu (bool): Should use cpu or gpu. Uses cupyx which doesn't support order > 1 and modes
'wrap' and 'reflect'. Defaults to cpu for these cases or if cupyx not found.
keep_size (bool): Should keep original size (pad if needed).
"""
def __init__(self, zoom, order=3, mode='constant', cval=0, prefilter=True, use_gpu=False, keep_size=False):
Expand All @@ -210,13 +211,13 @@ def __call__(self, img):

zoomed_gpu = zoom_gpu(cupy.array(img), zoom=self.zoom, order=self.order,
mode=self.mode, cval=self.cval, prefilter=self.prefilter)
zoomed = cupy.asnumpy()
zoomed = cupy.asnumpy(zoomed_gpu)
except ModuleNotFoundError:
print('For GPU zoom, please install cupy. Defaulting to cpu.')
except Exception:
print('Warning: Zoom gpu failed. Defaulting to cpu.')
except NotImplementedError:
print("Defaulting to CPU. cupyx doesn't support order > 1 and modes 'wrap' or 'reflect'.")

if not zoomed or not self.use_gpu:
if zoomed is None:
zoomed = scipy.ndimage.zoom(img, zoom=self.zoom, order=self.order,
mode=self.mode, cval=self.cval, prefilter=self.prefilter)

Expand Down
17 changes: 10 additions & 7 deletions tests/test_zoom.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import unittest

import numpy as np
import importlib

from scipy.ndimage import zoom as zoom_scipy
from parameterized import parameterized

Expand All @@ -35,15 +37,16 @@ def test_correct_results(self, zoom, order, mode, cval, prefilter, use_gpu, keep
self.assertTrue(np.allclose(expected, zoomed))

@parameterized.expand([
("gpu_zoom", 0.6, 3, 'constant', 0, True)
("gpu_zoom", 0.6, 1, 'constant', 0, True)
])
def test_gpu_zoom(self, _, zoom, order, mode, cval, prefilter):
zoom_fn = Zoom(zoom=zoom, order=order, mode=mode, cval=cval,
prefilter=prefilter, use_gpu=True, keep_size=False)
zoomed = zoom_fn(self.imt)
expected = zoom_scipy(self.imt, zoom=zoom, mode=mode, order=order,
cval=cval, prefilter=prefilter)
self.assertTrue(np.allclose(expected, zoomed))
if importlib.util.find_spec('cupy'):
zoom_fn = Zoom(zoom=zoom, order=order, mode=mode, cval=cval,
prefilter=prefilter, use_gpu=True, keep_size=False)
zoomed = zoom_fn(self.imt)
expected = zoom_scipy(self.imt, zoom=zoom, mode=mode, order=order,
cval=cval, prefilter=prefilter)
self.assertTrue(np.allclose(expected, zoomed))

def test_keep_size(self):
zoom_fn = Zoom(zoom=0.6, keep_size=True)
Expand Down