Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 227 additions & 3 deletions monai/transforms/composables.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from monai.transforms.compose import Randomizable, Transform
from monai.transforms.transforms import (LoadNifti, AsChannelFirst, Orientation,
AddChannel, Spacing, Rotate90, SpatialCrop,
RandAffine, Rand2DElastic, Rand3DElastic)
RandAffine, Rand2DElastic, Rand3DElastic,
Flip, Rotate, Zoom)
from monai.utils.misc import ensure_tuple
from monai.transforms.utils import generate_pos_neg_label_crop_centers, create_grid
from monai.utils.aliases import alias
Expand Down Expand Up @@ -476,7 +477,6 @@ def __init__(self, keys,
as_tensor_output (bool): the computation is implemented using pytorch tensors, this option specifies
whether to convert it back to numpy arrays.
device (torch.device): device on which the tensor will be allocated.

See also:
- ``RandAffineGrid`` for the random affine paramters configurations.
- ``Affine`` for the affine transformation parameters configurations.
Expand Down Expand Up @@ -551,7 +551,6 @@ def __init__(self, keys,
as_tensor_output (bool): the computation is implemented using pytorch tensors, this option specifies
whether to convert it back to numpy arrays.
device (torch.device): device on which the tensor will be allocated.

See also:
- ``RandAffineGrid`` for the random affine paramters configurations.
- ``Affine`` for the affine transformation parameters configurations.
Expand Down Expand Up @@ -594,3 +593,228 @@ def __call__(self, data):
for key in self.keys: # same interpolation mode
d[key] = self.rand_3d_elastic.resampler(d[key], grid, mode=self.rand_3d_elastic.mode)
return d


@export
@alias('FlipD', 'FlipDict')
class Flipd(MapTransform):
"""Dictionary-based wrapper of Flip.

Args:
keys (dict): Keys to pick data for transformation.
axis (None, int or tuple of ints): Axes along which to flip over. Default is None.
"""

def __init__(self, keys, axis=None):
MapTransform.__init__(self, keys)
self.flipper = Flip(axis=axis)

def __call__(self, data):
d = dict(data)
for key in self.keys:
d[key] = self.flipper(d[key])
return d


@export
@alias('RandFlipD', 'RandFlipDict')
class RandFlipd(Randomizable, MapTransform):
"""Dict-based wrapper of RandFlip.

Args:
prob (float): Probability of flipping.
axis (None, int or tuple of ints): Axes along which to flip over. Default is None.
"""

def __init__(self, keys, prob=0.1, axis=None):
MapTransform.__init__(self, keys)
self.axis = axis
self.prob = prob

self._do_transform = False
self.flipper = Flip(axis=axis)

def randomize(self):
self._do_transform = self.R.random_sample() < self.prob

def __call__(self, data):
self.randomize()
d = dict(data)
if not self._do_transform:
return d
for key in self.keys:
d[key] = self.flipper(d[key])
return d


@export
@alias('RotateD', 'RotateDict')
class Rotated(MapTransform):
"""Dictionary-based wrapper of Rotate.

Args:
keys (dict): Keys to pick data for transformation.
angle (float): Rotation angle in degrees.
axes (tuple of 2 ints): Axes of rotation. Default: (1, 2). This is the first two
axis in spatial dimensions according to MONAI channel first shape assumption.
reshape (bool): If true, output shape is made same as input. Default: True.
order (int): Order of spline interpolation. Range 0-5. Default: 1. This is
different from scipy where default interpolation is 3.
mode (str): Points outside boundary filled according to this mode. Options are
'constant', 'nearest', 'reflect', 'wrap'. Default: 'constant'.
cval (scalar): Values to fill outside boundary. Default: 0.
prefiter (bool): Apply spline_filter before interpolation. Default: True.
"""

def __init__(self, keys, angle, axes=(1, 2), reshape=True, order=1,
mode='constant', cval=0, prefilter=True):
MapTransform.__init__(self, keys)
self.rotator = Rotate(angle=angle, axes=axes, reshape=reshape,
order=order, mode=mode, cval=cval, prefilter=prefilter)

def __call__(self, data):
d = dict(data)
for key in self.keys:
d[key] = self.rotator(d[key])
return d


@export
@alias('RandRotateD', 'RandRotateDict')
class RandRotated(Randomizable, MapTransform):
"""Randomly rotates the input arrays.

Args:
prob (float): Probability of rotation.
degrees (tuple of float or float): Range of rotation in degrees. If single number,
angle is picked from (-degrees, degrees).
axes (tuple of 2 ints): Axes of rotation. Default: (1, 2). This is the first two
axis in spatial dimensions according to MONAI channel first shape assumption.
reshape (bool): If true, output shape is made same as input. Default: True.
order (int): Order of spline interpolation. Range 0-5. Default: 1. This is
different from scipy where default interpolation is 3.
mode (str): Points outside boundary filled according to this mode. Options are
'constant', 'nearest', 'reflect', 'wrap'. Default: 'constant'.
cval (scalar): Value to fill outside boundary. Default: 0.
prefiter (bool): Apply spline_filter before interpolation. Default: True.
"""
def __init__(self, keys, degrees, prob=0.1, axes=(1, 2), reshape=True, order=1,
mode='constant', cval=0, prefilter=True):
MapTransform.__init__(self, keys)
self.prob = prob
self.degrees = degrees
self.reshape = reshape
self.order = order
self.mode = mode
self.cval = cval
self.prefilter = prefilter
self.axes = axes

if not hasattr(self.degrees, '__iter__'):
self.degrees = (-self.degrees, self.degrees)
assert len(self.degrees) == 2, "degrees should be a number or pair of numbers."

self._do_transform = False
self.angle = None

def randomize(self):
self._do_transform = self.R.random_sample() < self.prob
self.angle = self.R.uniform(low=self.degrees[0], high=self.degrees[1])

def __call__(self, data):
self.randomize()
d = dict(data)
if not self._do_transform:
return d
rotator = Rotate(self.angle, self.axes, self.reshape, self.order,
self.mode, self.cval, self.prefilter)
for key in self.keys:
d[key] = self.flipper(d[key])
return d


@export
@alias('ZoomD', 'ZoomDict')
class Zoomd(MapTransform):
"""Dictionary-based wrapper of Zoom transform.

Args:
zoom (float or sequence): The zoom factor along the spatial axes.
If a float, zoom is the same for each spatial axis.
If a sequence, zoom should contain one value for each spatial axis.
order (int): order of interpolation. Default=3.
mode (str): Determines how input is extended beyond boundaries. Default is 'constant'.
cval (scalar, optional): Value to fill past edges. Default is 0.
use_gpu (bool): Should use cpu or gpu. Uses cupyx which doesn't support order > 1 and modes
'wrap' and 'reflect'. Defaults to cpu for these cases or if cupyx not found.
keep_size (bool): Should keep original size (pad if needed).
"""

def __init__(self, keys, zoom, order=3, mode='constant', cval=0,
prefilter=True, use_gpu=False, keep_size=False):
MapTransform.__init__(self, keys)
self.zoomer = Zoom(zoom=zoom, order=order, mode=mode, cval=cval,
prefilter=prefilter, use_gpu=use_gpu, keep_size=keep_size)

def __call__(self, data):
d = dict(data)
for key in self.keys:
d[key] = self.zoomer(d[key])
return d


@export
@alias('RandZoomD', 'RandZoomDict')
class RandZoomd(Randomizable, MapTransform):
"""Dict-based wrapper of RandZoom.

Args:
keys (dict): Keys to pick data for transformation.
prob (float): Probability of zooming.
min_zoom (float or sequence): Min zoom factor. Can be float or sequence same size as image.
max_zoom (float or sequence): Max zoom factor. Can be float or sequence same size as image.
order (int): order of interpolation. Default=3.
mode ('reflect', 'constant', 'nearest', 'mirror', 'wrap'): Determines how input is
extended beyond boundaries. Default: 'constant'.
cval (scalar, optional): Value to fill past edges. Default is 0.
use_gpu (bool): Should use cpu or gpu. Uses cupyx which doesn't support order > 1 and modes
'wrap' and 'reflect'. Defaults to cpu for these cases or if cupyx not found.
keep_size (bool): Should keep original size (pad if needed).
"""

def __init__(self, keys, prob=0.1, min_zoom=0.9,
max_zoom=1.1, order=3, mode='constant',
cval=0, prefilter=True, use_gpu=False, keep_size=False):
MapTransform.__init__(self, keys)
if hasattr(min_zoom, '__iter__') and \
hasattr(max_zoom, '__iter__'):
assert len(min_zoom) == len(max_zoom), "min_zoom and max_zoom must have same length."
self.min_zoom = min_zoom
self.max_zoom = max_zoom
self.prob = prob
self.order = order
self.mode = mode
self.cval = cval
self.prefilter = prefilter
self.use_gpu = use_gpu
self.keep_size = keep_size

self._do_transform = False
self._zoom = None

def randomize(self):
self._do_transform = self.R.random_sample() < self.prob
if hasattr(self.min_zoom, '__iter__'):
self._zoom = (self.R.uniform(l, h) for l, h in zip(self.min_zoom, self.max_zoom))
else:
self._zoom = self.R.uniform(self.min_zoom, self.max_zoom)

def __call__(self, data):
self.randomize()
d = dict(data)
if not self._do_transform:
return d
zoomer = Zoom(self._zoom, self.order, self.mode, self.cval, self.prefilter, self.use_gpu, self.keep_size)
for key in self.keys:
d[key] = zoomer(d[key])
return d
3 changes: 2 additions & 1 deletion monai/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def __call__(self, img):
pad_vec[idx] = [half, diff - half]
elif diff < 0: # need slicing
slice_vec[idx] = slice(half, half + od)
zoomed = np.pad(zoomed, pad_vec)
zoomed = np.pad(zoomed, pad_vec, mode='constant')
return zoomed[tuple(slice_vec)]


Expand Down Expand Up @@ -696,6 +696,7 @@ def __init__(self, prob=0.1, axis=None):
self.flipper = Flip(axis=axis)

self._do_transform = False
self.flipper = Flip(axis=axis)

def randomize(self):
self._do_transform = self.R.random_sample() < self.prob
Expand Down
33 changes: 23 additions & 10 deletions tests/test_flip.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,31 +14,44 @@
import numpy as np
from parameterized import parameterized

from monai.transforms import Flip
from monai.transforms import Flip, Flipd
from tests.utils import NumpyImageTestCase2D

INVALID_CASES = [("wrong_axis", ['s', 1], TypeError),
("not_numbers", 's', TypeError)]

VALID_CASES = [("no_axis", None),
("one_axis", 1),
("many_axis", [0, 1, 2])]


class FlipTest(NumpyImageTestCase2D):

@parameterized.expand([
("wrong_axis", ['s', 1], TypeError),
("not_numbers", 's', TypeError)
])
@parameterized.expand(INVALID_CASES)
def test_invalid_inputs(self, _, axis, raises):
with self.assertRaises(raises):
flip = Flip(axis)
flip(self.imt)

@parameterized.expand([
("no_axis", None),
("one_axis", 1),
("many_axis", [0, 1, 2])
])
@parameterized.expand(INVALID_CASES)
def test_invalid_cases_dict(self, _, axis, raises):
with self.assertRaises(raises):
flip = Flipd(keys='img', axis=axis)
flip({'img': self.imt})

@parameterized.expand(VALID_CASES)
def test_correct_results(self, _, axis):
flip = Flip(axis=axis)
expected = np.flip(self.imt, axis)
self.assertTrue(np.allclose(expected, flip(self.imt)))

@parameterized.expand(VALID_CASES)
def test_correct_results_dict(self, _, axis):
flip = Flipd(keys='img', axis=axis)
expected = np.flip(self.imt, axis)
res = flip({'img': self.imt})
assert np.allclose(expected, res['img'])


if __name__ == '__main__':
unittest.main()
29 changes: 18 additions & 11 deletions tests/test_random_flip.py → tests/test_rand_flip.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,31 +14,38 @@
import numpy as np
from parameterized import parameterized

from monai.transforms import RandFlip
from monai.transforms import RandFlip, RandFlipd
from tests.utils import NumpyImageTestCase2D

INVALID_CASES = [("wrong_axis", ['s', 1], TypeError),
("not_numbers", 's', TypeError)]

class RandomFlipTest(NumpyImageTestCase2D):
VALID_CASES = [("no_axis", None),
("one_axis", 1),
("many_axis", [0, 1, 2])]

@parameterized.expand([
("wrong_axis", ['s', 1], TypeError),
("not_numbers", 's', TypeError)
])
class RandFlipTest(NumpyImageTestCase2D):

@parameterized.expand(INVALID_CASES)
def test_invalid_inputs(self, _, axis, raises):
with self.assertRaises(raises):
flip = RandFlip(prob=1.0, axis=axis)
flip(self.imt)

@parameterized.expand([
("no_axis", None),
("one_axis", 1),
("many_axis", [0, 1, 2])
])
@parameterized.expand(VALID_CASES)
def test_correct_results(self, _, axis):
flip = RandFlip(prob=1.0, axis=axis)
expected = np.flip(self.imt, axis)
self.assertTrue(np.allclose(expected, flip(self.imt)))

@parameterized.expand(VALID_CASES)
def test_correct_results_dict(self, _, axis):
flip = RandFlipd(keys='img', prob=1.0, axis=axis)
res = flip({'img': self.imt})

expected = np.flip(self.imt, axis)
self.assertTrue(np.allclose(expected, res['img']))


if __name__ == '__main__':
unittest.main()
File renamed without changes.
Loading