Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,7 +1065,7 @@ def __call__(
"""
if grid is None:
if spatial_size is not None:
grid = create_grid(spatial_size)
grid = create_grid(spatial_size, dtype=float)
else:
raise ValueError("Incompatible values: grid=None and spatial_size=None.")

Expand All @@ -1084,9 +1084,7 @@ def __call__(
else:
affine = self.affine

if self.device not in (None, torch.device("cpu"), "cpu"):
grid, *_ = convert_data_type(grid, torch.Tensor, device=self.device)
grid, *_ = convert_data_type(grid, dtype=float)
grid, *_ = convert_data_type(grid, torch.Tensor, device=self.device, dtype=float)
affine, *_ = convert_to_dst_type(affine, grid)

grid = (affine @ grid.reshape((grid.shape[0], -1))).reshape([-1] + list(grid.shape[1:]))
Expand Down
6 changes: 2 additions & 4 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,10 +795,8 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
sp_size = fall_back_tuple(self.rand_affine.spatial_size, data[self.keys[0]].shape[1:])
# change image size or do random transform
do_resampling = self._do_transform or (sp_size != ensure_tuple(data[self.keys[0]].shape[1:]))

affine: NdarrayOrTensor = np.eye(len(sp_size) + 1, dtype=np.float64)
if device not in (None, torch.device("cpu"), "cpu"):
affine, *_ = convert_data_type(affine, torch.Tensor, device=device)
affine: torch.Tensor = torch.eye(len(sp_size) + 1, dtype=torch.float64, device=device)
# converting affine to tensor because the resampler currently only support torch backend
grid = None
if do_resampling: # need to prepare grid
grid = self.rand_affine.get_identity_grid(sp_size)
Expand Down
5 changes: 5 additions & 0 deletions tests/test_rand_affined.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,11 @@ def test_rand_affined(self, input_param, input_data, expected_val):
expected = expected_val[key] if isinstance(expected_val, dict) else expected_val
assert_allclose(result, expected, rtol=1e-4, atol=1e-4)

g.set_random_state(4)
res = g(input_data)
# affine should be tensor because the resampler only supports pytorch backend
self.assertTrue(isinstance(res["img_transforms"][0]["extra_info"]["affine"], torch.Tensor))

def test_ill_cache(self):
with self.assertWarns(UserWarning):
# spatial size is None
Expand Down