From af7ed0c7aec3227562eeedaa0f4ef94e5f26119d Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 17 Sep 2021 15:05:54 +0100 Subject: [PATCH 01/20] enhance affinegrid to use torch backend Signed-off-by: Wenqi Li --- monai/transforms/utils.py | 109 ++++++++++++++++++++++----- tests/test_create_grid_and_affine.py | 6 +- 2 files changed, 94 insertions(+), 21 deletions(-) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 05e45bf26f..0388b3ada2 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -576,7 +576,9 @@ def create_control_grid( return create_grid(grid_shape, spacing, homogeneous, dtype) -def create_rotate(spatial_dims: int, radians: Union[Sequence[float], float]) -> np.ndarray: +def create_rotate( + spatial_dims: int, radians: Union[Sequence[float], float], backend=TransformBackends.NUMPY +) -> NdarrayOrTensor: """ create a 2D or 3D rotation matrix @@ -585,48 +587,73 @@ def create_rotate(spatial_dims: int, radians: Union[Sequence[float], float]) -> radians: rotation radians when spatial_dims == 3, the `radians` sequence corresponds to rotation in the 1st, 2nd, and 3rd dim respectively. + backend: APIs to use, ``numpy`` or ``torch``. Raises: ValueError: When ``radians`` is empty. ValueError: When ``spatial_dims`` is not one of [2, 3]. """ + if look_up_option(backend, TransformBackends) == TransformBackends.NUMPY: + sin_func = np.sin + cos_func = np.cos + array_func = np.array + elif look_up_option(backend, TransformBackends) == TransformBackends.TORCH: + sin_func = lambda th: torch.sin(torch.as_tensor(th)) # type: ignore + cos_func = lambda th: torch.cos(torch.as_tensor(th)) # type: ignore + array_func = torch.tensor # type: ignore + else: + raise ValueError("backend {} is not supported".format(backend)) + return _create_rotate( + spatial_dims=spatial_dims, radians=radians, sin_func=sin_func, cos_func=cos_func, array_func=array_func + ) + + +def _create_rotate( + spatial_dims: int, + radians: Union[Sequence[float], float], + sin_func: Callable = np.sin, + cos_func: Callable = np.cos, + array_func: Callable = np.array, +) -> NdarrayOrTensor: radians = ensure_tuple(radians) if spatial_dims == 2: if len(radians) >= 1: - sin_, cos_ = np.sin(radians[0]), np.cos(radians[0]) - return np.array([[cos_, -sin_, 0.0], [sin_, cos_, 0.0], [0.0, 0.0, 1.0]]) + sin_, cos_ = sin_func(radians[0]), cos_func(radians[0]) + return array_func([[cos_, -sin_, 0.0], [sin_, cos_, 0.0], [0.0, 0.0, 1.0]]) # type: ignore raise ValueError("radians must be non empty.") if spatial_dims == 3: affine = None if len(radians) >= 1: - sin_, cos_ = np.sin(radians[0]), np.cos(radians[0]) - affine = np.array( + sin_, cos_ = sin_func(radians[0]), cos_func(radians[0]) + affine = array_func( [[1.0, 0.0, 0.0, 0.0], [0.0, cos_, -sin_, 0.0], [0.0, sin_, cos_, 0.0], [0.0, 0.0, 0.0, 1.0]] ) if len(radians) >= 2: - sin_, cos_ = np.sin(radians[1]), np.cos(radians[1]) + sin_, cos_ = sin_func(radians[1]), cos_func(radians[1]) if affine is None: raise ValueError("Affine should be a matrix.") - affine = affine @ np.array( + affine = affine @ array_func( [[cos_, 0.0, sin_, 0.0], [0.0, 1.0, 0.0, 0.0], [-sin_, 0.0, cos_, 0.0], [0.0, 0.0, 0.0, 1.0]] ) if len(radians) >= 3: - sin_, cos_ = np.sin(radians[2]), np.cos(radians[2]) + sin_, cos_ = sin_func(radians[2]), cos_func(radians[2]) if affine is None: raise ValueError("Affine should be a matrix.") - affine = affine @ np.array( + affine = affine @ array_func( [[cos_, -sin_, 0.0, 0.0], [sin_, cos_, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]] ) if affine is None: raise ValueError("radians must be non empty.") - return affine + return affine # type: ignore raise ValueError(f"Unsupported spatial_dims: {spatial_dims}, available options are [2, 3].") -def create_shear(spatial_dims: int, coefs: Union[Sequence[float], float]) -> np.ndarray: +def create_shear( + spatial_dims: int, coefs: Union[Sequence[float], float], backend=TransformBackends.NUMPY +) -> NdarrayOrTensor: """ create a shearing matrix @@ -642,16 +669,28 @@ def create_shear(spatial_dims: int, coefs: Union[Sequence[float], float]) -> np. [0.0, 0.0, 0.0, 1.0], ] + backend: APIs to use, ``numpy`` or ``torch``. + Raises: NotImplementedError: When ``spatial_dims`` is not one of [2, 3]. """ + if look_up_option(backend, TransformBackends) == TransformBackends.NUMPY: + array_func = np.array + elif look_up_option(backend, TransformBackends) == TransformBackends.TORCH: + array_func = torch.tensor # type: ignore + else: + raise ValueError("backend {} is not supported".format(backend)) + return _create_shear(spatial_dims=spatial_dims, coefs=coefs, array_func=array_func) + + +def _create_shear(spatial_dims: int, coefs: Union[Sequence[float], float], array_func=np.array) -> NdarrayOrTensor: if spatial_dims == 2: coefs = ensure_tuple_size(coefs, dim=2, pad_val=0.0) - return np.array([[1, coefs[0], 0.0], [coefs[1], 1.0, 0.0], [0.0, 0.0, 1.0]]) + return array_func([[1, coefs[0], 0.0], [coefs[1], 1.0, 0.0], [0.0, 0.0, 1.0]]) # type: ignore if spatial_dims == 3: coefs = ensure_tuple_size(coefs, dim=6, pad_val=0.0) - return np.array( + return array_func( # type: ignore [ [1.0, coefs[0], coefs[1], 0.0], [coefs[2], 1.0, coefs[3], 0.0], @@ -662,31 +701,63 @@ def create_shear(spatial_dims: int, coefs: Union[Sequence[float], float]) -> np. raise NotImplementedError("Currently only spatial_dims in [2, 3] are supported.") -def create_scale(spatial_dims: int, scaling_factor: Union[Sequence[float], float]): +def create_scale( + spatial_dims: int, scaling_factor: Union[Sequence[float], float], backend=TransformBackends.NUMPY +) -> NdarrayOrTensor: """ create a scaling matrix Args: spatial_dims: spatial rank scaling_factor: scaling factors for every spatial dim, defaults to 1. + backend: APIs to use, ``numpy`` or ``torch``. """ + if look_up_option(backend, TransformBackends) == TransformBackends.NUMPY: + array_func = np.diag + elif look_up_option(backend, TransformBackends) == TransformBackends.TORCH: + array_func = lambda x: torch.diag(torch.as_tensor(x)) # type: ignore + else: + raise ValueError("backend {} is not supported".format(backend)) + return _create_scale(spatial_dims=spatial_dims, scaling_factor=scaling_factor, array_func=array_func) + + +def _create_scale( + spatial_dims: int, scaling_factor: Union[Sequence[float], float], array_func=np.diag +) -> NdarrayOrTensor: scaling_factor = ensure_tuple_size(scaling_factor, dim=spatial_dims, pad_val=1.0) - return np.diag(scaling_factor[:spatial_dims] + (1.0,)) + return array_func(scaling_factor[:spatial_dims] + (1.0,)) # type: ignore -def create_translate(spatial_dims: int, shift: Union[Sequence[float], float]) -> np.ndarray: +def create_translate( + spatial_dims: int, shift: Union[Sequence[float], float], backend=TransformBackends.NUMPY +) -> NdarrayOrTensor: """ create a translation matrix Args: spatial_dims: spatial rank shift: translate pixel/voxel for every spatial dim, defaults to 0. - """ + backend: APIs to use, ``numpy`` or ``torch``. + """ + if look_up_option(backend, TransformBackends) == TransformBackends.NUMPY: + eye_func = np.eye + array_func = np.asarray + elif look_up_option(backend, TransformBackends) == TransformBackends.TORCH: + eye_func = lambda x: torch.eye(torch.as_tensor(x)) # type: ignore + array_func = torch.as_tensor + else: + raise ValueError("backend {} is not supported".format(backend)) + return _create_translate(spatial_dims=spatial_dims, shift=shift, eye_func=eye_func, array_func=array_func) + + +def _create_translate( + spatial_dims: int, shift: Union[Sequence[float], float], eye_func=np.eye, array_func=np.asarray +) -> NdarrayOrTensor: shift = ensure_tuple(shift) - affine = np.eye(spatial_dims + 1) + affine = eye_func(spatial_dims + 1) for i, a in enumerate(shift[:spatial_dims]): affine[i, spatial_dims] = a - return np.asarray(affine) + return array_func(affine) def generate_spatial_bounding_box( diff --git a/tests/test_create_grid_and_affine.py b/tests/test_create_grid_and_affine.py index 0c0e52e04a..0e16d4bde5 100644 --- a/tests/test_create_grid_and_affine.py +++ b/tests/test_create_grid_and_affine.py @@ -21,6 +21,7 @@ create_shear, create_translate, ) +from tests.utils import assert_allclose class TestCreateGrid(unittest.TestCase): @@ -147,8 +148,9 @@ def test_create_control_grid(self): def test_assert(func, params, expected): - m = func(*params) - np.testing.assert_allclose(m, expected, atol=1e-7) + for b in ("torch", "numpy"): + m = func(*params, backend=b) + assert_allclose(m, expected, type_test=False, atol=1e-7) class TestCreateAffine(unittest.TestCase): From a2e0a9ee779e9e3387483736431af95f4981d4ae Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 17 Sep 2021 17:17:33 +0100 Subject: [PATCH 02/20] style fixes Signed-off-by: Wenqi Li --- monai/transforms/utils.py | 64 ++++++++++++++++++++------------------- 1 file changed, 33 insertions(+), 31 deletions(-) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 0388b3ada2..93a0a7aab7 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -595,18 +595,18 @@ def create_rotate( """ if look_up_option(backend, TransformBackends) == TransformBackends.NUMPY: - sin_func = np.sin - cos_func = np.cos - array_func = np.array - elif look_up_option(backend, TransformBackends) == TransformBackends.TORCH: - sin_func = lambda th: torch.sin(torch.as_tensor(th)) # type: ignore - cos_func = lambda th: torch.cos(torch.as_tensor(th)) # type: ignore - array_func = torch.tensor # type: ignore - else: - raise ValueError("backend {} is not supported".format(backend)) - return _create_rotate( - spatial_dims=spatial_dims, radians=radians, sin_func=sin_func, cos_func=cos_func, array_func=array_func - ) + return _create_rotate( + spatial_dims=spatial_dims, radians=radians, sin_func=np.sin, cos_func=np.cos, array_func=np.array + ) + if look_up_option(backend, TransformBackends) == TransformBackends.TORCH: + return _create_rotate( + spatial_dims=spatial_dims, + radians=radians, + sin_func=lambda th: torch.sin(torch.as_tensor(th)), + cos_func=lambda th: torch.cos(torch.as_tensor(th)), + array_func=torch.as_tensor, + ) + raise ValueError("backend {} is not supported".format(backend)) def _create_rotate( @@ -677,11 +677,10 @@ def create_shear( """ if look_up_option(backend, TransformBackends) == TransformBackends.NUMPY: array_func = np.array - elif look_up_option(backend, TransformBackends) == TransformBackends.TORCH: - array_func = torch.tensor # type: ignore - else: - raise ValueError("backend {} is not supported".format(backend)) - return _create_shear(spatial_dims=spatial_dims, coefs=coefs, array_func=array_func) + return _create_shear(spatial_dims=spatial_dims, coefs=coefs, array_func=np.array) + if look_up_option(backend, TransformBackends) == TransformBackends.TORCH: + return _create_shear(spatial_dims=spatial_dims, coefs=coefs, array_func=torch.as_tensor) + raise ValueError("backend {} is not supported".format(backend)) def _create_shear(spatial_dims: int, coefs: Union[Sequence[float], float], array_func=np.array) -> NdarrayOrTensor: @@ -713,12 +712,14 @@ def create_scale( backend: APIs to use, ``numpy`` or ``torch``. """ if look_up_option(backend, TransformBackends) == TransformBackends.NUMPY: - array_func = np.diag - elif look_up_option(backend, TransformBackends) == TransformBackends.TORCH: - array_func = lambda x: torch.diag(torch.as_tensor(x)) # type: ignore - else: - raise ValueError("backend {} is not supported".format(backend)) - return _create_scale(spatial_dims=spatial_dims, scaling_factor=scaling_factor, array_func=array_func) + return _create_scale(spatial_dims=spatial_dims, scaling_factor=scaling_factor, array_func=np.diag) + if look_up_option(backend, TransformBackends) == TransformBackends.TORCH: + return _create_scale( + spatial_dims=spatial_dims, + scaling_factor=scaling_factor, + array_func=lambda x: torch.diag(torch.as_tensor(x)), + ) + raise ValueError("backend {} is not supported".format(backend)) def _create_scale( @@ -740,14 +741,15 @@ def create_translate( backend: APIs to use, ``numpy`` or ``torch``. """ if look_up_option(backend, TransformBackends) == TransformBackends.NUMPY: - eye_func = np.eye - array_func = np.asarray - elif look_up_option(backend, TransformBackends) == TransformBackends.TORCH: - eye_func = lambda x: torch.eye(torch.as_tensor(x)) # type: ignore - array_func = torch.as_tensor - else: - raise ValueError("backend {} is not supported".format(backend)) - return _create_translate(spatial_dims=spatial_dims, shift=shift, eye_func=eye_func, array_func=array_func) + return _create_translate(spatial_dims=spatial_dims, shift=shift, eye_func=np.eye, array_func=np.asarray) + if look_up_option(backend, TransformBackends) == TransformBackends.TORCH: + return _create_translate( + spatial_dims=spatial_dims, + shift=shift, + eye_func=lambda x: torch.eye(torch.as_tensor(x)), # type: ignore + array_func=torch.as_tensor, + ) + raise ValueError("backend {} is not supported".format(backend)) def _create_translate( From fe663ab29d6415d923683f2c79a625b7432fbab1 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 17 Sep 2021 17:24:33 +0100 Subject: [PATCH 03/20] codeformat fixes Signed-off-by: Wenqi Li --- monai/transforms/spatial/array.py | 6 +++--- monai/transforms/utils.py | 7 +++---- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index df3d3eb093..9f65facf29 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -462,7 +462,7 @@ def __init__( self.padding_mode: GridSamplePadMode = look_up_option(padding_mode, GridSamplePadMode) self.align_corners = align_corners self.dtype = dtype - self._rotation_matrix: Optional[np.ndarray] = None + self._rotation_matrix: Optional[NdarrayOrTensor] = None def __call__( self, @@ -511,7 +511,7 @@ def __call__( corners = np.asarray(np.meshgrid(*[(0, dim) for dim in im_shape], indexing="ij")).reshape( (len(im_shape), -1) ) - corners = transform[:-1, :-1] @ corners + corners = transform[:-1, :-1] @ corners # type: ignore output_shape = np.asarray(corners.ptp(axis=1) + 0.5, dtype=int) shift_1 = create_translate(input_ndim, (-(output_shape - 1) / 2).tolist()) transform = shift @ transform @ shift_1 @@ -532,7 +532,7 @@ def __call__( out, *_ = convert_to_dst_type(output, dst=img, dtype=output.dtype) return out - def get_rotation_matrix(self) -> Optional[np.ndarray]: + def get_rotation_matrix(self) -> Optional[NdarrayOrTensor]: """ Get the most recently applied rotation matrix This is not thread-safe. diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 93a0a7aab7..bb6903fb86 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -602,8 +602,8 @@ def create_rotate( return _create_rotate( spatial_dims=spatial_dims, radians=radians, - sin_func=lambda th: torch.sin(torch.as_tensor(th)), - cos_func=lambda th: torch.cos(torch.as_tensor(th)), + sin_func=lambda th: torch.sin(torch.as_tensor(th, dtype=torch.float32)), + cos_func=lambda th: torch.cos(torch.as_tensor(th, dtype=torch.float32)), array_func=torch.as_tensor, ) raise ValueError("backend {} is not supported".format(backend)) @@ -676,7 +676,6 @@ def create_shear( """ if look_up_option(backend, TransformBackends) == TransformBackends.NUMPY: - array_func = np.array return _create_shear(spatial_dims=spatial_dims, coefs=coefs, array_func=np.array) if look_up_option(backend, TransformBackends) == TransformBackends.TORCH: return _create_shear(spatial_dims=spatial_dims, coefs=coefs, array_func=torch.as_tensor) @@ -759,7 +758,7 @@ def _create_translate( affine = eye_func(spatial_dims + 1) for i, a in enumerate(shift[:spatial_dims]): affine[i, spatial_dims] = a - return array_func(affine) + return array_func(affine) # type: ignore def generate_spatial_bounding_box( From e2fa593c25ce9481816e0be82c63d49cd03e8f52 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 17 Sep 2021 17:37:04 +0100 Subject: [PATCH 04/20] backend affien grid Signed-off-by: Wenqi Li --- monai/transforms/spatial/array.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 9f65facf29..f4520ed1c8 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -1069,18 +1069,19 @@ def __call__( else: raise ValueError("Incompatible values: grid=None and spatial_size=None.") + _b = TransformBackends.TORCH if isinstance(grid, torch.Tensor) else TransformBackends.NUMPY affine: NdarrayOrTensor if self.affine is None: spatial_dims = len(grid.shape) - 1 - affine = np.eye(spatial_dims + 1) + affine = torch.eye(spatial_dims + 1) if _b == TransformBackends.TORCH else np.eye(spatial_dims + 1) if self.rotate_params: - affine = affine @ create_rotate(spatial_dims, self.rotate_params) + affine = affine @ create_rotate(spatial_dims, self.rotate_params, backend=_b) if self.shear_params: - affine = affine @ create_shear(spatial_dims, self.shear_params) + affine = affine @ create_shear(spatial_dims, self.shear_params, backend=_b) if self.translate_params: - affine = affine @ create_translate(spatial_dims, self.translate_params) + affine = affine @ create_translate(spatial_dims, self.translate_params, backend=_b) if self.scale_params: - affine = affine @ create_scale(spatial_dims, self.scale_params) + affine = affine @ create_scale(spatial_dims, self.scale_params, backend=_b) else: affine = self.affine From fa9930080fffda577831c89cea26289547480f76 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sat, 18 Sep 2021 13:11:48 +0100 Subject: [PATCH 05/20] device support Signed-off-by: Wenqi Li --- monai/transforms/spatial/array.py | 15 +++-- monai/transforms/utils.py | 90 +++++++++++++++++----------- tests/test_create_grid_and_affine.py | 8 ++- 3 files changed, 72 insertions(+), 41 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 98fc9b2e5d..6ae9ec7627 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -1070,18 +1070,23 @@ def __call__( raise ValueError("Incompatible values: grid=None and spatial_size=None.") _b = TransformBackends.TORCH if isinstance(grid, torch.Tensor) else TransformBackends.NUMPY + _device = self.device or (grid.device if isinstance(grid, torch.Tensor) else None) affine: NdarrayOrTensor if self.affine is None: spatial_dims = len(grid.shape) - 1 - affine = torch.eye(spatial_dims + 1) if _b == TransformBackends.TORCH else np.eye(spatial_dims + 1) + affine = ( + torch.eye(spatial_dims + 1, device=_device) + if _b == TransformBackends.TORCH + else np.eye(spatial_dims + 1) + ) if self.rotate_params: - affine = affine @ create_rotate(spatial_dims, self.rotate_params, backend=_b) + affine = affine @ create_rotate(spatial_dims, self.rotate_params, device=_device, backend=_b) if self.shear_params: - affine = affine @ create_shear(spatial_dims, self.shear_params, backend=_b) + affine = affine @ create_shear(spatial_dims, self.shear_params, device=_device, backend=_b) if self.translate_params: - affine = affine @ create_translate(spatial_dims, self.translate_params, backend=_b) + affine = affine @ create_translate(spatial_dims, self.translate_params, device=_device, backend=_b) if self.scale_params: - affine = affine @ create_scale(spatial_dims, self.scale_params, backend=_b) + affine = affine @ create_scale(spatial_dims, self.scale_params, device=_device, backend=_b) else: affine = self.affine diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index bb6903fb86..1d1af8b26e 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -577,7 +577,10 @@ def create_control_grid( def create_rotate( - spatial_dims: int, radians: Union[Sequence[float], float], backend=TransformBackends.NUMPY + spatial_dims: int, + radians: Union[Sequence[float], float], + device: Optional[torch.device] = None, + backend=TransformBackends.NUMPY, ) -> NdarrayOrTensor: """ create a 2D or 3D rotation matrix @@ -587,6 +590,7 @@ def create_rotate( radians: rotation radians when spatial_dims == 3, the `radians` sequence corresponds to rotation in the 1st, 2nd, and 3rd dim respectively. + device: device to compute and store the output. backend: APIs to use, ``numpy`` or ``torch``. Raises: @@ -596,15 +600,15 @@ def create_rotate( """ if look_up_option(backend, TransformBackends) == TransformBackends.NUMPY: return _create_rotate( - spatial_dims=spatial_dims, radians=radians, sin_func=np.sin, cos_func=np.cos, array_func=np.array + spatial_dims=spatial_dims, radians=radians, sin_func=np.sin, cos_func=np.cos, eye_func=np.eye ) if look_up_option(backend, TransformBackends) == TransformBackends.TORCH: return _create_rotate( spatial_dims=spatial_dims, radians=radians, - sin_func=lambda th: torch.sin(torch.as_tensor(th, dtype=torch.float32)), - cos_func=lambda th: torch.cos(torch.as_tensor(th, dtype=torch.float32)), - array_func=torch.as_tensor, + sin_func=lambda th: torch.sin(torch.as_tensor(th, dtype=torch.float32, device=device)), + cos_func=lambda th: torch.cos(torch.as_tensor(th, dtype=torch.float32, device=device)), + eye_func=lambda rank: torch.eye(rank, device=device), ) raise ValueError("backend {} is not supported".format(backend)) @@ -614,36 +618,41 @@ def _create_rotate( radians: Union[Sequence[float], float], sin_func: Callable = np.sin, cos_func: Callable = np.cos, - array_func: Callable = np.array, + eye_func: Callable = np.eye, ) -> NdarrayOrTensor: radians = ensure_tuple(radians) if spatial_dims == 2: if len(radians) >= 1: sin_, cos_ = sin_func(radians[0]), cos_func(radians[0]) - return array_func([[cos_, -sin_, 0.0], [sin_, cos_, 0.0], [0.0, 0.0, 1.0]]) # type: ignore + out = eye_func(3) + out[0, 0], out[0, 1] = cos_, -sin_ + out[1, 0], out[1, 1] = sin_, cos_ + return out # type: ignore raise ValueError("radians must be non empty.") if spatial_dims == 3: affine = None if len(radians) >= 1: sin_, cos_ = sin_func(radians[0]), cos_func(radians[0]) - affine = array_func( - [[1.0, 0.0, 0.0, 0.0], [0.0, cos_, -sin_, 0.0], [0.0, sin_, cos_, 0.0], [0.0, 0.0, 0.0, 1.0]] - ) + affine = eye_func(4) + affine[1, 1], affine[1, 2] = cos_, -sin_ + affine[2, 1], affine[2, 2] = sin_, cos_ if len(radians) >= 2: sin_, cos_ = sin_func(radians[1]), cos_func(radians[1]) if affine is None: raise ValueError("Affine should be a matrix.") - affine = affine @ array_func( - [[cos_, 0.0, sin_, 0.0], [0.0, 1.0, 0.0, 0.0], [-sin_, 0.0, cos_, 0.0], [0.0, 0.0, 0.0, 1.0]] - ) + _affine = eye_func(4) + _affine[0, 0], _affine[0, 2] = cos_, sin_ + _affine[2, 0], _affine[2, 2] = -sin_, cos_ + affine = affine @ _affine if len(radians) >= 3: sin_, cos_ = sin_func(radians[2]), cos_func(radians[2]) if affine is None: raise ValueError("Affine should be a matrix.") - affine = affine @ array_func( - [[cos_, -sin_, 0.0, 0.0], [sin_, cos_, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]] - ) + _affine = eye_func(4) + _affine[0, 0], _affine[0, 1] = cos_, -sin_ + _affine[1, 0], _affine[1, 1] = sin_, cos_ + affine = affine @ _affine if affine is None: raise ValueError("radians must be non empty.") return affine # type: ignore @@ -652,7 +661,10 @@ def _create_rotate( def create_shear( - spatial_dims: int, coefs: Union[Sequence[float], float], backend=TransformBackends.NUMPY + spatial_dims: int, + coefs: Union[Sequence[float], float], + device: Optional[torch.device] = None, + backend=TransformBackends.NUMPY, ) -> NdarrayOrTensor: """ create a shearing matrix @@ -669,6 +681,7 @@ def create_shear( [0.0, 0.0, 0.0, 1.0], ] + device: device to compute and store the output. backend: APIs to use, ``numpy`` or ``torch``. Raises: @@ -676,31 +689,35 @@ def create_shear( """ if look_up_option(backend, TransformBackends) == TransformBackends.NUMPY: - return _create_shear(spatial_dims=spatial_dims, coefs=coefs, array_func=np.array) + return _create_shear(spatial_dims=spatial_dims, coefs=coefs, eye_func=np.eye) if look_up_option(backend, TransformBackends) == TransformBackends.TORCH: - return _create_shear(spatial_dims=spatial_dims, coefs=coefs, array_func=torch.as_tensor) + return _create_shear( + spatial_dims=spatial_dims, coefs=coefs, eye_func=lambda rank: torch.eye(rank, device=device) + ) raise ValueError("backend {} is not supported".format(backend)) -def _create_shear(spatial_dims: int, coefs: Union[Sequence[float], float], array_func=np.array) -> NdarrayOrTensor: +def _create_shear(spatial_dims: int, coefs: Union[Sequence[float], float], eye_func=np.eye) -> NdarrayOrTensor: if spatial_dims == 2: coefs = ensure_tuple_size(coefs, dim=2, pad_val=0.0) - return array_func([[1, coefs[0], 0.0], [coefs[1], 1.0, 0.0], [0.0, 0.0, 1.0]]) # type: ignore + out = eye_func(3) + out[0, 1], out[1, 0] = coefs[0], coefs[1] + return out # type: ignore if spatial_dims == 3: coefs = ensure_tuple_size(coefs, dim=6, pad_val=0.0) - return array_func( # type: ignore - [ - [1.0, coefs[0], coefs[1], 0.0], - [coefs[2], 1.0, coefs[3], 0.0], - [coefs[4], coefs[5], 1.0, 0.0], - [0.0, 0.0, 0.0, 1.0], - ] - ) + out = eye_func(4) + out[0, 1], out[0, 2] = coefs[0], coefs[1] + out[1, 0], out[1, 2] = coefs[2], coefs[3] + out[2, 0], out[2, 1] = coefs[4], coefs[5] + return out # type: ignore raise NotImplementedError("Currently only spatial_dims in [2, 3] are supported.") def create_scale( - spatial_dims: int, scaling_factor: Union[Sequence[float], float], backend=TransformBackends.NUMPY + spatial_dims: int, + scaling_factor: Union[Sequence[float], float], + device: Optional[torch.device] = None, + backend=TransformBackends.NUMPY, ) -> NdarrayOrTensor: """ create a scaling matrix @@ -708,6 +725,7 @@ def create_scale( Args: spatial_dims: spatial rank scaling_factor: scaling factors for every spatial dim, defaults to 1. + device: device to compute and store the output. backend: APIs to use, ``numpy`` or ``torch``. """ if look_up_option(backend, TransformBackends) == TransformBackends.NUMPY: @@ -716,7 +734,7 @@ def create_scale( return _create_scale( spatial_dims=spatial_dims, scaling_factor=scaling_factor, - array_func=lambda x: torch.diag(torch.as_tensor(x)), + array_func=lambda x: torch.diag(torch.as_tensor(x, device=device)), ) raise ValueError("backend {} is not supported".format(backend)) @@ -729,7 +747,10 @@ def _create_scale( def create_translate( - spatial_dims: int, shift: Union[Sequence[float], float], backend=TransformBackends.NUMPY + spatial_dims: int, + shift: Union[Sequence[float], float], + device: Optional[torch.device] = None, + backend=TransformBackends.NUMPY, ) -> NdarrayOrTensor: """ create a translation matrix @@ -737,6 +758,7 @@ def create_translate( Args: spatial_dims: spatial rank shift: translate pixel/voxel for every spatial dim, defaults to 0. + device: device to compute and store the output. backend: APIs to use, ``numpy`` or ``torch``. """ if look_up_option(backend, TransformBackends) == TransformBackends.NUMPY: @@ -745,8 +767,8 @@ def create_translate( return _create_translate( spatial_dims=spatial_dims, shift=shift, - eye_func=lambda x: torch.eye(torch.as_tensor(x)), # type: ignore - array_func=torch.as_tensor, + eye_func=lambda x: torch.eye(torch.as_tensor(x), device=device), # type: ignore + array_func=lambda x: torch.as_tensor(x, device=device), # type: ignore ) raise ValueError("backend {} is not supported".format(backend)) diff --git a/tests/test_create_grid_and_affine.py b/tests/test_create_grid_and_affine.py index 0e16d4bde5..42cb0a9b0e 100644 --- a/tests/test_create_grid_and_affine.py +++ b/tests/test_create_grid_and_affine.py @@ -12,6 +12,7 @@ import unittest import numpy as np +import torch from monai.transforms import ( create_control_grid, @@ -148,8 +149,11 @@ def test_create_control_grid(self): def test_assert(func, params, expected): - for b in ("torch", "numpy"): - m = func(*params, backend=b) + for b in ("torch", "numpy") + ("torch_gpu",) if torch.cuda.is_available() else (): + if b == "torch_gpu": + m = func(*params, device="cuda:0", backend="torch") + else: + m = func(*params, backend=b) assert_allclose(m, expected, type_test=False, atol=1e-7) From a0ee97d145057d934e2c2237a7dd97d8e8337369 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sat, 18 Sep 2021 14:07:19 +0100 Subject: [PATCH 06/20] less lookup Signed-off-by: Wenqi Li --- monai/transforms/utils.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 1d1af8b26e..130aa531dd 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -598,11 +598,12 @@ def create_rotate( ValueError: When ``spatial_dims`` is not one of [2, 3]. """ - if look_up_option(backend, TransformBackends) == TransformBackends.NUMPY: + _backend = look_up_option(backend, TransformBackends) + if _backend == TransformBackends.NUMPY: return _create_rotate( spatial_dims=spatial_dims, radians=radians, sin_func=np.sin, cos_func=np.cos, eye_func=np.eye ) - if look_up_option(backend, TransformBackends) == TransformBackends.TORCH: + if _backend == TransformBackends.TORCH: return _create_rotate( spatial_dims=spatial_dims, radians=radians, @@ -688,9 +689,10 @@ def create_shear( NotImplementedError: When ``spatial_dims`` is not one of [2, 3]. """ - if look_up_option(backend, TransformBackends) == TransformBackends.NUMPY: + _backend = look_up_option(backend, TransformBackends) + if _backend == TransformBackends.NUMPY: return _create_shear(spatial_dims=spatial_dims, coefs=coefs, eye_func=np.eye) - if look_up_option(backend, TransformBackends) == TransformBackends.TORCH: + if _backend == TransformBackends.TORCH: return _create_shear( spatial_dims=spatial_dims, coefs=coefs, eye_func=lambda rank: torch.eye(rank, device=device) ) @@ -728,9 +730,10 @@ def create_scale( device: device to compute and store the output. backend: APIs to use, ``numpy`` or ``torch``. """ - if look_up_option(backend, TransformBackends) == TransformBackends.NUMPY: + _backend = look_up_option(backend, TransformBackends) + if _backend == TransformBackends.NUMPY: return _create_scale(spatial_dims=spatial_dims, scaling_factor=scaling_factor, array_func=np.diag) - if look_up_option(backend, TransformBackends) == TransformBackends.TORCH: + if _backend == TransformBackends.TORCH: return _create_scale( spatial_dims=spatial_dims, scaling_factor=scaling_factor, @@ -761,9 +764,10 @@ def create_translate( device: device to compute and store the output. backend: APIs to use, ``numpy`` or ``torch``. """ - if look_up_option(backend, TransformBackends) == TransformBackends.NUMPY: + _backend = look_up_option(backend, TransformBackends) + if _backend == TransformBackends.NUMPY: return _create_translate(spatial_dims=spatial_dims, shift=shift, eye_func=np.eye, array_func=np.asarray) - if look_up_option(backend, TransformBackends) == TransformBackends.TORCH: + if _backend == TransformBackends.TORCH: return _create_translate( spatial_dims=spatial_dims, shift=shift, From 7171eb10b175dbd2d2a1cb6c4a7e652272cb6cf9 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sun, 19 Sep 2021 13:31:06 +0100 Subject: [PATCH 07/20] create grid with torch backend Signed-off-by: Wenqi Li --- monai/transforms/utils.py | 57 ++++++++++++++++++++++--- tests/test_create_grid_and_affine.py | 63 +++++++++++++--------------- 2 files changed, 82 insertions(+), 38 deletions(-) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 130aa531dd..ad535f925b 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -541,7 +541,9 @@ def create_grid( spatial_size: Sequence[int], spacing: Optional[Sequence[float]] = None, homogeneous: bool = True, - dtype: DtypeLike = float, + dtype=None, + device: Optional[torch.device] = None, + backend=TransformBackends.NUMPY, ): """ compute a `spatial_size` mesh. @@ -551,6 +553,30 @@ def create_grid( spacing: same len as ``spatial_size``, defaults to 1.0 (dense grid). homogeneous: whether to make homogeneous coordinates. dtype: output grid data type. + device: device to compute and store the output (when the backend is "torch"). + backend: APIs to use, ``numpy`` or ``torch``. + + """ + _backend = look_up_option(backend, TransformBackends) + if _backend == TransformBackends.NUMPY: + if dtype is None: + dtype = np.float32 + return _create_grid_numpy(spatial_size, spacing, homogeneous, dtype) + if _backend == TransformBackends.TORCH: + if dtype is None: + dtype = torch.float32 + return _create_grid_torch(spatial_size, spacing, homogeneous, dtype, device) + raise ValueError("backend {} is not supported".format(backend)) + + +def _create_grid_numpy( + spatial_size: Sequence[int], + spacing: Optional[Sequence[float]] = None, + homogeneous: bool = True, + dtype: DtypeLike = float, +): + """ + compute a `spatial_size` mesh with the numpy API. """ spacing = spacing or tuple(1.0 for _ in spatial_size) ranges = [np.linspace(-(d - 1.0) / 2.0 * s, (d - 1.0) / 2.0 * s, int(d)) for d, s in zip(spatial_size, spacing)] @@ -560,6 +586,27 @@ def create_grid( return np.concatenate([coords, np.ones_like(coords[:1])]) +def _create_grid_torch( + spatial_size: Sequence[int], + spacing: Optional[Sequence[float]] = None, + homogeneous: bool = True, + dtype=torch.float32, + device: Optional[torch.device] = None, +): + """ + compute a `spatial_size` mesh with the torch API. + """ + spacing = spacing or tuple(1.0 for _ in spatial_size) + ranges = [ + torch.linspace(-(d - 1.0) / 2.0 * s, (d - 1.0) / 2.0 * s, int(d), device=device, dtype=dtype) + for d, s in zip(spatial_size, spacing) + ] + coords = torch.meshgrid(*ranges) + if not homogeneous: + return torch.stack(coords) + return torch.stack([*coords, torch.ones_like(coords[0])]) + + def create_control_grid( spatial_shape: Sequence[int], spacing: Sequence[float], homogeneous: bool = True, dtype: DtypeLike = float ): @@ -590,7 +637,7 @@ def create_rotate( radians: rotation radians when spatial_dims == 3, the `radians` sequence corresponds to rotation in the 1st, 2nd, and 3rd dim respectively. - device: device to compute and store the output. + device: device to compute and store the output (when the backend is "torch"). backend: APIs to use, ``numpy`` or ``torch``. Raises: @@ -682,7 +729,7 @@ def create_shear( [0.0, 0.0, 0.0, 1.0], ] - device: device to compute and store the output. + device: device to compute and store the output (when the backend is "torch"). backend: APIs to use, ``numpy`` or ``torch``. Raises: @@ -727,7 +774,7 @@ def create_scale( Args: spatial_dims: spatial rank scaling_factor: scaling factors for every spatial dim, defaults to 1. - device: device to compute and store the output. + device: device to compute and store the output (when the backend is "torch"). backend: APIs to use, ``numpy`` or ``torch``. """ _backend = look_up_option(backend, TransformBackends) @@ -761,7 +808,7 @@ def create_translate( Args: spatial_dims: spatial rank shift: translate pixel/voxel for every spatial dim, defaults to 0. - device: device to compute and store the output. + device: device to compute and store the output (when the backend is "torch"). backend: APIs to use, ``numpy`` or ``torch``. """ _backend = look_up_option(backend, TransformBackends) diff --git a/tests/test_create_grid_and_affine.py b/tests/test_create_grid_and_affine.py index 42cb0a9b0e..69cc8a1c00 100644 --- a/tests/test_create_grid_and_affine.py +++ b/tests/test_create_grid_and_affine.py @@ -34,50 +34,47 @@ def test_create_grid(self): with self.assertRaisesRegex(TypeError, ""): create_grid((1, 1), spacing=2.0) - g = create_grid((1, 1)) - expected = np.array([[[0.0]], [[0.0]], [[1.0]]]) - np.testing.assert_allclose(g, expected) + test_assert(create_grid, ((1, 1),), np.array([[[0.0]], [[0.0]], [[1.0]]])) - g = create_grid((1, 1), homogeneous=False) - expected = np.array([[[0.0]], [[0.0]]]) - np.testing.assert_allclose(g, expected) + test_assert(create_grid, ((1, 1), None, False), np.array([[[0.0]], [[0.0]]])) - g = create_grid((1, 1), spacing=(1.2, 1.3)) - expected = np.array([[[0.0]], [[0.0]], [[1.0]]]) - np.testing.assert_allclose(g, expected) + test_assert(create_grid, ((1, 1), (1.2, 1.3)), np.array([[[0.0]], [[0.0]], [[1.0]]])) - g = create_grid((1, 1, 1), spacing=(1.2, 1.3, 1.0)) - expected = np.array([[[[0.0]]], [[[0.0]]], [[[0.0]]], [[[1.0]]]]) - np.testing.assert_allclose(g, expected) + test_assert(create_grid, ((1, 1, 1), (1.2, 1.3, 1.0)), np.array([[[[0.0]]], [[[0.0]]], [[[0.0]]], [[[1.0]]]])) - g = create_grid((1, 1, 1), spacing=(1.2, 1.3, 1.0), homogeneous=False) - expected = np.array([[[[0.0]]], [[[0.0]]], [[[0.0]]]]) - np.testing.assert_allclose(g, expected) + test_assert(create_grid, ((1, 1, 1), (1.2, 1.3, 1.0), False), np.array([[[[0.0]]], [[[0.0]]], [[[0.0]]]])) g = create_grid((1, 1, 1), spacing=(1.2, 1.3, 1.0), dtype=np.int32) np.testing.assert_equal(g.dtype, np.int32) - g = create_grid((2, 2, 2)) - expected = np.array( - [ - [[[-0.5, -0.5], [-0.5, -0.5]], [[0.5, 0.5], [0.5, 0.5]]], - [[[-0.5, -0.5], [0.5, 0.5]], [[-0.5, -0.5], [0.5, 0.5]]], - [[[-0.5, 0.5], [-0.5, 0.5]], [[-0.5, 0.5], [-0.5, 0.5]]], - [[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]], - ] + g = create_grid((1, 1, 1), spacing=(1.2, 1.3, 1.0), dtype=torch.float64, backend="torch") + np.testing.assert_equal(g.dtype, torch.float64) + + test_assert( + create_grid, + (2, 2, 2), + np.array( + [ + [[[-0.5, -0.5], [-0.5, -0.5]], [[0.5, 0.5], [0.5, 0.5]]], + [[[-0.5, -0.5], [0.5, 0.5]], [[-0.5, -0.5], [0.5, 0.5]]], + [[[-0.5, 0.5], [-0.5, 0.5]], [[-0.5, 0.5], [-0.5, 0.5]]], + [[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]], + ] + ), ) - np.testing.assert_allclose(g, expected) - g = create_grid((2, 2, 2), spacing=(1.2, 1.3, 1.0)) - expected = np.array( - [ - [[[-0.6, -0.6], [-0.6, -0.6]], [[0.6, 0.6], [0.6, 0.6]]], - [[[-0.65, -0.65], [0.65, 0.65]], [[-0.65, -0.65], [0.65, 0.65]]], - [[[-0.5, 0.5], [-0.5, 0.5]], [[-0.5, 0.5], [-0.5, 0.5]]], - [[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]], - ] + test_assert( + create_grid, + ((2, 2, 2), (1.2, 1.3, 1.0)), + np.array( + [ + [[[-0.6, -0.6], [-0.6, -0.6]], [[0.6, 0.6], [0.6, 0.6]]], + [[[-0.65, -0.65], [0.65, 0.65]], [[-0.65, -0.65], [0.65, 0.65]]], + [[[-0.5, 0.5], [-0.5, 0.5]], [[-0.5, 0.5], [-0.5, 0.5]]], + [[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]], + ] + ), ) - np.testing.assert_allclose(g, expected) def test_create_control_grid(self): with self.assertRaisesRegex(TypeError, ""): From 7c8808be4212da549a00ca3f11e93bf3d9a41d94 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sun, 19 Sep 2021 13:44:23 +0100 Subject: [PATCH 08/20] fixes tests Signed-off-by: Wenqi Li --- tests/test_create_grid_and_affine.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_create_grid_and_affine.py b/tests/test_create_grid_and_affine.py index 69cc8a1c00..0b1243a07a 100644 --- a/tests/test_create_grid_and_affine.py +++ b/tests/test_create_grid_and_affine.py @@ -52,7 +52,7 @@ def test_create_grid(self): test_assert( create_grid, - (2, 2, 2), + ((2, 2, 2),), np.array( [ [[[-0.5, -0.5], [-0.5, -0.5]], [[0.5, 0.5], [0.5, 0.5]]], @@ -146,7 +146,8 @@ def test_create_control_grid(self): def test_assert(func, params, expected): - for b in ("torch", "numpy") + ("torch_gpu",) if torch.cuda.is_available() else (): + gpu_test = ("torch_gpu",) if torch.cuda.is_available() else () + for b in ("torch", "numpy") + gpu_test: if b == "torch_gpu": m = func(*params, device="cuda:0", backend="torch") else: From 98188335acdac743b88a821b8092c84c3ffd8e9f Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sun, 19 Sep 2021 13:58:11 +0100 Subject: [PATCH 09/20] default to torch create grid Signed-off-by: Wenqi Li --- monai/transforms/spatial/array.py | 6 +++++- monai/utils/type_conversion.py | 7 +++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 6ae9ec7627..07d2ae7143 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -1055,6 +1055,10 @@ def __call__( grid: Optional[NdarrayOrTensor] = None, ) -> Tuple[NdarrayOrTensor, NdarrayOrTensor]: """ + The grid can be initialized with a `spatial_size` parameter, or provided directly as `grid`. + Therefore, either `spatial_size` or `grid` must be provided. + When initialising from `spatial_size`, the backend "torch" will be used. + Args: spatial_size: output grid size. grid: grid to be transformed. Shape must be (3, H, W) for 2D or (4, H, W, D) for 3D. @@ -1065,7 +1069,7 @@ def __call__( """ if grid is None: if spatial_size is not None: - grid = create_grid(spatial_size, dtype=float) + grid = create_grid(spatial_size, dtype=torch.float32, device=self.device, backend="torch") else: raise ValueError("Incompatible values: grid=None and spatial_size=None.") diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index 3636dbc6c0..3c8d81055c 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -193,12 +193,11 @@ def convert_to_cupy(data, dtype, wrap_sequence: bool = True): elif isinstance(data, dict): return {k: convert_to_cupy(v, dtype) for k, v in data.items()} # make it contiguous - if isinstance(data, cp.ndarray): - if data.ndim > 0: - data = cp.ascontiguousarray(data) - else: + if not isinstance(data, cp.ndarray): raise ValueError(f"The input data type [{type(data)}] cannot be converted into cupy arrays!") + if data.ndim > 0: + data = cp.ascontiguousarray(data) return data From 0c700f9638c7641f1e2ef9139d7b8714ba34fd10 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sun, 19 Sep 2021 14:39:30 +0100 Subject: [PATCH 10/20] fixes unit test Signed-off-by: Wenqi Li --- monai/transforms/spatial/array.py | 2 +- monai/transforms/utils.py | 6 +----- monai/utils/type_conversion.py | 16 +++++++++------- tests/test_get_equivalent_dtype.py | 8 ++++++++ 4 files changed, 19 insertions(+), 13 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 07d2ae7143..ca5049191c 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -1069,7 +1069,7 @@ def __call__( """ if grid is None: if spatial_size is not None: - grid = create_grid(spatial_size, dtype=torch.float32, device=self.device, backend="torch") + grid = create_grid(spatial_size, device=self.device, backend="torch") else: raise ValueError("Incompatible values: grid=None and spatial_size=None.") diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index ad535f925b..acdf7fac2c 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -541,7 +541,7 @@ def create_grid( spatial_size: Sequence[int], spacing: Optional[Sequence[float]] = None, homogeneous: bool = True, - dtype=None, + dtype=float, device: Optional[torch.device] = None, backend=TransformBackends.NUMPY, ): @@ -559,12 +559,8 @@ def create_grid( """ _backend = look_up_option(backend, TransformBackends) if _backend == TransformBackends.NUMPY: - if dtype is None: - dtype = np.float32 return _create_grid_numpy(spatial_size, spacing, homogeneous, dtype) if _backend == TransformBackends.TORCH: - if dtype is None: - dtype = torch.float32 return _create_grid_torch(spatial_size, spacing, homogeneous, dtype, device) raise ValueError("backend {} is not supported".format(backend)) diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index 3c8d81055c..98950ef8bc 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -6,6 +6,7 @@ from monai.config.type_definitions import DtypeLike, NdarrayOrTensor from monai.utils import optional_import +from monai.utils.module import look_up_option cp, has_cp = optional_import("cupy") cp_ndarray, _ = optional_import("cupy", name="ndarray") @@ -41,33 +42,34 @@ def dtype_torch_to_numpy(dtype): """Convert a torch dtype to its numpy equivalent.""" - if dtype not in _torch_to_np_dtype: - raise ValueError(f"Unsupported torch to numpy dtype '{dtype}'.") - return _torch_to_np_dtype[dtype] + return look_up_option(dtype, _torch_to_np_dtype) def dtype_numpy_to_torch(dtype): """Convert a numpy dtype to its torch equivalent.""" # np dtypes can be given as np.float32 and np.dtype(np.float32) so unify them dtype = np.dtype(dtype) if type(dtype) is type else dtype - if dtype not in _np_to_torch_dtype: - raise ValueError(f"Unsupported numpy to torch dtype '{dtype}'.") - return _np_to_torch_dtype[dtype] + return look_up_option(dtype, _np_to_torch_dtype) def get_equivalent_dtype(dtype, data_type): """Convert to the `dtype` that corresponds to `data_type`. - Example: + + Example:: + im = torch.tensor(1) dtype = get_equivalent_dtype(np.float32, type(im)) + """ if dtype is None: return None if data_type is torch.Tensor: if type(dtype) is torch.dtype: + # already a torch dtype and target `data_type` is torch.Tensor return dtype return dtype_numpy_to_torch(dtype) if type(dtype) is not torch.dtype: + # assuming the dtype is ok if it is not a torch dtype return dtype return dtype_torch_to_numpy(dtype) diff --git a/tests/test_get_equivalent_dtype.py b/tests/test_get_equivalent_dtype.py index 04ba5ae5fb..de2379b15b 100644 --- a/tests/test_get_equivalent_dtype.py +++ b/tests/test_get_equivalent_dtype.py @@ -32,6 +32,14 @@ def test_get_equivalent_dtype(self, im, input_dtype): out_dtype = get_equivalent_dtype(input_dtype, type(im)) self.assertEqual(out_dtype, im.dtype) + def test_native_type(self): + """the get_equivalent_dtype currently doesn't change the build-in type""" + n_type = [float, int, bool] + for n in n_type: + for im_dtype in DTYPES: + out_dtype = get_equivalent_dtype(n, type(im_dtype)) + self.assertEqual(out_dtype, n) + if __name__ == "__main__": unittest.main() From 6ca3ae711e54fc95f6c0c0e05effb1cce7622058 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sun, 19 Sep 2021 14:49:07 +0100 Subject: [PATCH 11/20] randaffine with create grid Signed-off-by: Wenqi Li --- monai/transforms/spatial/array.py | 8 ++++++-- monai/utils/type_conversion.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index ca5049191c..8e122a9511 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -1563,7 +1563,7 @@ def _init_identity_cache(self): f"'spatial_size={self.spatial_size}', please specify 'spatial_size'." ) return None - return torch.tensor(create_grid(spatial_size=_sp_size)).to(self.rand_affine_grid.device) + return create_grid(spatial_size=_sp_size, device=self.rand_affine_grid.device, backend="torch") def get_identity_grid(self, spatial_size: Sequence[int]): """ @@ -1577,7 +1577,11 @@ def get_identity_grid(self, spatial_size: Sequence[int]): spatial_size, [2] * ndim ): raise RuntimeError(f"spatial_size should not be dynamic, got {spatial_size}.") - return create_grid(spatial_size=spatial_size) if self._cached_grid is None else self._cached_grid + return ( + create_grid(spatial_size=spatial_size, device=self.rand_affine_grid.device, backend="torch") + if self._cached_grid is None + else self._cached_grid + ) def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index 98950ef8bc..5e781dc4ee 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -69,7 +69,7 @@ def get_equivalent_dtype(dtype, data_type): return dtype return dtype_numpy_to_torch(dtype) if type(dtype) is not torch.dtype: - # assuming the dtype is ok if it is not a torch dtype + # assuming the dtype is ok if it is not a torch dtype and target `data_type` is not torch.Tensor return dtype return dtype_torch_to_numpy(dtype) From 125dcecca687275be1735e2ff52d5dad5b4a4e4a Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sun, 19 Sep 2021 15:09:50 +0100 Subject: [PATCH 12/20] create_grid backend change for spatial transforms Signed-off-by: Wenqi Li --- monai/transforms/spatial/array.py | 15 +++++++++------ monai/transforms/spatial/dictionary.py | 6 ++++-- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 8e122a9511..b919eecc57 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -1074,7 +1074,7 @@ def __call__( raise ValueError("Incompatible values: grid=None and spatial_size=None.") _b = TransformBackends.TORCH if isinstance(grid, torch.Tensor) else TransformBackends.NUMPY - _device = self.device or (grid.device if isinstance(grid, torch.Tensor) else None) + _device = grid.device if isinstance(grid, torch.Tensor) else self.device affine: NdarrayOrTensor if self.affine is None: spatial_dims = len(grid.shape) - 1 @@ -1310,8 +1310,9 @@ def __call__( """ if grid is None: raise ValueError("Unknown grid.") + _device = img.device if isinstance(img, torch.Tensor) else self.device img_t: torch.Tensor - img_t, *_ = convert_data_type(img, torch.Tensor, device=self.device, dtype=torch.float32) # type: ignore + img_t, *_ = convert_data_type(img, torch.Tensor, device=_device, dtype=torch.float32) # type: ignore grid, *_ = convert_to_dst_type(grid, img_t) if USE_COMPILED: @@ -1715,6 +1716,7 @@ def __init__( ) self.resampler = Resample(device=device) + self.device = device self.spatial_size = spatial_size self.mode: GridSampleMode = look_up_option(mode, GridSampleMode) self.padding_mode: GridSamplePadMode = look_up_option(padding_mode, GridSamplePadMode) @@ -1766,7 +1768,8 @@ def __call__( ) grid = CenterSpatialCrop(roi_size=sp_size)(grid[0]) else: - grid = create_grid(spatial_size=sp_size) + _device = img.device if isinstance(img, torch.Tensor) else self.device + grid = create_grid(spatial_size=sp_size, device=_device, backend="torch") out: NdarrayOrTensor = self.resampler( img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode ) @@ -1904,11 +1907,11 @@ def __call__( """ sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) self.randomize(grid_size=sp_size) - grid = create_grid(spatial_size=sp_size) + _device = img.device if isinstance(img, torch.Tensor) else self.device + grid = create_grid(spatial_size=sp_size, device=_device, backend="torch") if self._do_transform: if self.rand_offset is None: - raise AssertionError - grid = torch.as_tensor(np.ascontiguousarray(grid), device=self.device) + raise RuntimeError("rand_offset is not initialized.") gaussian = GaussianFilter(3, self.sigma, 3.0).to(device=self.device) offset = torch.as_tensor(self.rand_offset, device=self.device).unsqueeze(0) grid[:3] += gaussian(offset)[0] * self.magnitude diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 487225cb60..6438ed5ae0 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -968,7 +968,8 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N ) grid = CenterSpatialCrop(roi_size=sp_size)(grid[0]) else: - grid = create_grid(spatial_size=sp_size) + _device = self.rand_2d_elastic.deform_grid.device + grid = create_grid(spatial_size=sp_size, device=_device, backend="torch") for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): d[key] = self.rand_2d_elastic.resampler(d[key], grid, mode=mode, padding_mode=padding_mode) @@ -1084,7 +1085,8 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N sp_size = fall_back_tuple(self.rand_3d_elastic.spatial_size, data[self.keys[0]].shape[1:]) self.randomize(grid_size=sp_size) - grid = create_grid(spatial_size=sp_size) + _device = self.rand_3d_elastic.device + grid = create_grid(spatial_size=sp_size, device=_device, backend="torch") if self._do_transform: device = self.rand_3d_elastic.device grid = torch.tensor(grid).to(device) From f4b6da4452946f7cb4d41be9addd4faf63834e3c Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sun, 19 Sep 2021 15:28:25 +0100 Subject: [PATCH 13/20] create control grid tensor Signed-off-by: Wenqi Li --- monai/transforms/spatial/array.py | 2 +- monai/transforms/utils.py | 19 +++-- tests/test_create_grid_and_affine.py | 112 +++++++++++++++------------ 3 files changed, 76 insertions(+), 57 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index b919eecc57..16c421d10b 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -1253,7 +1253,7 @@ def __call__(self, spatial_size: Sequence[int]): spatial_size: spatial size of the grid. """ self.spacing = fall_back_tuple(self.spacing, (1.0,) * len(spatial_size)) - control_grid = create_control_grid(spatial_size, self.spacing) + control_grid = create_control_grid(spatial_size, self.spacing, device=self.device, backend="torch") self.randomize(control_grid.shape[1:]) control_grid[: len(spatial_size)] += self.rand_mag * self.random_offset if self.as_tensor_output: diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index acdf7fac2c..a627a7544a 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -604,19 +604,28 @@ def _create_grid_torch( def create_control_grid( - spatial_shape: Sequence[int], spacing: Sequence[float], homogeneous: bool = True, dtype: DtypeLike = float + spatial_shape: Sequence[int], + spacing: Sequence[float], + homogeneous: bool = True, + dtype: DtypeLike = float, + device: Optional[torch.device] = None, + backend=TransformBackends.NUMPY, ): """ control grid with two additional point in each direction """ + torch_backend = look_up_option(backend, TransformBackends) == TransformBackends.TORCH + ceil_func: Callable = torch.ceil if torch_backend else np.ceil # type: ignore grid_shape = [] for d, s in zip(spatial_shape, spacing): - d = int(d) + d = torch.as_tensor(d, device=device) if torch_backend else int(d) # type: ignore if d % 2 == 0: - grid_shape.append(np.ceil((d - 1.0) / (2.0 * s) + 0.5) * 2.0 + 2.0) + grid_shape.append(ceil_func((d - 1.0) / (2.0 * s) + 0.5) * 2.0 + 2.0) else: - grid_shape.append(np.ceil((d - 1.0) / (2.0 * s)) * 2.0 + 3.0) - return create_grid(grid_shape, spacing, homogeneous, dtype) + grid_shape.append(ceil_func((d - 1.0) / (2.0 * s)) * 2.0 + 3.0) + return create_grid( + spatial_size=grid_shape, spacing=spacing, homogeneous=homogeneous, dtype=dtype, device=device, backend=backend + ) def create_rotate( diff --git a/tests/test_create_grid_and_affine.py b/tests/test_create_grid_and_affine.py index 0b1243a07a..b53eaa5b9d 100644 --- a/tests/test_create_grid_and_affine.py +++ b/tests/test_create_grid_and_affine.py @@ -82,67 +82,77 @@ def test_create_control_grid(self): with self.assertRaisesRegex(TypeError, ""): create_control_grid((1, 1), 2.0) - g = create_control_grid((1.0, 1.0), (1.0, 1.0)) - expected = np.array( - [ - [[-1.0, -1.0, -1.0], [0.0, 0.0, 0.0], [1.0, 1.0, 1.0]], - [[-1.0, 0.0, 1.0], [-1.0, 0.0, 1.0], [-1.0, 0.0, 1.0]], - [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], - ] + test_assert( + create_control_grid, + ((1.0, 1.0), (1.0, 1.0)), + np.array( + [ + [[-1.0, -1.0, -1.0], [0.0, 0.0, 0.0], [1.0, 1.0, 1.0]], + [[-1.0, 0.0, 1.0], [-1.0, 0.0, 1.0], [-1.0, 0.0, 1.0]], + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + ] + ), ) - np.testing.assert_allclose(g, expected) - g = create_control_grid((1.0, 1.0), (2.0, 2.0)) - expected = np.array( - [ - [[-2.0, -2.0, -2.0], [0.0, 0.0, 0.0], [2.0, 2.0, 2.0]], - [[-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0]], - [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], - ] + test_assert( + create_control_grid, + ((1.0, 1.0), (2.0, 2.0)), + np.array( + [ + [[-2.0, -2.0, -2.0], [0.0, 0.0, 0.0], [2.0, 2.0, 2.0]], + [[-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0]], + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + ] + ), ) - np.testing.assert_allclose(g, expected) - g = create_control_grid((2.0, 2.0), (1.0, 1.0)) - expected = np.array( - [ - [[-1.5, -1.5, -1.5, -1.5], [-0.5, -0.5, -0.5, -0.5], [0.5, 0.5, 0.5, 0.5], [1.5, 1.5, 1.5, 1.5]], - [[-1.5, -0.5, 0.5, 1.5], [-1.5, -0.5, 0.5, 1.5], [-1.5, -0.5, 0.5, 1.5], [-1.5, -0.5, 0.5, 1.5]], - [[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]], - ] + test_assert( + create_control_grid, + ((2.0, 2.0), (1.0, 1.0)), + np.array( + [ + [[-1.5, -1.5, -1.5, -1.5], [-0.5, -0.5, -0.5, -0.5], [0.5, 0.5, 0.5, 0.5], [1.5, 1.5, 1.5, 1.5]], + [[-1.5, -0.5, 0.5, 1.5], [-1.5, -0.5, 0.5, 1.5], [-1.5, -0.5, 0.5, 1.5], [-1.5, -0.5, 0.5, 1.5]], + [[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]], + ] + ), ) - np.testing.assert_allclose(g, expected) - g = create_control_grid((2.0, 2.0), (2.0, 2.0)) - expected = np.array( - [ - [[-3.0, -3.0, -3.0, -3.0], [-1.0, -1.0, -1.0, -1.0], [1.0, 1.0, 1.0, 1.0], [3.0, 3.0, 3.0, 3.0]], - [[-3.0, -1.0, 1.0, 3.0], [-3.0, -1.0, 1.0, 3.0], [-3.0, -1.0, 1.0, 3.0], [-3.0, -1.0, 1.0, 3.0]], - [[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]], - ] + test_assert( + create_control_grid, + ((2.0, 2.0), (2.0, 2.0)), + np.array( + [ + [[-3.0, -3.0, -3.0, -3.0], [-1.0, -1.0, -1.0, -1.0], [1.0, 1.0, 1.0, 1.0], [3.0, 3.0, 3.0, 3.0]], + [[-3.0, -1.0, 1.0, 3.0], [-3.0, -1.0, 1.0, 3.0], [-3.0, -1.0, 1.0, 3.0], [-3.0, -1.0, 1.0, 3.0]], + [[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]], + ] + ), ) - np.testing.assert_allclose(g, expected) - g = create_control_grid((1.0, 1.0, 1.0), (2.0, 2.0, 2.0), homogeneous=False) - expected = np.array( - [ - [ - [[-2.0, -2.0, -2.0], [-2.0, -2.0, -2.0], [-2.0, -2.0, -2.0]], - [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], - [[2.0, 2.0, 2.0], [2.0, 2.0, 2.0], [2.0, 2.0, 2.0]], - ], - [ - [[-2.0, -2.0, -2.0], [0.0, 0.0, 0.0], [2.0, 2.0, 2.0]], - [[-2.0, -2.0, -2.0], [0.0, 0.0, 0.0], [2.0, 2.0, 2.0]], - [[-2.0, -2.0, -2.0], [0.0, 0.0, 0.0], [2.0, 2.0, 2.0]], - ], + test_assert( + create_control_grid, + ((1.0, 1.0, 1.0), (2.0, 2.0, 2.0), False), + np.array( [ - [[-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0]], - [[-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0]], - [[-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0]], - ], - ] + [ + [[-2.0, -2.0, -2.0], [-2.0, -2.0, -2.0], [-2.0, -2.0, -2.0]], + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + [[2.0, 2.0, 2.0], [2.0, 2.0, 2.0], [2.0, 2.0, 2.0]], + ], + [ + [[-2.0, -2.0, -2.0], [0.0, 0.0, 0.0], [2.0, 2.0, 2.0]], + [[-2.0, -2.0, -2.0], [0.0, 0.0, 0.0], [2.0, 2.0, 2.0]], + [[-2.0, -2.0, -2.0], [0.0, 0.0, 0.0], [2.0, 2.0, 2.0]], + ], + [ + [[-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0]], + [[-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0]], + [[-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0]], + ], + ] + ), ) - np.testing.assert_allclose(g, expected) def test_assert(func, params, expected): From 2f1078f5072a06419bd987ed607f97c05bd0ecc9 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sun, 19 Sep 2021 15:53:04 +0100 Subject: [PATCH 14/20] enhance spatial, and crop xforms - Update test_rand_deform_grid.py - center_scale_crop - center_spatial_crop - rand_scale_crop - rand_spatial_crop - rand_spatial_crop_samples Signed-off-by: Wenqi Li --- monai/transforms/croppad/array.py | 30 +++++------ monai/transforms/croppad/dictionary.py | 28 +++++++---- monai/transforms/spatial/array.py | 9 ++-- monai/transforms/spatial/dictionary.py | 26 +++++----- monai/utils/type_conversion.py | 9 ++++ tests/test_center_scale_crop.py | 2 + tests/test_center_spatial_crop.py | 2 + tests/test_center_spatial_cropd.py | 49 +++++++++++------- tests/test_rand_deform_grid.py | 8 +-- tests/test_rand_scale_crop.py | 24 +++++---- tests/test_rand_scale_cropd.py | 13 +++-- tests/test_rand_spatial_crop.py | 10 ++-- tests/test_rand_spatial_crop_samples.py | 16 +++--- tests/test_rand_spatial_crop_samplesd.py | 64 +++++++++++++++--------- tests/test_rand_spatial_cropd.py | 11 ++-- 15 files changed, 184 insertions(+), 117 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index e1c915cc93..cc47972f3c 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -128,10 +128,7 @@ def __call__( # all zeros, skip padding return img mode = convert_pad_mode(dst=img, mode=mode or self.mode).value - if isinstance(img, torch.Tensor): - pad = self._pt_pad - else: - pad = self._np_pad # type: ignore + pad = self._pt_pad if isinstance(img, torch.Tensor) else self._np_pad return pad(img, self.to_pad, mode, **self.kwargs) # type: ignore @@ -449,15 +446,16 @@ class CenterSpatialCrop(Transform): the spatial size of output data will be [32, 40, 40]. """ + backend = SpatialCrop.backend + def __init__(self, roi_size: Union[Sequence[int], int]) -> None: self.roi_size = roi_size - def __call__(self, img: np.ndarray): + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. """ - img, *_ = convert_data_type(img, np.ndarray) # type: ignore roi_size = fall_back_tuple(self.roi_size, img.shape[1:]) center = [i // 2 for i in img.shape[1:]] cropper = SpatialCrop(roi_center=center, roi_size=roi_size) @@ -474,11 +472,12 @@ class CenterScaleCrop(Transform): """ + backend = CenterSpatialCrop.backend + def __init__(self, roi_scale: Union[Sequence[float], float]): self.roi_scale = roi_scale - def __call__(self, img: np.ndarray): - img, *_ = convert_data_type(img, np.ndarray) # type: ignore + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: img_size = img.shape[1:] ndim = len(img_size) roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)] @@ -510,6 +509,8 @@ class RandSpatialCrop(Randomizable, Transform): if True, the actual size is sampled from `randint(roi_size, max_roi_size + 1)`. """ + backend = CenterSpatialCrop.backend + def __init__( self, roi_size: Union[Sequence[int], int], @@ -535,15 +536,14 @@ def randomize(self, img_size: Sequence[int]) -> None: valid_size = get_valid_patch_size(img_size, self._size) self._slices = (slice(None),) + get_random_patch(img_size, valid_size, self.R) - def __call__(self, img: np.ndarray): + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. """ - img, *_ = convert_data_type(img, np.ndarray) # type: ignore self.randomize(img.shape[1:]) if self._size is None: - raise AssertionError + raise RuntimeError("self._size not specified.") if self.random_center: return img[self._slices] cropper = CenterSpatialCrop(self._size) @@ -582,12 +582,11 @@ def __init__( self.roi_scale = roi_scale self.max_roi_scale = max_roi_scale - def __call__(self, img: np.ndarray): + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. """ - img, *_ = convert_data_type(img, np.ndarray) # type: ignore img_size = img.shape[1:] ndim = len(img_size) self.roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)] @@ -629,6 +628,8 @@ class RandSpatialCropSamples(Randomizable, Transform): """ + backend = RandScaleCrop.backend + def __init__( self, roi_size: Union[Sequence[int], int], @@ -652,12 +653,11 @@ def set_random_state( def randomize(self, data: Optional[Any] = None) -> None: pass - def __call__(self, img: np.ndarray) -> List[np.ndarray]: + def __call__(self, img: NdarrayOrTensor) -> List[NdarrayOrTensor]: """ Apply the transform to `img`, assuming `img` is channel-first and cropping doesn't change the channel dim. """ - img, *_ = convert_data_type(img, np.ndarray) # type: ignore return [self.cropper(img) for _ in range(self.num_samples)] diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 488b832450..f06be52e85 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -416,13 +416,15 @@ class CenterSpatialCropd(MapTransform, InvertibleTransform): allow_missing_keys: don't raise exception if key is missing. """ + backend = CenterSpatialCrop.backend + def __init__( self, keys: KeysCollection, roi_size: Union[Sequence[int], int], allow_missing_keys: bool = False ) -> None: super().__init__(keys, allow_missing_keys) self.cropper = CenterSpatialCrop(roi_size) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): orig_size = d[key].shape[1:] @@ -466,13 +468,15 @@ class CenterScaleCropd(MapTransform, InvertibleTransform): allow_missing_keys: don't raise exception if key is missing. """ + backend = CenterSpatialCrop.backend + def __init__( self, keys: KeysCollection, roi_scale: Union[Sequence[float], float], allow_missing_keys: bool = False ) -> None: super().__init__(keys, allow_missing_keys=allow_missing_keys) self.roi_scale = roi_scale - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) # use the spatial size of first image to scale, expect all images have the same spatial size img_size = data[self.keys[0]].shape[1:] @@ -537,6 +541,8 @@ class RandSpatialCropd(Randomizable, MapTransform, InvertibleTransform): allow_missing_keys: don't raise exception if key is missing. """ + backend = CenterSpatialCrop.backend + def __init__( self, keys: KeysCollection, @@ -565,11 +571,11 @@ def randomize(self, img_size: Sequence[int]) -> None: valid_size = get_valid_patch_size(img_size, self._size) self._slices = (slice(None),) + get_random_patch(img_size, valid_size, self.R) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) self.randomize(d[self.keys[0]].shape[1:]) # image shape from the first data key if self._size is None: - raise AssertionError + raise RuntimeError("self._size not specified.") for key in self.key_iterator(d): if self.random_center: self.push_transform(d, key, {"slices": [(i.start, i.stop) for i in self._slices[1:]]}) # type: ignore @@ -638,6 +644,8 @@ class RandScaleCropd(RandSpatialCropd): allow_missing_keys: don't raise exception if key is missing. """ + backend = RandSpatialCropd.backend + def __init__( self, keys: KeysCollection, @@ -659,7 +667,7 @@ def __init__( self.roi_scale = roi_scale self.max_roi_scale = max_roi_scale - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: img_size = data[self.keys[0]].shape[1:] ndim = len(img_size) self.roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)] @@ -723,6 +731,8 @@ class RandSpatialCropSamplesd(Randomizable, MapTransform, InvertibleTransform): """ + backend = RandSpatialCropd.backend + def __init__( self, keys: KeysCollection, @@ -755,7 +765,7 @@ def set_random_state( def randomize(self, data: Optional[Any] = None) -> None: pass - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, np.ndarray]]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashable, NdarrayOrTensor]]: ret = [] for i in range(self.num_samples): d = dict(data) @@ -765,14 +775,14 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n cropped = self.cropper(d) # self.cropper will have added RandSpatialCropd to the list. Change to RandSpatialCropSamplesd for key in self.key_iterator(cropped): - cropped[str(key) + InverseKeys.KEY_SUFFIX][-1][InverseKeys.CLASS_NAME] = self.__class__.__name__ - cropped[str(key) + InverseKeys.KEY_SUFFIX][-1][InverseKeys.ID] = id(self) + cropped[str(key) + InverseKeys.KEY_SUFFIX][-1][InverseKeys.CLASS_NAME] = self.__class__.__name__ # type: ignore + cropped[str(key) + InverseKeys.KEY_SUFFIX][-1][InverseKeys.ID] = id(self) # type: ignore # add `patch_index` to the meta data for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix): meta_key = meta_key or f"{key}_{meta_key_postfix}" if meta_key not in cropped: cropped[meta_key] = {} # type: ignore - cropped[meta_key][Key.PATCH_INDEX] = i + cropped[meta_key][Key.PATCH_INDEX] = i # type: ignore ret.append(cropped) return ret diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 16c421d10b..e88dd394c4 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -1255,9 +1255,10 @@ def __call__(self, spatial_size: Sequence[int]): self.spacing = fall_back_tuple(self.spacing, (1.0,) * len(spatial_size)) control_grid = create_control_grid(spatial_size, self.spacing, device=self.device, backend="torch") self.randomize(control_grid.shape[1:]) - control_grid[: len(spatial_size)] += self.rand_mag * self.random_offset - if self.as_tensor_output: - control_grid = torch.as_tensor(np.ascontiguousarray(control_grid), device=self.device) + _offset, *_ = convert_to_dst_type(self.rand_mag * self.random_offset, control_grid) + control_grid[: len(spatial_size)] += _offset + if not self.as_tensor_output: + control_grid, *_ = convert_data_type(control_grid, output_type=np.ndarray, dtype=np.float32) return control_grid @@ -1761,7 +1762,7 @@ def __call__( grid = self.rand_affine_grid(grid=grid) grid = torch.nn.functional.interpolate( # type: ignore recompute_scale_factor=True, - input=torch.as_tensor(grid).unsqueeze(0), + input=grid.unsqueeze(0), scale_factor=list(ensure_tuple(self.deform_grid.spacing)), mode=InterpolateMode.BICUBIC.value, align_corners=False, diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 6438ed5ae0..e4252f245c 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -827,20 +827,19 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd transform = self.get_most_recent_transform(d, key) # if transform was not performed and spatial size is None, nothing to do. if not transform[InverseKeys.DO_TRANSFORM] and self.rand_affine.spatial_size is None: - out: NdarrayOrTensor = d[key] - else: - orig_size = transform[InverseKeys.ORIG_SIZE] - # Create inverse transform - fwd_affine = transform[InverseKeys.EXTRA_INFO]["affine"] - mode = transform[InverseKeys.EXTRA_INFO]["mode"] - padding_mode = transform[InverseKeys.EXTRA_INFO]["padding_mode"] - inv_affine = np.linalg.inv(fwd_affine) + continue + orig_size = transform[InverseKeys.ORIG_SIZE] + # Create inverse transform + fwd_affine = transform[InverseKeys.EXTRA_INFO]["affine"] + mode = transform[InverseKeys.EXTRA_INFO]["mode"] + padding_mode = transform[InverseKeys.EXTRA_INFO]["padding_mode"] + inv_affine = np.linalg.inv(fwd_affine) - affine_grid = AffineGrid(affine=inv_affine) - grid, _ = affine_grid(orig_size) # type: ignore + affine_grid = AffineGrid(affine=inv_affine) + grid, _ = affine_grid(orig_size) # type: ignore - # Apply inverse transform - d[key] = self.rand_affine.resampler(d[key], grid, mode, padding_mode) + # Apply inverse transform + d[key] = self.rand_affine.resampler(d[key], grid, mode, padding_mode) # Remove the applied transform self.pop_transform(d, key) @@ -1089,9 +1088,8 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N grid = create_grid(spatial_size=sp_size, device=_device, backend="torch") if self._do_transform: device = self.rand_3d_elastic.device - grid = torch.tensor(grid).to(device) gaussian = GaussianFilter(spatial_dims=3, sigma=self.rand_3d_elastic.sigma, truncated=3.0).to(device) - offset = torch.tensor(self.rand_3d_elastic.rand_offset, device=device).unsqueeze(0) + offset = torch.as_tensor(self.rand_3d_elastic.rand_offset, device=device).unsqueeze(0) grid[:3] += gaussian(offset)[0] * self.rand_3d_elastic.magnitude grid = self.rand_3d_elastic.rand_affine_grid(grid=grid) diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index 5e781dc4ee..87095fef99 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -221,6 +221,15 @@ def convert_data_type( If left blank, it remains unchanged. Returns: modified data, orig_type, orig_device + + Note: + When both `output_type` and `dtype` are specified with different backend + (e.g., `torch.Tensor` and `np.float32`), the `output_type` will be used as the primary type, + for example:: + + >>> convert_data_type(1, torch.Tensor, dtype=np.float32) + (1.0, , None) + """ orig_type: Any if isinstance(data, torch.Tensor): diff --git a/tests/test_center_scale_crop.py b/tests/test_center_scale_crop.py index e28849ce90..4c5bfc4fac 100644 --- a/tests/test_center_scale_crop.py +++ b/tests/test_center_scale_crop.py @@ -38,11 +38,13 @@ class TestCenterScaleCrop(unittest.TestCase): @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_3]) def test_shape(self, input_param, input_data, expected_shape): result = CenterScaleCrop(**input_param)(input_data) + self.assertEqual(isinstance(result, torch.Tensor), isinstance(input_data, torch.Tensor)) np.testing.assert_allclose(result.shape, expected_shape) @parameterized.expand([TEST_CASE_2]) def test_value(self, input_param, input_data, expected_value): result = CenterScaleCrop(**input_param)(input_data) + self.assertEqual(isinstance(result, torch.Tensor), isinstance(input_data, torch.Tensor)) np.testing.assert_allclose(result, expected_value) diff --git a/tests/test_center_spatial_crop.py b/tests/test_center_spatial_crop.py index 3e828176a5..d6a7edb305 100644 --- a/tests/test_center_spatial_crop.py +++ b/tests/test_center_spatial_crop.py @@ -38,11 +38,13 @@ class TestCenterSpatialCrop(unittest.TestCase): @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_3]) def test_shape(self, input_param, input_data, expected_shape): result = CenterSpatialCrop(**input_param)(input_data) + self.assertEqual(isinstance(result, torch.Tensor), isinstance(input_data, torch.Tensor)) np.testing.assert_allclose(result.shape, expected_shape) @parameterized.expand([TEST_CASE_2]) def test_value(self, input_param, input_data, expected_value): result = CenterSpatialCrop(**input_param)(input_data) + self.assertEqual(isinstance(result, torch.Tensor), isinstance(input_data, torch.Tensor)) np.testing.assert_allclose(result, expected_value) diff --git a/tests/test_center_spatial_cropd.py b/tests/test_center_spatial_cropd.py index 349253ab56..ed33e5bf88 100644 --- a/tests/test_center_spatial_cropd.py +++ b/tests/test_center_spatial_cropd.py @@ -15,33 +15,48 @@ from parameterized import parameterized from monai.transforms import CenterSpatialCropd +from tests.utils import TEST_NDARRAYS -TEST_CASE_0 = [ - {"keys": "img", "roi_size": [2, -1, -1]}, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - (3, 2, 3, 3), -] +TEST_SHAPES = [] +for p in TEST_NDARRAYS: + TEST_SHAPES.append( + [ + {"keys": "img", "roi_size": [2, -1, -1]}, + {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, + (3, 2, 3, 3), + ] + ) -TEST_CASE_1 = [ - {"keys": "img", "roi_size": [2, 2, 2]}, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - (3, 2, 2, 2), -] + TEST_SHAPES.append( + [ + {"keys": "img", "roi_size": [2, 2, 2]}, + {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, + (3, 2, 2, 2), + ] + ) -TEST_CASE_2 = [ - {"keys": "img", "roi_size": [2, 2]}, - {"img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]])}, - np.array([[[1, 2], [2, 3]]]), -] +TEST_CASES = [] +for p in TEST_NDARRAYS: + TEST_CASES.append( + [ + {"keys": "img", "roi_size": [2, 2]}, + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]) + ) + }, + p(np.array([[[1, 2], [2, 3]]])), + ] + ) class TestCenterSpatialCropd(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1]) + @parameterized.expand(TEST_SHAPES) def test_shape(self, input_param, input_data, expected_shape): result = CenterSpatialCropd(**input_param)(input_data) self.assertTupleEqual(result["img"].shape, expected_shape) - @parameterized.expand([TEST_CASE_2]) + @parameterized.expand(TEST_CASES) def test_value(self, input_param, input_data, expected_value): result = CenterSpatialCropd(**input_param)(input_data) np.testing.assert_allclose(result["img"], expected_value) diff --git a/tests/test_rand_deform_grid.py b/tests/test_rand_deform_grid.py index 7c12c263d2..4725e28339 100644 --- a/tests/test_rand_deform_grid.py +++ b/tests/test_rand_deform_grid.py @@ -12,10 +12,10 @@ import unittest import numpy as np -import torch from parameterized import parameterized from monai.transforms import RandDeformGrid +from tests.utils import assert_allclose TEST_CASES = [ [ @@ -129,11 +129,7 @@ def test_rand_deform_grid(self, input_param, input_data, expected_val): g = RandDeformGrid(**input_param) g.set_random_state(123) result = g(**input_data) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) - if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + assert_allclose(result, expected_val, type_test=False, rtol=1e-3, atol=1e-3) if __name__ == "__main__": diff --git a/tests/test_rand_scale_crop.py b/tests/test_rand_scale_crop.py index db5487ebff..a0c5471ffb 100644 --- a/tests/test_rand_scale_crop.py +++ b/tests/test_rand_scale_crop.py @@ -15,6 +15,7 @@ from parameterized import parameterized from monai.transforms import RandScaleCrop +from tests.utils import TEST_NDARRAYS, assert_allclose TEST_CASE_1 = [ {"roi_scale": [1.0, 1.0, -1.0], "random_center": True}, @@ -55,22 +56,25 @@ class TestRandScaleCrop(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_shape(self, input_param, input_data, expected_shape): - result = RandScaleCrop(**input_param)(input_data) - self.assertTupleEqual(result.shape, expected_shape) + for p in TEST_NDARRAYS: + result = RandScaleCrop(**input_param)(p(input_data)) + self.assertTupleEqual(result.shape, expected_shape) @parameterized.expand([TEST_CASE_3]) def test_value(self, input_param, input_data): - cropper = RandScaleCrop(**input_param) - result = cropper(input_data) - roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] - np.testing.assert_allclose(result, input_data[:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]]) + for p in TEST_NDARRAYS: + cropper = RandScaleCrop(**input_param) + result = cropper(p(input_data)) + roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] + assert_allclose(result, input_data[:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]], type_test=False) @parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) def test_random_shape(self, input_param, input_data, expected_shape): - cropper = RandScaleCrop(**input_param) - cropper.set_random_state(seed=123) - result = cropper(input_data) - self.assertTupleEqual(result.shape, expected_shape) + for p in TEST_NDARRAYS: + cropper = RandScaleCrop(**input_param) + cropper.set_random_state(seed=123) + result = cropper(p(input_data)) + self.assertTupleEqual(result.shape, expected_shape) if __name__ == "__main__": diff --git a/tests/test_rand_scale_cropd.py b/tests/test_rand_scale_cropd.py index 265c6c467d..f78a81d339 100644 --- a/tests/test_rand_scale_cropd.py +++ b/tests/test_rand_scale_cropd.py @@ -15,6 +15,7 @@ from parameterized import parameterized from monai.transforms import RandScaleCropd +from tests.utils import TEST_NDARRAYS, assert_allclose TEST_CASE_1 = [ {"keys": "img", "roi_scale": [1.0, 1.0, -1.0], "random_center": True}, @@ -66,10 +67,14 @@ def test_shape(self, input_param, input_data, expected_shape): @parameterized.expand([TEST_CASE_3]) def test_value(self, input_param, input_data): - cropper = RandScaleCropd(**input_param) - result = cropper(input_data) - roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] - np.testing.assert_allclose(result["img"], input_data["img"][:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]]) + for p in TEST_NDARRAYS: + cropper = RandScaleCropd(**input_param) + input_data["img"] = p(input_data["img"]) + result = cropper(input_data) + roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] + assert_allclose( + result["img"], input_data["img"][:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]], type_test=False + ) @parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) def test_random_shape(self, input_param, input_data, expected_shape): diff --git a/tests/test_rand_spatial_crop.py b/tests/test_rand_spatial_crop.py index 01e057e589..19b1841c6d 100644 --- a/tests/test_rand_spatial_crop.py +++ b/tests/test_rand_spatial_crop.py @@ -15,6 +15,7 @@ from parameterized import parameterized from monai.transforms import RandSpatialCrop +from tests.utils import TEST_NDARRAYS, assert_allclose TEST_CASE_0 = [ {"roi_size": [3, 3, -1], "random_center": True}, @@ -56,10 +57,11 @@ def test_shape(self, input_param, input_data, expected_shape): @parameterized.expand([TEST_CASE_3]) def test_value(self, input_param, input_data): - cropper = RandSpatialCrop(**input_param) - result = cropper(input_data) - roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] - np.testing.assert_allclose(result, input_data[:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]]) + for p in TEST_NDARRAYS: + cropper = RandSpatialCrop(**input_param) + result = cropper(p(input_data)) + roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] + assert_allclose(result, input_data[:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]], type_test=False) @parameterized.expand([TEST_CASE_4, TEST_CASE_5]) def test_random_shape(self, input_param, input_data, expected_shape): diff --git a/tests/test_rand_spatial_crop_samples.py b/tests/test_rand_spatial_crop_samples.py index 0ade9bbbba..eefe7d0e0a 100644 --- a/tests/test_rand_spatial_crop_samples.py +++ b/tests/test_rand_spatial_crop_samples.py @@ -15,6 +15,7 @@ from parameterized import parameterized from monai.transforms import RandSpatialCropSamples +from tests.utils import TEST_NDARRAYS, assert_allclose TEST_CASE_1 = [ {"roi_size": [3, 3, 3], "num_samples": 4, "random_center": True, "random_size": False}, @@ -70,14 +71,15 @@ class TestRandSpatialCropSamples(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_shape(self, input_param, input_data, expected_shape, expected_last_item): - xform = RandSpatialCropSamples(**input_param) - xform.set_random_state(1234) - result = xform(input_data) + for p in TEST_NDARRAYS: + xform = RandSpatialCropSamples(**input_param) + xform.set_random_state(1234) + result = xform(p(input_data)) - np.testing.assert_equal(len(result), input_param["num_samples"]) - for item, expected in zip(result, expected_shape): - self.assertTupleEqual(item.shape, expected) - np.testing.assert_allclose(result[-1], expected_last_item) + np.testing.assert_equal(len(result), input_param["num_samples"]) + for item, expected in zip(result, expected_shape): + self.assertTupleEqual(item.shape, expected) + assert_allclose(result[-1], expected_last_item, type_test=False) if __name__ == "__main__": diff --git a/tests/test_rand_spatial_crop_samplesd.py b/tests/test_rand_spatial_crop_samplesd.py index 3f5eee7b27..4b41ce3344 100644 --- a/tests/test_rand_spatial_crop_samplesd.py +++ b/tests/test_rand_spatial_crop_samplesd.py @@ -15,6 +15,7 @@ from parameterized import parameterized from monai.transforms import Compose, RandSpatialCropSamplesd, ToTensord +from tests.utils import TEST_NDARRAYS, assert_allclose TEST_CASE_1 = [ {"keys": ["img", "seg"], "num_samples": 4, "roi_size": [2, 2, 2], "random_center": True}, @@ -38,31 +39,48 @@ }, ] -TEST_CASE_2 = [ - {"keys": ["img", "seg"], "num_samples": 8, "roi_size": [2, 2, 3], "random_center": False}, - {"img": np.arange(81).reshape(3, 3, 3, 3), "seg": np.arange(81, 0, -1).reshape(3, 3, 3, 3)}, - [(3, 3, 3, 3), (3, 2, 3, 3), (3, 2, 2, 3), (3, 2, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3), (3, 2, 2, 3), (3, 3, 2, 3)], - { - "img": np.array( +TEST_CASE_2 = [] +for p in TEST_NDARRAYS: + TEST_CASE_2.append( + [ + {"keys": ["img", "seg"], "num_samples": 8, "roi_size": [2, 2, 3], "random_center": False}, + {"img": p(np.arange(81).reshape(3, 3, 3, 3)), "seg": p(np.arange(81, 0, -1).reshape(3, 3, 3, 3))}, [ - [[[0, 1, 2], [3, 4, 5]], [[9, 10, 11], [12, 13, 14]], [[18, 19, 20], [21, 22, 23]]], - [[[27, 28, 29], [30, 31, 32]], [[36, 37, 38], [39, 40, 41]], [[45, 46, 47], [48, 49, 50]]], - [[[54, 55, 56], [57, 58, 59]], [[63, 64, 65], [66, 67, 68]], [[72, 73, 74], [75, 76, 77]]], - ] - ), - "seg": np.array( - [ - [[[81, 80, 79], [78, 77, 76]], [[72, 71, 70], [69, 68, 67]], [[63, 62, 61], [60, 59, 58]]], - [[[54, 53, 52], [51, 50, 49]], [[45, 44, 43], [42, 41, 40]], [[36, 35, 34], [33, 32, 31]]], - [[[27, 26, 25], [24, 23, 22]], [[18, 17, 16], [15, 14, 13]], [[9, 8, 7], [6, 5, 4]]], - ] - ), - }, -] + (3, 3, 3, 3), + (3, 2, 3, 3), + (3, 2, 2, 3), + (3, 2, 3, 3), + (3, 3, 3, 3), + (3, 3, 3, 3), + (3, 2, 2, 3), + (3, 3, 2, 3), + ], + { + "img": p( + np.array( + [ + [[[0, 1, 2], [3, 4, 5]], [[9, 10, 11], [12, 13, 14]], [[18, 19, 20], [21, 22, 23]]], + [[[27, 28, 29], [30, 31, 32]], [[36, 37, 38], [39, 40, 41]], [[45, 46, 47], [48, 49, 50]]], + [[[54, 55, 56], [57, 58, 59]], [[63, 64, 65], [66, 67, 68]], [[72, 73, 74], [75, 76, 77]]], + ] + ) + ), + "seg": p( + np.array( + [ + [[[81, 80, 79], [78, 77, 76]], [[72, 71, 70], [69, 68, 67]], [[63, 62, 61], [60, 59, 58]]], + [[[54, 53, 52], [51, 50, 49]], [[45, 44, 43], [42, 41, 40]], [[36, 35, 34], [33, 32, 31]]], + [[[27, 26, 25], [24, 23, 22]], [[18, 17, 16], [15, 14, 13]], [[9, 8, 7], [6, 5, 4]]], + ] + ) + ), + }, + ] + ) class TestRandSpatialCropSamplesd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand([TEST_CASE_1, *TEST_CASE_2]) def test_shape(self, input_param, input_data, expected_shape, expected_last): xform = RandSpatialCropSamplesd(**input_param) xform.set_random_state(1234) @@ -73,8 +91,8 @@ def test_shape(self, input_param, input_data, expected_shape, expected_last): for i, item in enumerate(result): self.assertEqual(item["img_meta_dict"]["patch_index"], i) self.assertEqual(item["seg_meta_dict"]["patch_index"], i) - np.testing.assert_allclose(item["img"], expected_last["img"]) - np.testing.assert_allclose(item["seg"], expected_last["seg"]) + assert_allclose(item["img"], expected_last["img"], type_test=True) + assert_allclose(item["seg"], expected_last["seg"], type_test=True) def test_deep_copy(self): data = {"img": np.ones((1, 10, 11, 12))} diff --git a/tests/test_rand_spatial_cropd.py b/tests/test_rand_spatial_cropd.py index 610c1974aa..edcb61dc99 100644 --- a/tests/test_rand_spatial_cropd.py +++ b/tests/test_rand_spatial_cropd.py @@ -15,6 +15,7 @@ from parameterized import parameterized from monai.transforms import RandSpatialCropd +from tests.utils import TEST_NDARRAYS TEST_CASE_0 = [ {"keys": "img", "roi_size": [3, 3, -1], "random_center": True}, @@ -67,10 +68,12 @@ def test_value(self, input_param, input_data): @parameterized.expand([TEST_CASE_4, TEST_CASE_5]) def test_random_shape(self, input_param, input_data, expected_shape): - cropper = RandSpatialCropd(**input_param) - cropper.set_random_state(seed=123) - result = cropper(input_data) - self.assertTupleEqual(result["img"].shape, expected_shape) + for p in TEST_NDARRAYS: + cropper = RandSpatialCropd(**input_param) + cropper.set_random_state(seed=123) + input_data["img"] = p(input_data["img"]) + result = cropper(input_data) + self.assertTupleEqual(result["img"].shape, expected_shape) if __name__ == "__main__": From e72f5d2fb91d1f080a9965dcc2ca022babfb8400 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 20 Sep 2021 00:21:27 +0100 Subject: [PATCH 15/20] fixes invert Signed-off-by: Wenqi Li --- monai/transforms/spatial/dictionary.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index e4252f245c..195a68605c 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -826,20 +826,19 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # if transform was not performed and spatial size is None, nothing to do. - if not transform[InverseKeys.DO_TRANSFORM] and self.rand_affine.spatial_size is None: - continue - orig_size = transform[InverseKeys.ORIG_SIZE] - # Create inverse transform - fwd_affine = transform[InverseKeys.EXTRA_INFO]["affine"] - mode = transform[InverseKeys.EXTRA_INFO]["mode"] - padding_mode = transform[InverseKeys.EXTRA_INFO]["padding_mode"] - inv_affine = np.linalg.inv(fwd_affine) + if transform[InverseKeys.DO_TRANSFORM] or self.rand_affine.spatial_size is not None: + orig_size = transform[InverseKeys.ORIG_SIZE] + # Create inverse transform + fwd_affine = transform[InverseKeys.EXTRA_INFO]["affine"] + mode = transform[InverseKeys.EXTRA_INFO]["mode"] + padding_mode = transform[InverseKeys.EXTRA_INFO]["padding_mode"] + inv_affine = np.linalg.inv(fwd_affine) - affine_grid = AffineGrid(affine=inv_affine) - grid, _ = affine_grid(orig_size) # type: ignore + affine_grid = AffineGrid(affine=inv_affine) + grid, _ = affine_grid(orig_size) # type: ignore - # Apply inverse transform - d[key] = self.rand_affine.resampler(d[key], grid, mode, padding_mode) + # Apply inverse transform + d[key] = self.rand_affine.resampler(d[key], grid, mode, padding_mode) # Remove the applied transform self.pop_transform(d, key) From d9b6b6b327071b81f7cd6e4768257475dea154e9 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 20 Sep 2021 00:54:18 +0100 Subject: [PATCH 16/20] tensor resize, unit test fixes Signed-off-by: Wenqi Li --- monai/transforms/spatial/array.py | 32 +++++++++++++++----------- monai/transforms/spatial/dictionary.py | 6 +++-- tests/test_resize.py | 24 +++++++++++-------- tests/test_resized.py | 24 +++++++++++-------- 4 files changed, 52 insertions(+), 34 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index e88dd394c4..6f662f2dce 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -362,6 +362,8 @@ class Resize(Transform): See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate """ + backend = [TransformBackends.TORCH] + def __init__( self, spatial_size: Union[Sequence[int], int], @@ -376,10 +378,10 @@ def __init__( def __call__( self, - img: np.ndarray, + img: NdarrayOrTensor, mode: Optional[Union[InterpolateMode, str]] = None, align_corners: Optional[bool] = None, - ) -> np.ndarray: + ) -> NdarrayOrTensor: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]). @@ -394,33 +396,33 @@ def __call__( ValueError: When ``self.spatial_size`` length is less than ``img`` spatial dimensions. """ - img, *_ = convert_data_type(img, np.ndarray) # type: ignore + img_, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float) # type: ignore if self.size_mode == "all": - input_ndim = img.ndim - 1 # spatial ndim + input_ndim = img_.ndim - 1 # spatial ndim output_ndim = len(ensure_tuple(self.spatial_size)) if output_ndim > input_ndim: - input_shape = ensure_tuple_size(img.shape, output_ndim + 1, 1) - img = img.reshape(input_shape) + input_shape = ensure_tuple_size(img_.shape, output_ndim + 1, 1) + img_ = img_.reshape(input_shape) elif output_ndim < input_ndim: raise ValueError( "len(spatial_size) must be greater or equal to img spatial dimensions, " f"got spatial_size={output_ndim} img={input_ndim}." ) - spatial_size_ = fall_back_tuple(self.spatial_size, img.shape[1:]) + spatial_size_ = fall_back_tuple(self.spatial_size, img_.shape[1:]) else: # for the "longest" mode - img_size = img.shape[1:] + img_size = img_.shape[1:] if not isinstance(self.spatial_size, int): raise ValueError("spatial_size must be an int number if size_mode is 'longest'.") scale = self.spatial_size / max(img_size) spatial_size_ = tuple(int(round(s * scale)) for s in img_size) resized = torch.nn.functional.interpolate( # type: ignore - input=torch.as_tensor(np.ascontiguousarray(img), dtype=torch.float).unsqueeze(0), + input=img_.unsqueeze(0), # type: ignore size=spatial_size_, mode=look_up_option(self.mode if mode is None else mode, InterpolateMode).value, align_corners=self.align_corners if align_corners is None else align_corners, ) - resized = resized.squeeze(0).detach().cpu().numpy() - return np.asarray(resized) + out, *_ = convert_to_dst_type(resized.squeeze(0), img) + return out class Rotate(Transform, ThreadUnsafe): @@ -1094,7 +1096,7 @@ def __call__( else: affine = self.affine - grid, *_ = convert_data_type(grid, torch.Tensor, device=self.device, dtype=float) + grid, *_ = convert_data_type(grid, torch.Tensor, device=_device, dtype=float) affine, *_ = convert_to_dst_type(affine, grid) grid = (affine @ grid.reshape((grid.shape[0], -1))).reshape([-1] + list(grid.shape[1:])) @@ -1216,6 +1218,8 @@ class RandDeformGrid(Randomizable, Transform): Generate random deformation grid. """ + backend = [TransformBackends.TORCH] + def __init__( self, spacing: Union[Sequence[float], float], @@ -1913,8 +1917,8 @@ def __call__( if self._do_transform: if self.rand_offset is None: raise RuntimeError("rand_offset is not initialized.") - gaussian = GaussianFilter(3, self.sigma, 3.0).to(device=self.device) - offset = torch.as_tensor(self.rand_offset, device=self.device).unsqueeze(0) + gaussian = GaussianFilter(3, self.sigma, 3.0).to(device=_device) + offset = torch.as_tensor(self.rand_offset, device=_device).unsqueeze(0) grid[:3] += gaussian(offset)[0] * self.magnitude grid = self.rand_affine_grid(grid=grid) out: NdarrayOrTensor = self.resampler( diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 195a68605c..f36300dea6 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -521,6 +521,8 @@ class Resized(MapTransform, InvertibleTransform): allow_missing_keys: don't raise exception if key is missing. """ + backend = Resize.backend + def __init__( self, keys: KeysCollection, @@ -535,7 +537,7 @@ def __init__( self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.resizer = Resize(spatial_size=spatial_size, size_mode=size_mode) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, mode, align_corners in self.key_iterator(d, self.mode, self.align_corners): self.push_transform( @@ -549,7 +551,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[key] = self.resizer(d[key], mode=mode, align_corners=align_corners) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) diff --git a/tests/test_resize.py b/tests/test_resize.py index e5ec5dd1a9..f6c4a8b14b 100644 --- a/tests/test_resize.py +++ b/tests/test_resize.py @@ -16,7 +16,7 @@ from parameterized import parameterized from monai.transforms import Resize -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose TEST_CASE_0 = [{"spatial_size": 15}, (6, 10, 15)] @@ -45,16 +45,22 @@ def test_correct_results(self, spatial_size, mode): _order = 1 if spatial_size == (32, -1): spatial_size = (32, 64) - expected = [] - for channel in self.imt[0]: - expected.append( - skimage.transform.resize( - channel, spatial_size, order=_order, clip=False, preserve_range=False, anti_aliasing=False - ) + expected = [ + skimage.transform.resize( + channel, + spatial_size, + order=_order, + clip=False, + preserve_range=False, + anti_aliasing=False, ) + for channel in self.imt[0] + ] + expected = np.stack(expected).astype(np.float32) - out = resize(self.imt[0]) - np.testing.assert_allclose(out, expected, atol=0.9) + for p in TEST_NDARRAYS: + out = resize(p(self.imt[0])) + assert_allclose(out, expected, type_test=False, atol=0.9) @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) def test_longest_shape(self, input_param, expected_shape): diff --git a/tests/test_resized.py b/tests/test_resized.py index 930faf00eb..47b8e8a704 100644 --- a/tests/test_resized.py +++ b/tests/test_resized.py @@ -16,7 +16,7 @@ from parameterized import parameterized from monai.transforms import Resized -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose TEST_CASE_0 = [{"keys": "img", "spatial_size": 15}, (6, 10, 15)] @@ -48,16 +48,22 @@ def test_correct_results(self, spatial_size, mode): _order = 1 if spatial_size == (32, -1): spatial_size = (32, 64) - expected = [] - for channel in self.imt[0]: - expected.append( - skimage.transform.resize( - channel, spatial_size, order=_order, clip=False, preserve_range=False, anti_aliasing=False - ) + expected = [ + skimage.transform.resize( + channel, + spatial_size, + order=_order, + clip=False, + preserve_range=False, + anti_aliasing=False, ) + for channel in self.imt[0] + ] + expected = np.stack(expected).astype(np.float32) - out = resize({"img": self.imt[0]})["img"] - np.testing.assert_allclose(out, expected, atol=0.9) + for p in TEST_NDARRAYS: + out = resize({"img": p(self.imt[0])})["img"] + assert_allclose(out, expected, type_test=False, atol=0.9) @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_longest_shape(self, input_param, expected_shape): From b2773cae249e8bb0a34e14e2056314f0105a0974 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 20 Sep 2021 00:57:46 +0100 Subject: [PATCH 17/20] unit test fix Signed-off-by: Wenqi Li --- tests/test_center_spatial_cropd.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_center_spatial_cropd.py b/tests/test_center_spatial_cropd.py index ed33e5bf88..8ffcdf4387 100644 --- a/tests/test_center_spatial_cropd.py +++ b/tests/test_center_spatial_cropd.py @@ -15,7 +15,7 @@ from parameterized import parameterized from monai.transforms import CenterSpatialCropd -from tests.utils import TEST_NDARRAYS +from tests.utils import TEST_NDARRAYS, assert_allclose TEST_SHAPES = [] for p in TEST_NDARRAYS: @@ -59,7 +59,7 @@ def test_shape(self, input_param, input_data, expected_shape): @parameterized.expand(TEST_CASES) def test_value(self, input_param, input_data, expected_value): result = CenterSpatialCropd(**input_param)(input_data) - np.testing.assert_allclose(result["img"], expected_value) + assert_allclose(result["img"], expected_value, type_test=False) if __name__ == "__main__": From e7d6281674c9a3c296ffd6cfc2d551a7e379dd99 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 20 Sep 2021 01:04:30 +0100 Subject: [PATCH 18/20] fixes codeformat Signed-off-by: Wenqi Li --- monai/data/png_writer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/data/png_writer.py b/monai/data/png_writer.py index 2baec3b872..52163e40ac 100644 --- a/monai/data/png_writer.py +++ b/monai/data/png_writer.py @@ -48,7 +48,7 @@ def write_png( """ if not isinstance(data, np.ndarray): - raise AssertionError("input data must be numpy array.") + raise ValueError("input data must be numpy array.") if len(data.shape) == 3 and data.shape[2] == 1: # PIL Image can't save image with 1 channel data = data.squeeze(2) if output_spatial_shape is not None: @@ -59,11 +59,11 @@ def write_png( _min, _max = np.min(data), np.max(data) if len(data.shape) == 3: data = np.moveaxis(data, -1, 0) # to channel first - data = xform(data) + data = xform(data) # type: ignore data = np.moveaxis(data, 0, -1) else: # (H, W) data = np.expand_dims(data, 0) # make a channel - data = xform(data)[0] # first channel + data = xform(data)[0] # type: ignore if mode != InterpolateMode.NEAREST: data = np.clip(data, _min, _max) # type: ignore From 218058be4aa28ee33e25219f0a26ccd1a707ce27 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 20 Sep 2021 07:54:45 +0100 Subject: [PATCH 19/20] resize_with_pad_or_crop Signed-off-by: Wenqi Li --- monai/transforms/croppad/array.py | 5 +++-- monai/transforms/croppad/dictionary.py | 2 ++ tests/test_resize_with_pad_or_crop.py | 17 ++++++++++++----- tests/test_resize_with_pad_or_cropd.py | 14 +++++++++++--- 4 files changed, 28 insertions(+), 10 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index cc47972f3c..276ba6104d 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -1128,6 +1128,8 @@ class ResizeWithPadOrCrop(Transform): """ + backend = list(set(SpatialPad.backend) & set(CenterSpatialCrop.backend)) + def __init__( self, spatial_size: Union[Sequence[int], int], @@ -1138,7 +1140,7 @@ def __init__( self.padder = SpatialPad(spatial_size=spatial_size, method=method, mode=mode, **np_kwargs) self.cropper = CenterSpatialCrop(roi_size=spatial_size) - def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = None) -> np.ndarray: + def __call__(self, img: NdarrayOrTensor, mode: Optional[Union[NumpyPadMode, str]] = None) -> NdarrayOrTensor: """ Args: img: data to pad or crop, assuming `img` is channel-first and @@ -1149,7 +1151,6 @@ def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = N If None, defaults to the ``mode`` in construction. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html """ - img, *_ = convert_data_type(img, np.ndarray) # type: ignore return self.padder(self.cropper(img), mode=mode) # type: ignore diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index f06be52e85..28df32eb4c 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -1387,6 +1387,8 @@ class ResizeWithPadOrCropd(MapTransform, InvertibleTransform): """ + backend = ResizeWithPadOrCrop.backend + def __init__( self, keys: KeysCollection, diff --git a/tests/test_resize_with_pad_or_crop.py b/tests/test_resize_with_pad_or_crop.py index 46f1fc86cc..2162a0bb1b 100644 --- a/tests/test_resize_with_pad_or_crop.py +++ b/tests/test_resize_with_pad_or_crop.py @@ -12,9 +12,11 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import ResizeWithPadOrCrop +from tests.utils import TEST_NDARRAYS TEST_CASES = [ [ @@ -48,11 +50,16 @@ class TestResizeWithPadOrCrop(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_pad_shape(self, input_param, input_shape, expected_shape): - paddcroper = ResizeWithPadOrCrop(**input_param) - result = paddcroper(np.zeros(input_shape)) - np.testing.assert_allclose(result.shape, expected_shape) - result = paddcroper(np.zeros(input_shape), mode="constant") - np.testing.assert_allclose(result.shape, expected_shape) + for p in TEST_NDARRAYS: + if isinstance(p(0), torch.Tensor) and ( + "constant_values" in input_param or input_param["mode"] == "reflect" + ): + continue + paddcroper = ResizeWithPadOrCrop(**input_param) + result = paddcroper(p(np.zeros(input_shape))) + np.testing.assert_allclose(result.shape, expected_shape) + result = paddcroper(p(np.zeros(input_shape)), mode="constant") + np.testing.assert_allclose(result.shape, expected_shape) if __name__ == "__main__": diff --git a/tests/test_resize_with_pad_or_cropd.py b/tests/test_resize_with_pad_or_cropd.py index 32a62a9e16..58f6c92a8f 100644 --- a/tests/test_resize_with_pad_or_cropd.py +++ b/tests/test_resize_with_pad_or_cropd.py @@ -12,9 +12,11 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import ResizeWithPadOrCropd +from tests.utils import TEST_NDARRAYS TEST_CASES = [ [ @@ -48,9 +50,15 @@ class TestResizeWithPadOrCropd(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_pad_shape(self, input_param, input_data, expected_val): - paddcroper = ResizeWithPadOrCropd(**input_param) - result = paddcroper(input_data) - np.testing.assert_allclose(result["img"].shape, expected_val) + for p in TEST_NDARRAYS: + if isinstance(p(0), torch.Tensor) and ( + "constant_values" in input_param or input_param["mode"] == "reflect" + ): + continue + paddcroper = ResizeWithPadOrCropd(**input_param) + input_data["img"] = p(input_data["img"]) + result = paddcroper(input_data) + np.testing.assert_allclose(result["img"].shape, expected_val) if __name__ == "__main__": From 54c5f0b05ec0c939863addb1f6abf6c5c5404d62 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 20 Sep 2021 08:21:44 +0100 Subject: [PATCH 20/20] fixes mypy issue Signed-off-by: Wenqi Li --- monai/transforms/croppad/dictionary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 28df32eb4c..2590bf2e77 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -1402,7 +1402,7 @@ def __init__( self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padcropper = ResizeWithPadOrCrop(spatial_size=spatial_size, method=method, **np_kwargs) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, m in self.key_iterator(d, self.mode): orig_size = d[key].shape[1:]