From 6efec46af3a51b952acb2acdcab06e49029ec660 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Tue, 29 Aug 2023 23:08:05 +0800 Subject: [PATCH 1/4] fix #6911 Signed-off-by: KumoLiu --- monai/transforms/spatial/dictionary.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index e8118ffda0..e87eecdfc6 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -507,6 +507,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = No should_match = np.allclose(_init_shape, d[key].peek_pending_shape()) and np.allclose( _pixdim, d[key].pixdim, atol=1e-3 ) + _pixdim = d[key].pixdim d[key] = self.spacing_transform( data_array=d[key], mode=mode, From 03fb454a0a0ab2bbc558c0efe6123e4f03ba6c41 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Tue, 29 Aug 2023 23:17:45 +0800 Subject: [PATCH 2/4] add unittests Signed-off-by: KumoLiu --- tests/test_spacingd.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/test_spacingd.py b/tests/test_spacingd.py index 78f6ad454b..36986b2706 100644 --- a/tests/test_spacingd.py +++ b/tests/test_spacingd.py @@ -83,6 +83,20 @@ *device, ) ) + TESTS.append( + ( + "interp sep", + { + "image": MetaTensor(torch.ones((2, 1, 10)), affine=torch.eye(4)), + "seg1": MetaTensor(torch.ones((2, 1, 10)), affine=torch.diag(torch.tensor([2, 2, 2, 1]))), + "seg2": MetaTensor(torch.ones((2, 1, 10)), affine=torch.eye(4)), + }, + dict(keys=("image", "seg1", "seg2"), mode=("bilinear", "nearest", "nearest"), pixdim=(1, 1, 1)), + (2, 1, 10), + torch.as_tensor(np.diag((1, 1, 1, 1))), + *device, + ) + ) TESTS_TORCH = [] for track_meta in (False, True): From 6fbe1a22c2858d5ec23fe82d7638c81c2611fcaf Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Tue, 29 Aug 2023 23:48:06 +0800 Subject: [PATCH 3/4] address comments Signed-off-by: KumoLiu --- monai/transforms/spatial/dictionary.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index e87eecdfc6..8a53e1681f 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -501,13 +501,13 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = No d, self.mode, self.padding_mode, self.align_corners, self.dtype, self.scale_extent ): if self.ensure_same_shape and isinstance(d[key], MetaTensor): - if _init_shape is None and _pixdim is None: - _init_shape, _pixdim = d[key].peek_pending_shape(), d[key].pixdim + if _init_shape is None: + _init_shape = d[key].peek_pending_shape() else: should_match = np.allclose(_init_shape, d[key].peek_pending_shape()) and np.allclose( _pixdim, d[key].pixdim, atol=1e-3 ) - _pixdim = d[key].pixdim + _pixdim = d[key].pixdim d[key] = self.spacing_transform( data_array=d[key], mode=mode, From 0dfa12b1998300b2b4bdd76e8a3712bec39f54c0 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Wed, 30 Aug 2023 08:19:02 +0800 Subject: [PATCH 4/4] address comments Signed-off-by: KumoLiu --- monai/transforms/spatial/dictionary.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 8a53e1681f..01fadcfb69 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -501,13 +501,12 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = No d, self.mode, self.padding_mode, self.align_corners, self.dtype, self.scale_extent ): if self.ensure_same_shape and isinstance(d[key], MetaTensor): - if _init_shape is None: - _init_shape = d[key].peek_pending_shape() + if _init_shape is None and _pixdim is None: + _init_shape, _pixdim = d[key].peek_pending_shape(), d[key].pixdim else: should_match = np.allclose(_init_shape, d[key].peek_pending_shape()) and np.allclose( _pixdim, d[key].pixdim, atol=1e-3 ) - _pixdim = d[key].pixdim d[key] = self.spacing_transform( data_array=d[key], mode=mode, @@ -518,7 +517,8 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = No output_spatial_shape=output_shape_k if should_match else None, lazy=lazy_, ) - output_shape_k = d[key].peek_pending_shape() if isinstance(d[key], MetaTensor) else d[key].shape[1:] + if output_shape_k is None: + output_shape_k = d[key].peek_pending_shape() if isinstance(d[key], MetaTensor) else d[key].shape[1:] return d def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: