From 0321dc8a736d6ac2bf82596a0ef761bcc82a0732 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Tue, 3 Sep 2024 00:38:39 +0800 Subject: [PATCH 01/10] enhance ApplyTransformToPointsd Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/transforms/utility/dictionary.py | 29 +++++++++++++----------- tests/test_apply_transform_to_pointsd.py | 4 ++-- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 1279ca93ab..ce5c1446e8 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -1758,8 +1758,9 @@ class ApplyTransformToPointsd(MapTransform, InvertibleTransform): Args: keys: keys of the corresponding items to be transformed. See also: monai.transforms.MapTransform - refer_key: The key of the reference item used for transformation. - It can directly refer to an affine or an image from which the affine can be derived. + refer_keys: The key of the reference item used for transformation. + It can directly refer to an affine or an image from which the affine can be derived. It can also be a + sequence of keys. dtype: The desired data type for the output. affine: A 3x3 or 4x4 affine transformation matrix applied to points. This matrix typically originates from the image. For 2D points, a 3x3 matrix can be provided, avoiding the need to add an unnecessary @@ -1782,7 +1783,7 @@ class ApplyTransformToPointsd(MapTransform, InvertibleTransform): def __init__( self, keys: KeysCollection, - refer_key: str | None = None, + refer_keys: KeysCollection | None = None, dtype: DtypeLike | torch.dtype = torch.float64, affine: torch.Tensor | None = None, invert_affine: bool = True, @@ -1790,23 +1791,25 @@ def __init__( allow_missing_keys: bool = False, ): MapTransform.__init__(self, keys, allow_missing_keys) - self.refer_key = refer_key + self.refer_keys = ensure_tuple_rep(None, len(self.keys)) if refer_keys is None else ensure_tuple(refer_keys) + if len(self.keys) != len(self.refer_keys): + raise ValueError("refer_keys should have the same length as keys.") self.converter = ApplyTransformToPoints( dtype=dtype, affine=affine, invert_affine=invert_affine, affine_lps_to_ras=affine_lps_to_ras ) def __call__(self, data: Mapping[Hashable, torch.Tensor]): d = dict(data) - if self.refer_key is not None: - if self.refer_key in d: - refer_data = d[self.refer_key] - else: - raise KeyError(f"The refer_key '{self.refer_key}' is not found in the data.") - else: - refer_data = None - affine = getattr(refer_data, "affine", refer_data) - for key in self.key_iterator(d): + for key, refer_key in self.key_iterator(d, self.refer_keys): coords = d[key] + if refer_key is not None: + if refer_key in d: + refer_data = d[refer_key] + else: + raise KeyError(f"The refer_key '{refer_key}' is not found in the data.") + else: + refer_data = None + affine = getattr(refer_data, "affine", refer_data) d[key] = self.converter(coords, affine) return d diff --git a/tests/test_apply_transform_to_pointsd.py b/tests/test_apply_transform_to_pointsd.py index 4cedfa9d66..de465996bb 100644 --- a/tests/test_apply_transform_to_pointsd.py +++ b/tests/test_apply_transform_to_pointsd.py @@ -107,10 +107,10 @@ def test_transform_coordinates(self, image, points, affine, invert_affine, affin "point": points, "affine": torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]), } - refer_key = "image" if (image is not None and image != "affine") else image + refer_keys = "image" if (image is not None and image != "affine") else image transform = ApplyTransformToPointsd( keys="point", - refer_key=refer_key, + refer_keys=refer_keys, dtype=torch.int64, affine=affine, invert_affine=invert_affine, From dc366e8b671549bcffc10e12687bd695b840beb8 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Tue, 3 Sep 2024 15:55:55 +0800 Subject: [PATCH 02/10] Update monai/transforms/utility/dictionary.py Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/transforms/utility/dictionary.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index ce5c1446e8..4b9ee794d4 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -1802,14 +1802,15 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]): d = dict(data) for key, refer_key in self.key_iterator(d, self.refer_keys): coords = d[key] + affine = None # represents using affine given in constructor if refer_key is not None: if refer_key in d: refer_data = d[refer_key] else: raise KeyError(f"The refer_key '{refer_key}' is not found in the data.") - else: - refer_data = None - affine = getattr(refer_data, "affine", refer_data) + + # use the "affine" member of refer_data, or refer_data itself, as the affine matrix + affine = getattr(refer_data, "affine", refer_data) d[key] = self.converter(coords, affine) return d From 5713d5a5ed578c800355352c7c3b354d46116a6d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 3 Sep 2024 07:56:34 +0000 Subject: [PATCH 03/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/transforms/utility/dictionary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 4b9ee794d4..b97b81d5e5 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -1808,7 +1808,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]): refer_data = d[refer_key] else: raise KeyError(f"The refer_key '{refer_key}' is not found in the data.") - + # use the "affine" member of refer_data, or refer_data itself, as the affine matrix affine = getattr(refer_data, "affine", refer_data) d[key] = self.converter(coords, affine) From 16e2fd21ab88bd56475c0ec8b6f54836ec13824b Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Tue, 3 Sep 2024 15:57:01 +0800 Subject: [PATCH 04/10] Update monai/transforms/utility/dictionary.py Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/transforms/utility/dictionary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index b97b81d5e5..d50c064c1b 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -1760,7 +1760,7 @@ class ApplyTransformToPointsd(MapTransform, InvertibleTransform): See also: monai.transforms.MapTransform refer_keys: The key of the reference item used for transformation. It can directly refer to an affine or an image from which the affine can be derived. It can also be a - sequence of keys. + sequence of keys, in which case each refers to the affine applied to the matching points in `keys`. dtype: The desired data type for the output. affine: A 3x3 or 4x4 affine transformation matrix applied to points. This matrix typically originates from the image. For 2D points, a 3x3 matrix can be provided, avoiding the need to add an unnecessary From 253bec3d0148c4ee2b7492f1b4483c171a06f4d4 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Tue, 3 Sep 2024 16:52:05 +0800 Subject: [PATCH 05/10] add more tests Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- tests/test_apply_transform_to_pointsd.py | 127 +++++++++++------------ 1 file changed, 62 insertions(+), 65 deletions(-) diff --git a/tests/test_apply_transform_to_pointsd.py b/tests/test_apply_transform_to_pointsd.py index de465996bb..28d0a4898c 100644 --- a/tests/test_apply_transform_to_pointsd.py +++ b/tests/test_apply_transform_to_pointsd.py @@ -30,72 +30,35 @@ POINT_3D_WORLD = torch.tensor([[[2, 4, 6], [8, 10, 12]], [[14, 16, 18], [20, 22, 24]]]) POINT_3D_IMAGE = torch.tensor([[[-8, 8, 6], [-2, 14, 12]], [[4, 20, 18], [10, 26, 24]]]) POINT_3D_IMAGE_RAS = torch.tensor([[[-12, 0, 6], [-18, -6, 12]], [[-24, -12, 18], [-30, -18, 24]]]) +AFFINE_1 = torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) +AFFINE_2 = torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]) TEST_CASES = [ - [ - MetaTensor(DATA_2D, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])), - POINT_2D_WORLD, - None, - True, - False, - POINT_2D_IMAGE, - ], - [ - None, - MetaTensor(POINT_2D_IMAGE, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])), - None, - False, - False, - POINT_2D_WORLD, - ], - [ - None, - MetaTensor(POINT_2D_IMAGE, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])), - torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]), - False, - False, - POINT_2D_WORLD, - ], - [ - MetaTensor(DATA_2D, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])), - POINT_2D_WORLD, - None, - True, - True, - POINT_2D_IMAGE_RAS, - ], - [ - MetaTensor(DATA_3D, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])), - POINT_3D_WORLD, - None, - True, - False, - POINT_3D_IMAGE, - ], - ["affine", POINT_3D_WORLD, None, True, False, POINT_3D_IMAGE], - [ - MetaTensor(DATA_3D, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])), - MetaTensor(POINT_3D_IMAGE, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])), - None, - False, - False, - POINT_3D_WORLD, - ], - [ - MetaTensor(DATA_3D, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])), - POINT_3D_WORLD, - None, - True, - True, - POINT_3D_IMAGE_RAS, - ], + [MetaTensor(DATA_2D, affine=AFFINE_1), POINT_2D_WORLD, None, True, False, POINT_2D_IMAGE], # use image affine + [None, MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), None, False, False, POINT_2D_WORLD], # use point affine + [None, MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), AFFINE_1, False, False, POINT_2D_WORLD], # use input affine + [MetaTensor(DATA_2D, affine=AFFINE_1), POINT_2D_WORLD, None, True, True, POINT_2D_IMAGE_RAS], # test affine_lps_to_ras + [MetaTensor(DATA_3D, affine=AFFINE_2), POINT_3D_WORLD, None, True, False, POINT_3D_IMAGE], + ["affine", POINT_3D_WORLD, None, True, False, POINT_3D_IMAGE], # use refer_data itself + [MetaTensor(DATA_3D, affine=AFFINE_2), MetaTensor(POINT_3D_IMAGE, affine=AFFINE_2), None, False, False, POINT_3D_WORLD], + [MetaTensor(DATA_3D, affine=AFFINE_2), POINT_3D_WORLD, None, True, True, POINT_3D_IMAGE_RAS], + [MetaTensor(DATA_3D, affine=AFFINE_2), POINT_3D_WORLD, None, True, True, POINT_3D_IMAGE_RAS], +] +TEST_CASES_SEQUENCE = [ + [(MetaTensor(DATA_2D, affine=AFFINE_1), MetaTensor(DATA_3D, affine=AFFINE_2)), [POINT_2D_WORLD, POINT_3D_WORLD], None, True, False, ["image_1", "image_2"], [POINT_2D_IMAGE, POINT_3D_IMAGE]], # use image affine + [(MetaTensor(DATA_2D, affine=AFFINE_1), MetaTensor(DATA_3D, affine=AFFINE_2)), [POINT_2D_WORLD, POINT_3D_WORLD], None, True, True, ["image_1", "image_2"], [POINT_2D_IMAGE_RAS, POINT_3D_IMAGE_RAS]], # test affine_lps_to_ras + [(None, None), [MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), MetaTensor(POINT_3D_IMAGE, affine=AFFINE_2)], None, False, False, None, [POINT_2D_WORLD, POINT_3D_WORLD]], # use point affine + [(MetaTensor(DATA_2D, affine=AFFINE_1), MetaTensor(DATA_3D, affine=AFFINE_2)), [MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), MetaTensor(POINT_3D_IMAGE, affine=AFFINE_2)], None, False, False, ["image_1", "image_2"], [POINT_2D_WORLD, POINT_3D_WORLD]], + ] TEST_CASES_WRONG = [ - [POINT_2D_WORLD, True, None], - [POINT_2D_WORLD.unsqueeze(0), False, None], - [POINT_3D_WORLD[..., 0:1], False, None], - [POINT_3D_WORLD, False, torch.tensor([[[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]])], + [POINT_2D_WORLD, True, None, None], + [POINT_2D_WORLD.unsqueeze(0), False, None, None], + [POINT_3D_WORLD[..., 0:1], False, None, None], + [POINT_3D_WORLD, False, torch.tensor([[[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]]), None], + [POINT_3D_WORLD, False, None, "image"], + [POINT_3D_WORLD, False, None, []], ] @@ -122,11 +85,45 @@ def test_transform_coordinates(self, image, points, affine, invert_affine, affin invert_out = transform.inverse(output) self.assertTrue(torch.allclose(invert_out["point"], points)) + @parameterized.expand(TEST_CASES_SEQUENCE) + def test_transform_coordinates_sequences(self, image, points, affine, invert_affine, affine_lps_to_ras, refer_keys, expected_output): + data = { + "image_1": image[0], + "image_2": image[1], + "point_1": points[0], + "point_2": points[1], + "affine": torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]), + } + keys = ["point_1", "point_2"] + transform = ApplyTransformToPointsd( + keys=keys, + refer_keys=refer_keys, + dtype=torch.int64, + affine=affine, + invert_affine=invert_affine, + affine_lps_to_ras=affine_lps_to_ras, + ) + output = transform(data) + + self.assertTrue(torch.allclose(output["point_1"], expected_output[0])) + self.assertTrue(torch.allclose(output["point_2"], expected_output[1])) + invert_out = transform.inverse(output) + self.assertTrue(torch.allclose(invert_out["point_1"], points[0])) + @parameterized.expand(TEST_CASES_WRONG) - def test_wrong_input(self, input, invert_affine, affine): - transform = ApplyTransformToPointsd(keys="point", dtype=torch.int64, invert_affine=invert_affine, affine=affine) - with self.assertRaises(ValueError): - transform({"point": input}) + def test_wrong_input(self, input, invert_affine, affine, refer_keys): + if refer_keys == []: + with self.assertRaises(ValueError): + ApplyTransformToPointsd(keys="point", dtype=torch.int64, invert_affine=invert_affine, affine=affine, refer_keys=refer_keys) + else: + transform = ApplyTransformToPointsd(keys="point", dtype=torch.int64, invert_affine=invert_affine, affine=affine, refer_keys=refer_keys) + data = {"point": input} + if refer_keys == "image": + with self.assertRaises(KeyError): + transform(data) + else: + with self.assertRaises(ValueError): + transform(data) if __name__ == "__main__": From 96ecc1ed7730d40eb3e6ffc88a600e6281a7ae93 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Tue, 3 Sep 2024 16:53:03 +0800 Subject: [PATCH 06/10] minor fix Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- tests/test_apply_transform_to_pointsd.py | 73 ++++++++++++++++++++---- 1 file changed, 62 insertions(+), 11 deletions(-) diff --git a/tests/test_apply_transform_to_pointsd.py b/tests/test_apply_transform_to_pointsd.py index 28d0a4898c..89057d6022 100644 --- a/tests/test_apply_transform_to_pointsd.py +++ b/tests/test_apply_transform_to_pointsd.py @@ -37,19 +37,64 @@ [MetaTensor(DATA_2D, affine=AFFINE_1), POINT_2D_WORLD, None, True, False, POINT_2D_IMAGE], # use image affine [None, MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), None, False, False, POINT_2D_WORLD], # use point affine [None, MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), AFFINE_1, False, False, POINT_2D_WORLD], # use input affine - [MetaTensor(DATA_2D, affine=AFFINE_1), POINT_2D_WORLD, None, True, True, POINT_2D_IMAGE_RAS], # test affine_lps_to_ras + [ + MetaTensor(DATA_2D, affine=AFFINE_1), + POINT_2D_WORLD, + None, + True, + True, + POINT_2D_IMAGE_RAS, + ], # test affine_lps_to_ras [MetaTensor(DATA_3D, affine=AFFINE_2), POINT_3D_WORLD, None, True, False, POINT_3D_IMAGE], ["affine", POINT_3D_WORLD, None, True, False, POINT_3D_IMAGE], # use refer_data itself - [MetaTensor(DATA_3D, affine=AFFINE_2), MetaTensor(POINT_3D_IMAGE, affine=AFFINE_2), None, False, False, POINT_3D_WORLD], + [ + MetaTensor(DATA_3D, affine=AFFINE_2), + MetaTensor(POINT_3D_IMAGE, affine=AFFINE_2), + None, + False, + False, + POINT_3D_WORLD, + ], [MetaTensor(DATA_3D, affine=AFFINE_2), POINT_3D_WORLD, None, True, True, POINT_3D_IMAGE_RAS], [MetaTensor(DATA_3D, affine=AFFINE_2), POINT_3D_WORLD, None, True, True, POINT_3D_IMAGE_RAS], ] TEST_CASES_SEQUENCE = [ - [(MetaTensor(DATA_2D, affine=AFFINE_1), MetaTensor(DATA_3D, affine=AFFINE_2)), [POINT_2D_WORLD, POINT_3D_WORLD], None, True, False, ["image_1", "image_2"], [POINT_2D_IMAGE, POINT_3D_IMAGE]], # use image affine - [(MetaTensor(DATA_2D, affine=AFFINE_1), MetaTensor(DATA_3D, affine=AFFINE_2)), [POINT_2D_WORLD, POINT_3D_WORLD], None, True, True, ["image_1", "image_2"], [POINT_2D_IMAGE_RAS, POINT_3D_IMAGE_RAS]], # test affine_lps_to_ras - [(None, None), [MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), MetaTensor(POINT_3D_IMAGE, affine=AFFINE_2)], None, False, False, None, [POINT_2D_WORLD, POINT_3D_WORLD]], # use point affine - [(MetaTensor(DATA_2D, affine=AFFINE_1), MetaTensor(DATA_3D, affine=AFFINE_2)), [MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), MetaTensor(POINT_3D_IMAGE, affine=AFFINE_2)], None, False, False, ["image_1", "image_2"], [POINT_2D_WORLD, POINT_3D_WORLD]], - + [ + (MetaTensor(DATA_2D, affine=AFFINE_1), MetaTensor(DATA_3D, affine=AFFINE_2)), + [POINT_2D_WORLD, POINT_3D_WORLD], + None, + True, + False, + ["image_1", "image_2"], + [POINT_2D_IMAGE, POINT_3D_IMAGE], + ], # use image affine + [ + (MetaTensor(DATA_2D, affine=AFFINE_1), MetaTensor(DATA_3D, affine=AFFINE_2)), + [POINT_2D_WORLD, POINT_3D_WORLD], + None, + True, + True, + ["image_1", "image_2"], + [POINT_2D_IMAGE_RAS, POINT_3D_IMAGE_RAS], + ], # test affine_lps_to_ras + [ + (None, None), + [MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), MetaTensor(POINT_3D_IMAGE, affine=AFFINE_2)], + None, + False, + False, + None, + [POINT_2D_WORLD, POINT_3D_WORLD], + ], # use point affine + [ + (MetaTensor(DATA_2D, affine=AFFINE_1), MetaTensor(DATA_3D, affine=AFFINE_2)), + [MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), MetaTensor(POINT_3D_IMAGE, affine=AFFINE_2)], + None, + False, + False, + ["image_1", "image_2"], + [POINT_2D_WORLD, POINT_3D_WORLD], + ], ] TEST_CASES_WRONG = [ @@ -86,7 +131,9 @@ def test_transform_coordinates(self, image, points, affine, invert_affine, affin self.assertTrue(torch.allclose(invert_out["point"], points)) @parameterized.expand(TEST_CASES_SEQUENCE) - def test_transform_coordinates_sequences(self, image, points, affine, invert_affine, affine_lps_to_ras, refer_keys, expected_output): + def test_transform_coordinates_sequences( + self, image, points, affine, invert_affine, affine_lps_to_ras, refer_keys, expected_output + ): data = { "image_1": image[0], "image_2": image[1], @@ -114,14 +161,18 @@ def test_transform_coordinates_sequences(self, image, points, affine, invert_aff def test_wrong_input(self, input, invert_affine, affine, refer_keys): if refer_keys == []: with self.assertRaises(ValueError): - ApplyTransformToPointsd(keys="point", dtype=torch.int64, invert_affine=invert_affine, affine=affine, refer_keys=refer_keys) + ApplyTransformToPointsd( + keys="point", dtype=torch.int64, invert_affine=invert_affine, affine=affine, refer_keys=refer_keys + ) else: - transform = ApplyTransformToPointsd(keys="point", dtype=torch.int64, invert_affine=invert_affine, affine=affine, refer_keys=refer_keys) + transform = ApplyTransformToPointsd( + keys="point", dtype=torch.int64, invert_affine=invert_affine, affine=affine, refer_keys=refer_keys + ) data = {"point": input} if refer_keys == "image": with self.assertRaises(KeyError): transform(data) - else: + else: with self.assertRaises(ValueError): transform(data) From 59089b084b735037304e8aa1bf07f8484793c5c6 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Tue, 3 Sep 2024 20:11:54 +0800 Subject: [PATCH 07/10] address comments Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- tests/test_apply_transform_to_pointsd.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/test_apply_transform_to_pointsd.py b/tests/test_apply_transform_to_pointsd.py index 89057d6022..8fd20454c9 100644 --- a/tests/test_apply_transform_to_pointsd.py +++ b/tests/test_apply_transform_to_pointsd.py @@ -37,6 +37,7 @@ [MetaTensor(DATA_2D, affine=AFFINE_1), POINT_2D_WORLD, None, True, False, POINT_2D_IMAGE], # use image affine [None, MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), None, False, False, POINT_2D_WORLD], # use point affine [None, MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), AFFINE_1, False, False, POINT_2D_WORLD], # use input affine + [None, POINT_2D_WORLD, AFFINE_1, True, False, POINT_2D_IMAGE], # use input affine [ MetaTensor(DATA_2D, affine=AFFINE_1), POINT_2D_WORLD, @@ -86,6 +87,15 @@ None, [POINT_2D_WORLD, POINT_3D_WORLD], ], # use point affine + [ + (None, None), + [POINT_2D_WORLD, POINT_2D_WORLD], + AFFINE_1, + True, + False, + None, + [POINT_2D_IMAGE, POINT_2D_IMAGE], + ], # use input affine [ (MetaTensor(DATA_2D, affine=AFFINE_1), MetaTensor(DATA_3D, affine=AFFINE_2)), [MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), MetaTensor(POINT_3D_IMAGE, affine=AFFINE_2)], @@ -139,7 +149,6 @@ def test_transform_coordinates_sequences( "image_2": image[1], "point_1": points[0], "point_2": points[1], - "affine": torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]), } keys = ["point_1", "point_2"] transform = ApplyTransformToPointsd( From 1a6f9a2625fc9e3258af55f32d79b0562b8e75b4 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Tue, 3 Sep 2024 20:12:54 +0800 Subject: [PATCH 08/10] fix format Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- tests/test_apply_transform_to_pointsd.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/test_apply_transform_to_pointsd.py b/tests/test_apply_transform_to_pointsd.py index 8fd20454c9..978113931c 100644 --- a/tests/test_apply_transform_to_pointsd.py +++ b/tests/test_apply_transform_to_pointsd.py @@ -144,12 +144,7 @@ def test_transform_coordinates(self, image, points, affine, invert_affine, affin def test_transform_coordinates_sequences( self, image, points, affine, invert_affine, affine_lps_to_ras, refer_keys, expected_output ): - data = { - "image_1": image[0], - "image_2": image[1], - "point_1": points[0], - "point_2": points[1], - } + data = {"image_1": image[0], "image_2": image[1], "point_1": points[0], "point_2": points[1]} keys = ["point_1", "point_2"] transform = ApplyTransformToPointsd( keys=keys, From f471cf30ba3384fde79a8373783935aea9b0c1b8 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Wed, 4 Sep 2024 01:48:45 +0800 Subject: [PATCH 09/10] enhance readability Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/transforms/utility/array.py | 64 +++++++++++++++----------- monai/transforms/utility/dictionary.py | 4 +- 2 files changed, 38 insertions(+), 30 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index fee546bea3..43d1b2a755 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1764,6 +1764,32 @@ def __init__( self.invert_affine = invert_affine self.affine_lps_to_ras = affine_lps_to_ras + def _compute_final_affine( + self, data: torch.Tensor, affine: torch.Tensor, applied_affine: torch.Tensor + ) -> torch.Tensor: + """ + Compute the final affine transformation matrix to apply to the point data. + + Args: + data: Input coordinates assumed to be in the shape (C, N, 2 or 3). + affine: 3x3 or 4x4 affine transformation matrix. + + Returns: + Final affine transformation matrix. + """ + + affine = convert_data_type(affine, dtype=torch.float64)[0] + + if self.affine_lps_to_ras: + affine = orientation_ras_lps(affine) + + if self.invert_affine: + affine = linalg_inv(affine) + if applied_affine is not None: + affine = affine @ applied_affine + + return affine + def transform_coordinates( self, data: torch.Tensor, affine: torch.Tensor | None = None ) -> tuple[torch.Tensor, dict]: @@ -1780,35 +1806,23 @@ def transform_coordinates( Transformed coordinates. """ data = convert_to_tensor(data, track_meta=get_track_meta()) - # applied_affine is the affine transformation matrix that has already been applied to the point data - applied_affine = getattr(data, "affine", None) - if affine is None and self.invert_affine: raise ValueError("affine must be provided when invert_affine is True.") - + # applied_affine is the affine transformation matrix that has already been applied to the point data + applied_affine = getattr(data, "affine", None) affine = applied_affine if affine is None else affine - affine = convert_data_type(affine, dtype=torch.float64)[0] # always convert to float64 for affine - original_affine: torch.Tensor = affine - if self.affine_lps_to_ras: - affine = orientation_ras_lps(affine) - # the final affine transformation matrix that will be applied to the point data - _affine: torch.Tensor = affine - if self.invert_affine: - _affine = linalg_inv(affine) - if applied_affine is not None: - # consider the affine transformation already applied to the data in the world space - # and compute delta affine - _affine = _affine @ linalg_inv(applied_affine) - out = apply_affine_to_points(data, _affine, dtype=self.dtype) + final_affine = self._compute_final_affine(data, affine, applied_affine) + out = apply_affine_to_points(data, final_affine, dtype=self.dtype) extra_info = { "invert_affine": self.invert_affine, "dtype": get_dtype_string(self.dtype), - "image_affine": original_affine, # record for inverse operation + "image_affine": affine, "affine_lps_to_ras": self.affine_lps_to_ras, } - xform: torch.Tensor = original_affine if self.invert_affine else linalg_inv(original_affine) + + xform = orientation_ras_lps(linalg_inv(final_affine)) if self.affine_lps_to_ras else linalg_inv(final_affine) meta_info = TraceableTransform.track_transform_meta( data, affine=xform, extra_info=extra_info, transform_info=self.get_transform_info() ) @@ -1834,16 +1848,12 @@ def __call__(self, data: torch.Tensor, affine: torch.Tensor | None = None): def inverse(self, data: torch.Tensor) -> torch.Tensor: transform = self.pop_transform(data) - # Create inverse transform - dtype = transform[TraceKeys.EXTRA_INFO]["dtype"] - invert_affine = not transform[TraceKeys.EXTRA_INFO]["invert_affine"] - affine = transform[TraceKeys.EXTRA_INFO]["image_affine"] - affine_lps_to_ras = transform[TraceKeys.EXTRA_INFO]["affine_lps_to_ras"] inverse_transform = ApplyTransformToPoints( - dtype=dtype, invert_affine=invert_affine, affine_lps_to_ras=affine_lps_to_ras + dtype=transform[TraceKeys.EXTRA_INFO]["dtype"], + invert_affine=not transform[TraceKeys.EXTRA_INFO]["invert_affine"], + affine_lps_to_ras=transform[TraceKeys.EXTRA_INFO]["affine_lps_to_ras"], ) - # Apply inverse with inverse_transform.trace_transform(False): - data = inverse_transform(data, affine) + data = inverse_transform(data, transform[TraceKeys.EXTRA_INFO]["image_affine"]) return data diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index d50c064c1b..db5f19c0de 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -1791,9 +1791,7 @@ def __init__( allow_missing_keys: bool = False, ): MapTransform.__init__(self, keys, allow_missing_keys) - self.refer_keys = ensure_tuple_rep(None, len(self.keys)) if refer_keys is None else ensure_tuple(refer_keys) - if len(self.keys) != len(self.refer_keys): - raise ValueError("refer_keys should have the same length as keys.") + self.refer_keys = ensure_tuple_rep(refer_keys, len(self.keys)) self.converter = ApplyTransformToPoints( dtype=dtype, affine=affine, invert_affine=invert_affine, affine_lps_to_ras=affine_lps_to_ras ) From 850c082c3d267d5b95cdd3123d834c6a8e20ac47 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Wed, 4 Sep 2024 14:20:38 +0800 Subject: [PATCH 10/10] fix mypy Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/transforms/utility/array.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 43d1b2a755..bfd2f506c2 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1764,9 +1764,7 @@ def __init__( self.invert_affine = invert_affine self.affine_lps_to_ras = affine_lps_to_ras - def _compute_final_affine( - self, data: torch.Tensor, affine: torch.Tensor, applied_affine: torch.Tensor - ) -> torch.Tensor: + def _compute_final_affine(self, affine: torch.Tensor, applied_affine: torch.Tensor | None = None) -> torch.Tensor: """ Compute the final affine transformation matrix to apply to the point data. @@ -1809,10 +1807,12 @@ def transform_coordinates( if affine is None and self.invert_affine: raise ValueError("affine must be provided when invert_affine is True.") # applied_affine is the affine transformation matrix that has already been applied to the point data - applied_affine = getattr(data, "affine", None) + applied_affine: torch.Tensor | None = getattr(data, "affine", None) affine = applied_affine if affine is None else affine + if affine is None: + raise ValueError("affine must be provided if data does not have an affine matrix.") - final_affine = self._compute_final_affine(data, affine, applied_affine) + final_affine = self._compute_final_affine(affine, applied_affine) out = apply_affine_to_points(data, final_affine, dtype=self.dtype) extra_info = {