diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 8b1fb854f2..df3d3eb093 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -1065,7 +1065,7 @@ def __call__( """ if grid is None: if spatial_size is not None: - grid = create_grid(spatial_size) + grid = create_grid(spatial_size, dtype=float) else: raise ValueError("Incompatible values: grid=None and spatial_size=None.") @@ -1084,9 +1084,7 @@ def __call__( else: affine = self.affine - if self.device not in (None, torch.device("cpu"), "cpu"): - grid, *_ = convert_data_type(grid, torch.Tensor, device=self.device) - grid, *_ = convert_data_type(grid, dtype=float) + grid, *_ = convert_data_type(grid, torch.Tensor, device=self.device, dtype=float) affine, *_ = convert_to_dst_type(affine, grid) grid = (affine @ grid.reshape((grid.shape[0], -1))).reshape([-1] + list(grid.shape[1:])) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index d794e51e80..487225cb60 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -795,10 +795,8 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N sp_size = fall_back_tuple(self.rand_affine.spatial_size, data[self.keys[0]].shape[1:]) # change image size or do random transform do_resampling = self._do_transform or (sp_size != ensure_tuple(data[self.keys[0]].shape[1:])) - - affine: NdarrayOrTensor = np.eye(len(sp_size) + 1, dtype=np.float64) - if device not in (None, torch.device("cpu"), "cpu"): - affine, *_ = convert_data_type(affine, torch.Tensor, device=device) + affine: torch.Tensor = torch.eye(len(sp_size) + 1, dtype=torch.float64, device=device) + # converting affine to tensor because the resampler currently only support torch backend grid = None if do_resampling: # need to prepare grid grid = self.rand_affine.get_identity_grid(sp_size) diff --git a/tests/test_rand_affined.py b/tests/test_rand_affined.py index bec9602d62..0109175b16 100644 --- a/tests/test_rand_affined.py +++ b/tests/test_rand_affined.py @@ -211,6 +211,11 @@ def test_rand_affined(self, input_param, input_data, expected_val): expected = expected_val[key] if isinstance(expected_val, dict) else expected_val assert_allclose(result, expected, rtol=1e-4, atol=1e-4) + g.set_random_state(4) + res = g(input_data) + # affine should be tensor because the resampler only supports pytorch backend + self.assertTrue(isinstance(res["img_transforms"][0]["extra_info"]["affine"], torch.Tensor)) + def test_ill_cache(self): with self.assertWarns(UserWarning): # spatial size is None