From 57b8233f4b1ebb9785655b2995996e9d8be698c5 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 10 Mar 2023 09:15:56 +0000 Subject: [PATCH 1/4] udpate Signed-off-by: Wenqi Li --- monai/data/utils.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/monai/data/utils.py b/monai/data/utils.py index 6501122e2a..c126a634c3 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -875,15 +875,14 @@ def compute_shape_offset( in_coords = [(-0.5, dim - 0.5) if scale_extent else (0.0, dim - 1.0) for dim in shape] corners: np.ndarray = np.asarray(np.meshgrid(*in_coords, indexing="ij")).reshape((len(shape), -1)) corners = np.concatenate((corners, np.ones_like(corners[:1]))) - corners = in_affine_ @ corners try: - inv_mat = np.linalg.inv(out_affine_) + corners_out = np.linalg.solve(out_affine_, in_affine_) @ corners except np.linalg.LinAlgError as e: raise ValueError(f"Affine {out_affine_} is not invertible") from e - corners_out = inv_mat @ corners + corners = in_affine_ @ corners + all_dist = corners_out[:-1].copy() corners_out = corners_out[:-1] / corners_out[-1] out_shape = np.round(corners_out.ptp(axis=1)) if scale_extent else np.round(corners_out.ptp(axis=1) + 1.0) - all_dist = inv_mat[:-1, :-1] @ corners[:-1, :] offset = None for i in range(corners.shape[1]): min_corner = np.min(all_dist - all_dist[:, i : i + 1], 1) From 958bd7b3f369ffe25be60e58abf45e71f4a3d555 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 14 Mar 2023 22:57:28 +0000 Subject: [PATCH 2/4] enhance spacing Signed-off-by: Wenqi Li --- monai/networks/layers/spatial_transforms.py | 2 ++ monai/transforms/spatial/dictionary.py | 18 +++++++++++++++- tests/test_spacingd.py | 24 +++++++++++++++++++++ 3 files changed, 43 insertions(+), 1 deletion(-) diff --git a/monai/networks/layers/spatial_transforms.py b/monai/networks/layers/spatial_transforms.py index ff5b0a3b89..5656e379c5 100644 --- a/monai/networks/layers/spatial_transforms.py +++ b/monai/networks/layers/spatial_transforms.py @@ -537,6 +537,8 @@ def forward( theta = torch.cat([theta, pad_affine], dim=1) if tuple(theta.shape[1:]) not in ((3, 3), (4, 4)): raise ValueError(f"theta must be Nx3x3 or Nx4x4, got {theta.shape}.") + if not torch.is_floating_point(theta): + raise ValueError(f"theta must be floating point data, got {theta.dtype}") # validate `src` if not isinstance(src, torch.Tensor): diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 4c1fe4f268..ea734543bc 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -333,6 +333,7 @@ def __init__( recompute_affine: bool = False, min_pixdim: Sequence[float] | float | None = None, max_pixdim: Sequence[float] | float | None = None, + ensure_same_shape: bool = True, allow_missing_keys: bool = False, ) -> None: """ @@ -390,6 +391,8 @@ def __init__( max_pixdim: maximal input spacing to be resampled. If provided, input image with a smaller spacing than this value will be kept in its original spacing (not be resampled to `pixdim`). Set it to `None` to use the value of `pixdim`. Default to `None`. + ensure_same_shape: when the inputs have the same spatial shape, and almost the same pixdim, + whether to ensure exactly the same output spatial shape. Default to True. allow_missing_keys: don't raise exception if key is missing. """ @@ -402,13 +405,24 @@ def __init__( self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.dtype = ensure_tuple_rep(dtype, len(self.keys)) self.scale_extent = ensure_tuple_rep(scale_extent, len(self.keys)) + self.ensure_same_shape = ensure_same_shape def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d: dict = dict(data) + + _init_shape, _pixdim, should_match = None, None, False + output_shape_k = None # tracking output shape + for key, mode, padding_mode, align_corners, dtype, scale_extent in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners, self.dtype, self.scale_extent ): - # resample array of each corresponding key + if self.ensure_same_shape and isinstance(d[key], MetaTensor): + if _init_shape is None and _pixdim is None: + _init_shape, _pixdim = d[key].peek_pending_shape(), d[key].pixdim + else: + should_match = np.allclose(_init_shape, d[key].peek_pending_shape()) and np.allclose( + _pixdim, d[key].pixdim, atol=1e-3 + ) d[key] = self.spacing_transform( data_array=d[key], mode=mode, @@ -416,7 +430,9 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc align_corners=align_corners, dtype=dtype, scale_extent=scale_extent, + output_spatial_shape=output_shape_k if should_match else None, ) + output_shape_k = d[key].peek_pending_shape() if isinstance(d[key], MetaTensor) else d[key].shape[1:] return d def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: diff --git a/tests/test_spacingd.py b/tests/test_spacingd.py index a77c3636fa..ea8f7755c7 100644 --- a/tests/test_spacingd.py +++ b/tests/test_spacingd.py @@ -118,6 +118,30 @@ def test_orntd_torch(self, init_param, img: torch.Tensor, track_meta: bool, devi self.assertNotIsInstance(res, MetaTensor) self.assertNotEqual(img.shape, res.shape) + def test_space_same_shape(self): + affine_1 = np.array( + [ + [1.499277e00, 2.699563e-02, 3.805804e-02, -1.948635e02], + [-2.685805e-02, 1.499757e00, -2.635604e-12, 4.438188e01], + [-3.805194e-02, -5.999028e-04, 1.499517e00, 4.036536e01], + [0.000000e00, 0.000000e00, 0.000000e00, 1.000000e00], + ] + ) + affine_2 = np.array( + [ + [1.499275e00, 2.692252e-02, 3.805728e-02, -1.948635e02], + [-2.693010e-02, 1.499758e00, -4.260525e-05, 4.438188e01], + [-3.805190e-02, -6.406730e-04, 1.499517e00, 4.036536e01], + [0.000000e00, 0.000000e00, 0.000000e00, 1.000000e00], + ] + ) + img_1 = MetaTensor(np.zeros((1, 238, 145, 315)), affine=affine_1) + img_2 = MetaTensor(np.zeros((1, 238, 145, 315)), affine=affine_2) + out = Spacingd(("img_1", "img_2"), pixdim=1)({"img_1": img_1, "img_2": img_2}) + self.assertEqual(out["img_1"].shape, out["img_2"].shape) # ensure_same_shape True + out = Spacingd(("img_1", "img_2"), pixdim=1, ensure_same_shape=False)({"img_1": img_1, "img_2": img_2}) + self.assertNotEqual(out["img_1"].shape, out["img_2"].shape) # ensure_same_shape False + if __name__ == "__main__": unittest.main() From 77d9bd278abe72c143bd874bf2b59a478e02d961 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 15 Mar 2023 01:06:16 +0000 Subject: [PATCH 3/4] update tests Signed-off-by: Wenqi Li --- tests/test_global_mutual_information_loss.py | 48 ++++++++++---------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/tests/test_global_mutual_information_loss.py b/tests/test_global_mutual_information_loss.py index af66de46b2..5847db1388 100644 --- a/tests/test_global_mutual_information_loss.py +++ b/tests/test_global_mutual_information_loss.py @@ -26,28 +26,28 @@ EXPECTED_VALUE = { "xyz_translation": [ - -1.5860259532928467, - -0.5957175493240356, - -0.3855515122413635, - -0.28728482127189636, - -0.23416118323802948, - -0.19534644484519958, - -0.17001715302467346, - -0.15043553709983826, - -0.1366637945175171, - -0.12534910440444946, + -1.5860257, + -0.62433463, + -0.38217825, + -0.2905613, + -0.23233329, + -0.1961407, + -0.16905619, + -0.15100679, + -0.13666219, + -0.12635908, ], "xyz_rotation": [ - -1.5860259532928467, - -0.29977330565452576, - -0.18411292135715485, - -0.1582011878490448, - -0.16107326745986938, - -0.165723517537117, - -0.1970357596874237, - -0.1755618453025818, - -0.17100191116333008, - -0.17264796793460846, + -1.5860257, + -0.30265224, + -0.18666176, + -0.15887907, + -0.1625064, + -0.16603896, + -0.19222091, + -0.18158069, + -0.167644, + -0.16698098, ], } @@ -84,9 +84,9 @@ def transformation(translate_params=(0.0, 0.0, 0.0), rotate_params=(0.0, 0.0, 0. numpy array of shape HWD """ transform_list = [ - transforms.LoadImaged(keys="img"), + transforms.LoadImaged(keys="img", image_only=True), transforms.Affined( - keys="img", translate_params=translate_params, rotate_params=rotate_params, device=None + keys="img", translate_params=translate_params, rotate_params=rotate_params, device=None, ), transforms.NormalizeIntensityd(keys=["img"]), ] @@ -94,7 +94,7 @@ def transformation(translate_params=(0.0, 0.0, 0.0), rotate_params=(0.0, 0.0, 0. return transformation({"img": FILE_PATH})["img"] a1 = transformation() - a1 = torch.tensor(a1).unsqueeze(0).unsqueeze(0).to(device) + a1 = a1.clone().unsqueeze(0).unsqueeze(0).to(device) for mode in transform_params_dict: transform_params_list = transform_params_dict[mode] @@ -104,7 +104,7 @@ def transformation(translate_params=(0.0, 0.0, 0.0), rotate_params=(0.0, 0.0, 0. translate_params=transform_params if "translation" in mode else (0.0, 0.0, 0.0), rotate_params=transform_params if "rotation" in mode else (0.0, 0.0, 0.0), ) - a2 = torch.tensor(a2).unsqueeze(0).unsqueeze(0).to(device) + a2 = a2.clone().unsqueeze(0).unsqueeze(0).to(device) result = loss_fn(a2, a1).detach().cpu().numpy() np.testing.assert_allclose(result, expected_value, rtol=1e-3, atol=5e-3) From 8974c6c6bdb0ced5f4c8835e773a0b64bf93ac0d Mon Sep 17 00:00:00 2001 From: monai-bot Date: Wed, 15 Mar 2023 01:47:18 +0000 Subject: [PATCH 4/4] [MONAI] code formatting Signed-off-by: monai-bot --- tests/test_global_mutual_information_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_global_mutual_information_loss.py b/tests/test_global_mutual_information_loss.py index 5847db1388..88d0a78716 100644 --- a/tests/test_global_mutual_information_loss.py +++ b/tests/test_global_mutual_information_loss.py @@ -86,7 +86,7 @@ def transformation(translate_params=(0.0, 0.0, 0.0), rotate_params=(0.0, 0.0, 0. transform_list = [ transforms.LoadImaged(keys="img", image_only=True), transforms.Affined( - keys="img", translate_params=translate_params, rotate_params=rotate_params, device=None, + keys="img", translate_params=translate_params, rotate_params=rotate_params, device=None ), transforms.NormalizeIntensityd(keys=["img"]), ]