diff --git a/monai/data/meta_obj.py b/monai/data/meta_obj.py index 5061efc1ce..6aab05dc94 100644 --- a/monai/data/meta_obj.py +++ b/monai/data/meta_obj.py @@ -82,6 +82,7 @@ class MetaObj: def __init__(self): self._meta: dict = MetaObj.get_default_meta() self._applied_operations: list = MetaObj.get_default_applied_operations() + self._pending_operations: list = MetaObj.get_default_applied_operations() # the same default as applied_ops self._is_batch: bool = False @staticmethod @@ -199,6 +200,19 @@ def push_applied_operation(self, t: Any) -> None: def pop_applied_operation(self) -> Any: return self._applied_operations.pop() + @property + def pending_operations(self) -> list[dict]: + """Get the pending operations. Defaults to ``[]``.""" + if hasattr(self, "_pending_operations"): + return self._pending_operations + return MetaObj.get_default_applied_operations() # the same default as applied_ops + + def push_pending_operation(self, t: Any) -> None: + self._pending_operations.append(t) + + def pop_pending_operation(self) -> Any: + return self._pending_operations.pop() + @property def is_batch(self) -> bool: """Return whether object is part of batch or not.""" diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 5a7d81ad8e..493aef848b 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -23,8 +23,8 @@ from monai.data.meta_obj import MetaObj, get_track_meta from monai.data.utils import affine_to_spacing, decollate_batch, list_data_collate, remove_extra_metadata from monai.utils import look_up_option -from monai.utils.enums import MetaKeys, PostFix, SpaceKeys -from monai.utils.type_conversion import convert_data_type, convert_to_tensor +from monai.utils.enums import LazyAttr, MetaKeys, PostFix, SpaceKeys +from monai.utils.type_conversion import convert_data_type, convert_to_numpy, convert_to_tensor __all__ = ["MetaTensor"] @@ -445,6 +445,20 @@ def pixdim(self): return [affine_to_spacing(a) for a in self.affine] return affine_to_spacing(self.affine) + def peek_pending_shape(self): + """Get the currently expected spatial shape as if all the pending operations are executed.""" + res = None + if self.pending_operations: + res = self.pending_operations[-1].get(LazyAttr.SHAPE, None) + # default to spatial shape (assuming channel-first input) + return tuple(convert_to_numpy(self.shape, wrap_sequence=True).tolist()[1:]) if res is None else res + + def peek_pending_affine(self): + res = None + if self.pending_operations: + res = self.pending_operations[-1].get(LazyAttr.AFFINE, None) + return self.affine if res is None else res + def new_empty(self, size, dtype=None, device=None, requires_grad=False): """ must be defined for deepcopy to work diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index c5419cb9af..21d3621090 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -34,6 +34,7 @@ InterpolateMode, InverseKeys, JITMetadataKeys, + LazyAttr, LossReduction, MetaKeys, Method, diff --git a/monai/utils/enums.py b/monai/utils/enums.py index 79edbd7451..4fd9bea557 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -54,6 +54,7 @@ "AlgoEnsembleKeys", "HoVerNetMode", "HoVerNetBranch", + "LazyAttr", ] @@ -616,3 +617,16 @@ class HoVerNetBranch(StrEnum): HV = "horizontal_vertical" NP = "nucleus_prediction" NC = "type_prediction" + + +class LazyAttr(StrEnum): + """ + MetaTensor with pending operations requires some key attributes tracked especially when the primary array + is not up-to-date due to lazy evaluation. + This class specifies the set of key attributes to be tracked for each MetaTensor. + """ + + SHAPE = "lazy_shape" # spatial shape + AFFINE = "lazy_affine" + PADDING_MODE = "lazy_padding_mode" + INTERP_MODE = "lazy_interpolation_mode" diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index b46905f3c1..20d25ef61c 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -495,6 +495,15 @@ def test_construct_with_pre_applied_transforms(self): m = MetaTensor(im, applied_operations=data["im"].applied_operations) self.assertEqual(len(m.applied_operations), len(tr.transforms)) + def test_pending_ops(self): + m, _ = self.get_im() + self.assertEqual(m.pending_operations, []) + self.assertEqual(m.peek_pending_shape(), (10, 8)) + self.assertIsInstance(m.peek_pending_affine(), torch.Tensor) + m.push_pending_operation({}) + self.assertEqual(m.peek_pending_shape(), (10, 8)) + self.assertIsInstance(m.peek_pending_affine(), torch.Tensor) + @parameterized.expand(TESTS) def test_multiprocessing(self, device=None, dtype=None): """multiprocessing sharing with 'device' and 'dtype'"""