From 5fd339b3bc206eb001731b42905360b2ebba91e4 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 17 Sep 2021 14:04:28 +0100 Subject: [PATCH 1/2] fixes tutorial issue 352 Signed-off-by: Wenqi Li --- monai/transforms/spatial/array.py | 6 ++---- monai/transforms/spatial/dictionary.py | 5 ++--- tests/test_rand_affined.py | 5 +++++ 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 8b1fb854f2..df3d3eb093 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -1065,7 +1065,7 @@ def __call__( """ if grid is None: if spatial_size is not None: - grid = create_grid(spatial_size) + grid = create_grid(spatial_size, dtype=float) else: raise ValueError("Incompatible values: grid=None and spatial_size=None.") @@ -1084,9 +1084,7 @@ def __call__( else: affine = self.affine - if self.device not in (None, torch.device("cpu"), "cpu"): - grid, *_ = convert_data_type(grid, torch.Tensor, device=self.device) - grid, *_ = convert_data_type(grid, dtype=float) + grid, *_ = convert_data_type(grid, torch.Tensor, device=self.device, dtype=float) affine, *_ = convert_to_dst_type(affine, grid) grid = (affine @ grid.reshape((grid.shape[0], -1))).reshape([-1] + list(grid.shape[1:])) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index d794e51e80..6ebe80fd14 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -795,10 +795,9 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N sp_size = fall_back_tuple(self.rand_affine.spatial_size, data[self.keys[0]].shape[1:]) # change image size or do random transform do_resampling = self._do_transform or (sp_size != ensure_tuple(data[self.keys[0]].shape[1:])) - affine: NdarrayOrTensor = np.eye(len(sp_size) + 1, dtype=np.float64) - if device not in (None, torch.device("cpu"), "cpu"): - affine, *_ = convert_data_type(affine, torch.Tensor, device=device) + # converting affine to tensor because the resampler currently only support torch backend + affine, *_ = convert_data_type(affine, torch.Tensor, device=device) grid = None if do_resampling: # need to prepare grid grid = self.rand_affine.get_identity_grid(sp_size) diff --git a/tests/test_rand_affined.py b/tests/test_rand_affined.py index bec9602d62..0109175b16 100644 --- a/tests/test_rand_affined.py +++ b/tests/test_rand_affined.py @@ -211,6 +211,11 @@ def test_rand_affined(self, input_param, input_data, expected_val): expected = expected_val[key] if isinstance(expected_val, dict) else expected_val assert_allclose(result, expected, rtol=1e-4, atol=1e-4) + g.set_random_state(4) + res = g(input_data) + # affine should be tensor because the resampler only supports pytorch backend + self.assertTrue(isinstance(res["img_transforms"][0]["extra_info"]["affine"], torch.Tensor)) + def test_ill_cache(self): with self.assertWarns(UserWarning): # spatial size is None From 6d7d6614ef53b5b71980ccff2876bf862c732de2 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 17 Sep 2021 14:33:45 +0100 Subject: [PATCH 2/2] simplified Signed-off-by: Wenqi Li --- monai/transforms/spatial/dictionary.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 6ebe80fd14..487225cb60 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -795,9 +795,8 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N sp_size = fall_back_tuple(self.rand_affine.spatial_size, data[self.keys[0]].shape[1:]) # change image size or do random transform do_resampling = self._do_transform or (sp_size != ensure_tuple(data[self.keys[0]].shape[1:])) - affine: NdarrayOrTensor = np.eye(len(sp_size) + 1, dtype=np.float64) + affine: torch.Tensor = torch.eye(len(sp_size) + 1, dtype=torch.float64, device=device) # converting affine to tensor because the resampler currently only support torch backend - affine, *_ = convert_data_type(affine, torch.Tensor, device=device) grid = None if do_resampling: # need to prepare grid grid = self.rand_affine.get_identity_grid(sp_size)