Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,15 +875,14 @@ def compute_shape_offset(
in_coords = [(-0.5, dim - 0.5) if scale_extent else (0.0, dim - 1.0) for dim in shape]
corners: np.ndarray = np.asarray(np.meshgrid(*in_coords, indexing="ij")).reshape((len(shape), -1))
corners = np.concatenate((corners, np.ones_like(corners[:1])))
corners = in_affine_ @ corners
try:
inv_mat = np.linalg.inv(out_affine_)
corners_out = np.linalg.solve(out_affine_, in_affine_) @ corners
except np.linalg.LinAlgError as e:
raise ValueError(f"Affine {out_affine_} is not invertible") from e
corners_out = inv_mat @ corners
corners = in_affine_ @ corners
all_dist = corners_out[:-1].copy()
corners_out = corners_out[:-1] / corners_out[-1]
out_shape = np.round(corners_out.ptp(axis=1)) if scale_extent else np.round(corners_out.ptp(axis=1) + 1.0)
all_dist = inv_mat[:-1, :-1] @ corners[:-1, :]
offset = None
for i in range(corners.shape[1]):
min_corner = np.min(all_dist - all_dist[:, i : i + 1], 1)
Expand Down
2 changes: 2 additions & 0 deletions monai/networks/layers/spatial_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,8 @@ def forward(
theta = torch.cat([theta, pad_affine], dim=1)
if tuple(theta.shape[1:]) not in ((3, 3), (4, 4)):
raise ValueError(f"theta must be Nx3x3 or Nx4x4, got {theta.shape}.")
if not torch.is_floating_point(theta):
raise ValueError(f"theta must be floating point data, got {theta.dtype}")

# validate `src`
if not isinstance(src, torch.Tensor):
Expand Down
18 changes: 17 additions & 1 deletion monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ def __init__(
recompute_affine: bool = False,
min_pixdim: Sequence[float] | float | None = None,
max_pixdim: Sequence[float] | float | None = None,
ensure_same_shape: bool = True,
allow_missing_keys: bool = False,
) -> None:
"""
Expand Down Expand Up @@ -396,6 +397,8 @@ def __init__(
max_pixdim: maximal input spacing to be resampled. If provided, input image with a smaller spacing than this
value will be kept in its original spacing (not be resampled to `pixdim`). Set it to `None` to use the
value of `pixdim`. Default to `None`.
ensure_same_shape: when the inputs have the same spatial shape, and almost the same pixdim,
whether to ensure exactly the same output spatial shape. Default to True.
allow_missing_keys: don't raise exception if key is missing.

"""
Expand All @@ -408,6 +411,7 @@ def __init__(
self.align_corners = ensure_tuple_rep(align_corners, len(self.keys))
self.dtype = ensure_tuple_rep(dtype, len(self.keys))
self.scale_extent = ensure_tuple_rep(scale_extent, len(self.keys))
self.ensure_same_shape = ensure_same_shape

@LazyTransform.lazy_evaluation.setter # type: ignore
def lazy_evaluation(self, val: bool) -> None:
Expand All @@ -416,18 +420,30 @@ def lazy_evaluation(self, val: bool) -> None:

def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:
d: dict = dict(data)

_init_shape, _pixdim, should_match = None, None, False
output_shape_k = None # tracking output shape

for key, mode, padding_mode, align_corners, dtype, scale_extent in self.key_iterator(
d, self.mode, self.padding_mode, self.align_corners, self.dtype, self.scale_extent
):
# resample array of each corresponding key
if self.ensure_same_shape and isinstance(d[key], MetaTensor):
if _init_shape is None and _pixdim is None:
_init_shape, _pixdim = d[key].peek_pending_shape(), d[key].pixdim
else:
should_match = np.allclose(_init_shape, d[key].peek_pending_shape()) and np.allclose(
_pixdim, d[key].pixdim, atol=1e-3
)
d[key] = self.spacing_transform(
data_array=d[key],
mode=mode,
padding_mode=padding_mode,
align_corners=align_corners,
dtype=dtype,
scale_extent=scale_extent,
output_spatial_shape=output_shape_k if should_match else None,
)
output_shape_k = d[key].peek_pending_shape() if isinstance(d[key], MetaTensor) else d[key].shape[1:]
return d

def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
Expand Down
46 changes: 23 additions & 23 deletions tests/test_global_mutual_information_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,28 +26,28 @@

EXPECTED_VALUE = {
"xyz_translation": [
-1.5860259532928467,
-0.5957175493240356,
-0.3855515122413635,
-0.28728482127189636,
-0.23416118323802948,
-0.19534644484519958,
-0.17001715302467346,
-0.15043553709983826,
-0.1366637945175171,
-0.12534910440444946,
-1.5860257,
-0.62433463,
-0.38217825,
-0.2905613,
-0.23233329,
-0.1961407,
-0.16905619,
-0.15100679,
-0.13666219,
-0.12635908,
],
"xyz_rotation": [
-1.5860259532928467,
-0.29977330565452576,
-0.18411292135715485,
-0.1582011878490448,
-0.16107326745986938,
-0.165723517537117,
-0.1970357596874237,
-0.1755618453025818,
-0.17100191116333008,
-0.17264796793460846,
-1.5860257,
-0.30265224,
-0.18666176,
-0.15887907,
-0.1625064,
-0.16603896,
-0.19222091,
-0.18158069,
-0.167644,
-0.16698098,
],
}

Expand Down Expand Up @@ -84,7 +84,7 @@ def transformation(translate_params=(0.0, 0.0, 0.0), rotate_params=(0.0, 0.0, 0.
numpy array of shape HWD
"""
transform_list = [
transforms.LoadImaged(keys="img"),
transforms.LoadImaged(keys="img", image_only=True),
transforms.Affined(
keys="img", translate_params=translate_params, rotate_params=rotate_params, device=None
),
Expand All @@ -94,7 +94,7 @@ def transformation(translate_params=(0.0, 0.0, 0.0), rotate_params=(0.0, 0.0, 0.
return transformation({"img": FILE_PATH})["img"]

a1 = transformation()
a1 = torch.tensor(a1).unsqueeze(0).unsqueeze(0).to(device)
a1 = a1.clone().unsqueeze(0).unsqueeze(0).to(device)

for mode in transform_params_dict:
transform_params_list = transform_params_dict[mode]
Expand All @@ -104,7 +104,7 @@ def transformation(translate_params=(0.0, 0.0, 0.0), rotate_params=(0.0, 0.0, 0.
translate_params=transform_params if "translation" in mode else (0.0, 0.0, 0.0),
rotate_params=transform_params if "rotation" in mode else (0.0, 0.0, 0.0),
)
a2 = torch.tensor(a2).unsqueeze(0).unsqueeze(0).to(device)
a2 = a2.clone().unsqueeze(0).unsqueeze(0).to(device)
result = loss_fn(a2, a1).detach().cpu().numpy()
np.testing.assert_allclose(result, expected_value, rtol=1e-3, atol=5e-3)

Expand Down
24 changes: 24 additions & 0 deletions tests/test_spacingd.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,30 @@ def test_orntd_torch(self, init_param, img: torch.Tensor, track_meta: bool, devi
self.assertNotIsInstance(res, MetaTensor)
self.assertNotEqual(img.shape, res.shape)

def test_space_same_shape(self):
affine_1 = np.array(
[
[1.499277e00, 2.699563e-02, 3.805804e-02, -1.948635e02],
[-2.685805e-02, 1.499757e00, -2.635604e-12, 4.438188e01],
[-3.805194e-02, -5.999028e-04, 1.499517e00, 4.036536e01],
[0.000000e00, 0.000000e00, 0.000000e00, 1.000000e00],
]
)
affine_2 = np.array(
[
[1.499275e00, 2.692252e-02, 3.805728e-02, -1.948635e02],
[-2.693010e-02, 1.499758e00, -4.260525e-05, 4.438188e01],
[-3.805190e-02, -6.406730e-04, 1.499517e00, 4.036536e01],
[0.000000e00, 0.000000e00, 0.000000e00, 1.000000e00],
]
)
img_1 = MetaTensor(np.zeros((1, 238, 145, 315)), affine=affine_1)
img_2 = MetaTensor(np.zeros((1, 238, 145, 315)), affine=affine_2)
out = Spacingd(("img_1", "img_2"), pixdim=1)({"img_1": img_1, "img_2": img_2})
self.assertEqual(out["img_1"].shape, out["img_2"].shape) # ensure_same_shape True
out = Spacingd(("img_1", "img_2"), pixdim=1, ensure_same_shape=False)({"img_1": img_1, "img_2": img_2})
self.assertNotEqual(out["img_1"].shape, out["img_2"].shape) # ensure_same_shape False


if __name__ == "__main__":
unittest.main()