From 373ed816d401d00e0193046ed080db992549e28d Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 26 Mar 2021 07:46:55 +0000 Subject: [PATCH 1/2] fixes #1857, SpatialCrop is compatible with tensors Signed-off-by: Wenqi Li --- monai/transforms/croppad/array.py | 4 ++-- monai/transforms/spatial/array.py | 2 +- tests/test_center_spatial_crop.py | 9 ++++++++- tests/test_rand_elastic_2d.py | 2 +- tests/test_rand_elastic_3d.py | 2 +- tests/test_spatial_crop.py | 7 +++++++ 6 files changed, 20 insertions(+), 6 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 6174378e3b..159fa1a5f4 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -246,14 +246,14 @@ def __init__( self.roi_start = self.roi_start if isinstance(self.roi_start, np.ndarray) else np.array([self.roi_start]) self.roi_end = self.roi_end if isinstance(self.roi_end, np.ndarray) else np.array([self.roi_end]) - def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> np.ndarray: + def __call__(self, img: Union[np.ndarray, torch.Tensor]): """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. """ sd = min(self.roi_start.size, self.roi_end.size, len(img.shape[1:])) # spatial dims slices = [slice(None)] + [slice(s, e) for s, e in zip(self.roi_start[:sd], self.roi_end[:sd])] - return np.asarray(img[tuple(slices)]) + return img[tuple(slices)] class CenterSpatialCrop(Transform): diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 471b171312..1c096ba743 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -1563,7 +1563,7 @@ def __call__( mode=InterpolateMode.BICUBIC.value, align_corners=False, ) - grid = CenterSpatialCrop(roi_size=sp_size)(np.asarray(grid[0])) + grid = CenterSpatialCrop(roi_size=sp_size)(grid[0]) else: grid = create_grid(spatial_size=sp_size) return self.resampler(img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode) diff --git a/tests/test_center_spatial_crop.py b/tests/test_center_spatial_crop.py index c03ec24e18..3e828176a5 100644 --- a/tests/test_center_spatial_crop.py +++ b/tests/test_center_spatial_crop.py @@ -12,6 +12,7 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import CenterSpatialCrop @@ -26,9 +27,15 @@ np.array([[[1, 2], [2, 3]]]), ] +TEST_CASE_3 = [ + {"roi_size": [2, 2, 2]}, + torch.randint(0, 2, size=[3, 3, 3, 3], device="cuda" if torch.cuda.is_available() else "cpu"), + (3, 2, 2, 2), +] + class TestCenterSpatialCrop(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1]) + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_3]) def test_shape(self, input_param, input_data, expected_shape): result = CenterSpatialCrop(**input_param)(input_data) np.testing.assert_allclose(result.shape, expected_shape) diff --git a/tests/test_rand_elastic_2d.py b/tests/test_rand_elastic_2d.py index aa408f0fdc..fbfb7d5761 100644 --- a/tests/test_rand_elastic_2d.py +++ b/tests/test_rand_elastic_2d.py @@ -74,7 +74,7 @@ "scale_range": [0.01, 0.02], "prob": 0.9, "as_tensor_output": False, - "device": None, + "device": "cuda" if torch.cuda.is_available() else "cpu", "spatial_size": (2, 2), }, {"img": torch.arange(27).reshape((3, 3, 3))}, diff --git a/tests/test_rand_elastic_3d.py b/tests/test_rand_elastic_3d.py index 8cd74c6be7..c63282d571 100644 --- a/tests/test_rand_elastic_3d.py +++ b/tests/test_rand_elastic_3d.py @@ -59,7 +59,7 @@ "prob": 0.9, "rotate_range": [1, 1, 1], "as_tensor_output": False, - "device": None, + "device": "cuda" if torch.cuda.is_available() else "cpu", "spatial_size": (2, 2, 2), }, {"img": torch.arange(27).reshape((1, 3, 3, 3)), "mode": "bilinear"}, diff --git a/tests/test_spatial_crop.py b/tests/test_spatial_crop.py index f3c904889f..4c56929686 100644 --- a/tests/test_spatial_crop.py +++ b/tests/test_spatial_crop.py @@ -12,6 +12,7 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import SpatialCrop @@ -49,6 +50,12 @@ def test_shape(self, input_param, input_shape, expected_shape): result = SpatialCrop(**input_param)(input_data) self.assertTupleEqual(result.shape, expected_shape) + @parameterized.expand(TEST_CASES) + def test_tensor_shape(self, input_param, input_shape, expected_shape): + input_data = torch.randint(0, 2, size=input_shape, device="cuda" if torch.cuda.is_available() else "cpu") + result = SpatialCrop(**input_param)(input_data) + self.assertTupleEqual(result.shape, expected_shape) + if __name__ == "__main__": unittest.main() From 306ed32c08087d38d385ae19cccda557df850b81 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 26 Mar 2021 08:16:55 +0000 Subject: [PATCH 2/2] update val comparisons Signed-off-by: Wenqi Li --- tests/test_inverse.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index d54855d7c1..ccc4f366c2 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -491,7 +491,10 @@ def check_inverse(self, name, keys, orig_d, fwd_bck_d, unmodified_d, acceptable_ unmodified = unmodified_d[key] if isinstance(orig, np.ndarray): mean_diff = np.mean(np.abs(orig - fwd_bck)) - unmodded_diff = np.mean(np.abs(orig - ResizeWithPadOrCrop(orig.shape[1:])(unmodified))) + resized = ResizeWithPadOrCrop(orig.shape[1:])(unmodified) + if isinstance(resized, torch.Tensor): + resized = resized.detach().cpu().numpy() + unmodded_diff = np.mean(np.abs(orig - resized)) try: self.assertLessEqual(mean_diff, acceptable_diff) except AssertionError: