diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index e8118ffda0..01fadcfb69 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -517,7 +517,8 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = No output_spatial_shape=output_shape_k if should_match else None, lazy=lazy_, ) - output_shape_k = d[key].peek_pending_shape() if isinstance(d[key], MetaTensor) else d[key].shape[1:] + if output_shape_k is None: + output_shape_k = d[key].peek_pending_shape() if isinstance(d[key], MetaTensor) else d[key].shape[1:] return d def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: diff --git a/tests/test_spacingd.py b/tests/test_spacingd.py index 78f6ad454b..36986b2706 100644 --- a/tests/test_spacingd.py +++ b/tests/test_spacingd.py @@ -83,6 +83,20 @@ *device, ) ) + TESTS.append( + ( + "interp sep", + { + "image": MetaTensor(torch.ones((2, 1, 10)), affine=torch.eye(4)), + "seg1": MetaTensor(torch.ones((2, 1, 10)), affine=torch.diag(torch.tensor([2, 2, 2, 1]))), + "seg2": MetaTensor(torch.ones((2, 1, 10)), affine=torch.eye(4)), + }, + dict(keys=("image", "seg1", "seg2"), mode=("bilinear", "nearest", "nearest"), pixdim=(1, 1, 1)), + (2, 1, 10), + torch.as_tensor(np.diag((1, 1, 1, 1))), + *device, + ) + ) TESTS_TORCH = [] for track_meta in (False, True):