diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 5a7d81ad8e..feb78ec375 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -24,7 +24,7 @@ from monai.data.utils import affine_to_spacing, decollate_batch, list_data_collate, remove_extra_metadata from monai.utils import look_up_option from monai.utils.enums import MetaKeys, PostFix, SpaceKeys -from monai.utils.type_conversion import convert_data_type, convert_to_tensor +from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_tensor __all__ = ["MetaTensor"] @@ -125,7 +125,7 @@ def __init__( super().__init__() # set meta if meta is not None: - self.meta = meta + self.meta = dict(meta) elif isinstance(x, MetaObj): self.__dict__ = deepcopy(x.__dict__) # set the affine @@ -150,6 +150,62 @@ def __init__( if MetaKeys.SPACE not in self.meta: self.meta[MetaKeys.SPACE] = SpaceKeys.RAS # defaulting to the right-anterior-superior space + if MetaKeys.EVALUATED not in self.meta: + self.meta[MetaKeys.EVALUATED] = True + if MetaKeys.ORIGINAL_CHANNEL_DIM not in self.meta: + self.meta[MetaKeys.ORIGINAL_CHANNEL_DIM] = "no_channel" # defaulting to channel first + + @property + def evaluated(self) -> bool: + """a flag indicating whether the array content is up-to-date with the affine/spatial_shape properties.""" + if MetaKeys.EVALUATED not in self.meta: + self.meta[MetaKeys.EVALUATED] = True + return bool(self.meta[MetaKeys.EVALUATED]) + + @evaluated.setter + def evaluated(self, value: bool): + """when setting an evaluated metatensor to a lazy status, original affine will be stored.""" + if not value and (MetaKeys.SPATIAL_SHAPE not in self.meta or MetaKeys.AFFINE not in self.meta): + warnings.warn("Setting MetaTensor to lazy evaluation requires spatial_shape and affine.") + if self.evaluated and not value: + self.meta[MetaKeys.ORIGINAL_AFFINE] = self.affine # switch to lazy evaluation, store current affine + self.meta[MetaKeys.SPATIAL_SHAPE] = self.spatial_shape + self.meta[MetaKeys.EVALUATED] = value + + def evaluate(self, mode="bilinear", padding_mode="border"): + if self.evaluated: + self.spatial_shape = self.array.shape[1:] + return + # how to ensure channel first? + resampler = monai.transforms.SpatialResample(mode=mode, padding_mode=padding_mode) + dst_affine, self.affine = self.affine, self.meta[MetaKeys.ORIGINAL_AFFINE] + with resampler.trace_transform(False): + output = resampler(self, dst_affine=dst_affine, spatial_size=self.spatial_shape, align_corners=True) + self.array = output.array + self.spatial_shape = self.array.shape[1:] + self.affine = dst_affine + self.evaluated = True + return + + @property + def spatial_shape(self): + """if spatial shape is undefined, it infers the shape from array shape and original channel dim.""" + if MetaKeys.SPATIAL_SHAPE not in self.meta: + _shape = list(self.array.shape) + channel_dim = self.meta.get(MetaKeys.ORIGINAL_CHANNEL_DIM, 0) + if _shape and channel_dim != "no_channel": + _shape.pop(int(channel_dim)) + else: + _shape = self.meta.get(MetaKeys.SPATIAL_SHAPE) + if not isinstance(_shape, torch.Tensor): + self.meta[MetaKeys.SPATIAL_SHAPE] = convert_to_tensor( + _shape, device=torch.device("cpu"), wrap_sequence=True, track_meta=False + ) + return self.meta[MetaKeys.SPATIAL_SHAPE] + + @spatial_shape.setter + def spatial_shape(self, value): + self.meta[MetaKeys.SPATIAL_SHAPE] = convert_to_dst_type(value, self.spatial_shape, wrap_sequence=True)[0] @staticmethod def update_meta(rets: Sequence, func, args, kwargs) -> Sequence: diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 7b55a993a1..66110fe24d 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -22,6 +22,7 @@ # For backwards compatibility (so this still works: from monai.transforms.compose import MapTransform) from monai.transforms.transform import ( # noqa: F401 + LazyTransform, MapTransform, Randomizable, RandomizableTransform, @@ -29,11 +30,36 @@ apply_transform, ) from monai.utils import MAX_SEED, ensure_tuple, get_seed -from monai.utils.enums import TraceKeys +from monai.utils.enums import GridSampleMode, GridSamplePadMode, TraceKeys __all__ = ["Compose", "OneOf"] +def eval_lazy_stack( + data, upcoming, lazy_resample: bool = False, mode=GridSampleMode.BILINEAR, padding_mode=GridSamplePadMode.BORDER +): + """ + Given the upcoming transform ``upcoming``, if lazy_resample is True, go through the Metatensors and + evaluate the lazy applied operations. The returned `data` will then be ready for the ``upcoming`` transform. + """ + if not lazy_resample: + return data # eager evaluation + if isinstance(data, monai.data.MetaTensor): + if lazy_resample and not isinstance(upcoming, LazyTransform): + data.evaluate(mode=mode, padding_mode=padding_mode) + return data + if isinstance(data, Mapping): + if isinstance(upcoming, MapTransform): + return { + k: eval_lazy_stack(v, upcoming, lazy_resample, mode, padding_mode) if k in upcoming.keys else v + for k, v in data.items() + } + return {k: eval_lazy_stack(v, upcoming, lazy_resample, mode, padding_mode) for k, v in data.items()} + if isinstance(data, (list, tuple)): + return [eval_lazy_stack(v, upcoming, lazy_resample, mode, padding_mode) for v in data] + return data + + class Compose(Randomizable, InvertibleTransform): """ ``Compose`` provides the ability to chain a series of callables together in @@ -111,6 +137,16 @@ class Compose(Randomizable, InvertibleTransform): log_stats: whether to log the detailed information of data and applied transform when error happened, for NumPy array and PyTorch Tensor, log the data shape and value range, for other metadata, log the values directly. default to `False`. + lazy_resample: whether to compute consecutive spatial transforms resampling lazily. Default to False. + mode: {``"bilinear"``, ``"nearest"``} + Interpolation mode when ``lazy_resample=True``. Defaults to ``"bilinear"``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + When `USE_COMPILED` is `True`, this argument uses + ``"nearest"``, ``"bilinear"``, ``"bicubic"`` to indicate 0, 1, 3 order interpolations. + See also: https://docs.monai.io/en/stable/networks.html#grid-pull + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values when ``lazy_resample=True``. Defaults to ``"border"``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html """ @@ -120,6 +156,9 @@ def __init__( map_items: bool = True, unpack_items: bool = False, log_stats: bool = False, + lazy_resample: bool = False, + mode=GridSampleMode.BILINEAR, + padding_mode=GridSamplePadMode.BORDER, ) -> None: if transforms is None: transforms = [] @@ -127,8 +166,16 @@ def __init__( self.map_items = map_items self.unpack_items = unpack_items self.log_stats = log_stats + self.lazy_resample = lazy_resample + self.mode = mode + self.padding_mode = padding_mode self.set_random_state(seed=get_seed()) + if self.lazy_resample: + for t in self.flatten().transforms: # TODO: test Compose of Compose/OneOf + if isinstance(t, LazyTransform): + t.set_eager_mode(False) + def set_random_state(self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None) -> "Compose": super().set_random_state(seed=seed, state=state) for _transform in self.transforms: @@ -171,7 +218,17 @@ def __len__(self): def __call__(self, input_): for _transform in self.transforms: + input_ = eval_lazy_stack( + input_, + upcoming=_transform, + lazy_resample=self.lazy_resample, + mode=self.mode, + padding_mode=self.padding_mode, + ) input_ = apply_transform(_transform, input_, self.map_items, self.unpack_items, self.log_stats) + input_ = eval_lazy_stack( + input_, upcoming=None, lazy_resample=self.lazy_resample, mode=self.mode, padding_mode=self.padding_mode + ) return input_ def inverse(self, data): diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 9a773c4369..1c44d28a83 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -27,7 +27,7 @@ from monai.data.meta_tensor import MetaTensor from monai.data.utils import get_random_patch, get_valid_patch_size from monai.transforms.inverse import InvertibleTransform, TraceableTransform -from monai.transforms.transform import Randomizable, Transform +from monai.transforms.transform import LazyTransform, Randomizable, Transform from monai.transforms.utils import ( compute_divisible_spatial_size, convert_pad_mode, @@ -48,6 +48,7 @@ TransformBackends, convert_data_type, convert_to_dst_type, + convert_to_numpy, convert_to_tensor, ensure_tuple, ensure_tuple_rep, @@ -77,7 +78,7 @@ ] -class Pad(InvertibleTransform): +class Pad(InvertibleTransform, LazyTransform): """ Perform padding for a given an amount of padding in each dimension. @@ -137,6 +138,14 @@ def _pt_pad(img: torch.Tensor, pad_width, mode, **kwargs) -> torch.Tensor: # torch.pad expects `[B, C, H, W, [D]]` shape return pad_pt(img.unsqueeze(0), pt_pad_width, mode=mode, **kwargs).squeeze(0) + def lazy_call(self, img: torch.Tensor, to_pad) -> torch.Tensor: + if get_track_meta() and isinstance(img, MetaTensor): + img.evaluated = False + self.update_meta(img, to_pad=to_pad) + self.push_transform(img, orig_size=img.spatial_shape, extra_info={"padded": to_pad}) + img.spatial_shape = [d + s + e for d, (s, e) in zip(img.spatial_shape, to_pad[1:])] + return img + def __call__( # type: ignore self, img: torch.Tensor, to_pad: Optional[List[Tuple[int, int]]] = None, mode: Optional[str] = None, **kwargs ) -> torch.Tensor: @@ -157,19 +166,25 @@ def __call__( # type: ignore """ to_pad_ = self.to_pad if to_pad is None else to_pad if to_pad_ is None: - to_pad_ = self.compute_pad_width(img.shape[1:]) + spatial_shape = convert_to_numpy( + img.spatial_shape if isinstance(img, MetaTensor) and not img.evaluated else img.shape[1:], + wrap_sequence=True, + ) + to_pad_ = self.compute_pad_width(spatial_shape) mode_ = self.mode if mode is None else mode kwargs_ = dict(self.kwargs) kwargs_.update(kwargs) img_t = convert_to_tensor(data=img, track_meta=get_track_meta()) - _orig_size = img_t.shape[1:] + _orig_size = img_t.spatial_shape if isinstance(img, MetaTensor) and not img.evaluated else img_t.shape[1:] # all zeros, skip padding if np.asarray(to_pad_).any(): to_pad_ = list(to_pad_) if len(to_pad_) < len(img_t.shape): to_pad_ = list(to_pad_) + [(0, 0)] * (len(img_t.shape) - len(to_pad_)) + if not self.eager_mode: + return self.lazy_call(img_t, to_pad=to_pad_) if mode_ in {"linear_ramp", "maximum", "mean", "median", "minimum", "symmetric", "empty"}: out = self._np_pad(img_t, pad_width=to_pad_, mode=mode_, **kwargs_) else: @@ -361,7 +376,7 @@ def compute_pad_width(self, spatial_shape: Sequence[int]) -> List[Tuple[int, int return spatial_pad.compute_pad_width(spatial_shape) -class Crop(InvertibleTransform): +class Crop(InvertibleTransform, LazyTransform): """ Perform crop operations on the input image. @@ -421,29 +436,36 @@ def compute_slices( else: return [slice(int(s), int(e)) for s, e in zip(roi_start_t.tolist(), roi_end_t.tolist())] + def lazy_call(self, img: torch.Tensor, slices, cropped) -> torch.Tensor: + if get_track_meta() and isinstance(img, MetaTensor): + img.evaluated = False + self.update_meta(img, slices=slices) + self.push_transform(img, orig_size=img.spatial_shape, extra_info={"cropped": cropped}) + img.spatial_shape = [s.indices(o)[1] - s.indices(o)[0] for s, o in zip(slices[1:], img.spatial_shape)] + return img + def __call__(self, img: torch.Tensor, slices: Tuple[slice, ...]) -> torch.Tensor: # type: ignore """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. """ - orig_size = img.shape[1:] + orig_size = img.spatial_shape if isinstance(img, MetaTensor) and not img.evaluated else img.shape[1:] slices_ = list(slices) - sd = len(img.shape[1:]) # spatial dims + sd = len(orig_size) # spatial dims if len(slices_) < sd: slices_ += [slice(None)] * (sd - len(slices_)) # Add in the channel (no cropping) slices = tuple([slice(None)] + slices_[:sd]) - + cropped_np = np.asarray([[s.indices(o)[0], o - s.indices(o)[1]] for s, o in zip(slices[1:], orig_size)]) + cropped = cropped_np.flatten().tolist() img_t: MetaTensor = convert_to_tensor(data=img, track_meta=get_track_meta()) - _orig_size = img_t.shape[1:] + if not self.eager_mode: + return self.lazy_call(img_t, slices, cropped) img_t = img_t[slices] # type: ignore if get_track_meta(): self.update_meta(tensor=img_t, slices=slices) - cropped_from_start = np.asarray([s.indices(o)[0] for s, o in zip(slices[1:], orig_size)]) - cropped_from_end = np.asarray(orig_size) - img_t.shape[1:] - cropped_from_start - cropped = list(chain(*zip(cropped_from_start.tolist(), cropped_from_end.tolist()))) - self.push_transform(img_t, orig_size=_orig_size, extra_info={"cropped": cropped}) + self.push_transform(img_t, orig_size=orig_size, extra_info={"cropped": cropped}) return img_t def update_meta(self, tensor: MetaTensor, slices: Tuple[slice, ...]): @@ -526,6 +548,7 @@ def __init__(self, roi_size: Union[Sequence[int], int]) -> None: self.roi_size = roi_size def compute_slices(self, spatial_size: Sequence[int]): # type: ignore + spatial_size = convert_to_numpy(spatial_size, wrap_sequence=True) roi_size = fall_back_tuple(self.roi_size, spatial_size) roi_center = [i // 2 for i in spatial_size] return super().compute_slices(roi_center=roi_center, roi_size=roi_size) @@ -536,7 +559,12 @@ def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore slicing doesn't apply to the channel dim. """ - return super().__call__(img=img, slices=self.compute_slices(img.shape[1:])) + return super().__call__( + img=img, + slices=self.compute_slices( + img.spatial_shape if isinstance(img, MetaTensor) and not img.evaluated else img.shape[1:] + ), + ) class CenterScaleCrop(Crop): @@ -553,11 +581,16 @@ def __init__(self, roi_scale: Union[Sequence[float], float]): self.roi_scale = roi_scale def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore - img_size = img.shape[1:] + img_size = img.spatial_shape if isinstance(img, MetaTensor) and not img.evaluated else img.shape[1:] ndim = len(img_size) roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)] cropper = CenterSpatialCrop(roi_size=roi_size) - return super().__call__(img=img, slices=cropper.compute_slices(img.shape[1:])) + return super().__call__( + img=img, + slices=cropper.compute_slices( + img.spatial_shape if isinstance(img, MetaTensor) and not img.evaluated else img.shape[1:] + ), + ) class RandSpatialCrop(Randomizable, Crop): @@ -616,13 +649,18 @@ def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: """ if randomize: - self.randomize(img.shape[1:]) + self.randomize(img.spatial_shape if isinstance(img, MetaTensor) and not img.evaluated else img.shape[1:]) if self._size is None: raise RuntimeError("self._size not specified.") if self.random_center: return super().__call__(img=img, slices=self._slices) cropper = CenterSpatialCrop(self._size) - return super().__call__(img=img, slices=cropper.compute_slices(img.shape[1:])) + return super().__call__( + img=img, + slices=cropper.compute_slices( + img.spatial_shape if isinstance(img, MetaTensor) and not img.evaluated else img.shape[1:] + ), + ) class RandScaleCrop(RandSpatialCrop): @@ -675,7 +713,7 @@ def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: slicing doesn't apply to the channel dim. """ - self.get_max_roi_size(img.shape[1:]) + self.get_max_roi_size(img.spatial_shape if isinstance(img, MetaTensor) and not img.evaluated else img.shape[1:]) return super().__call__(img=img, randomize=randomize) @@ -824,6 +862,10 @@ def __init__( self.k_divisible = k_divisible self.padder = Pad(mode=mode, **pad_kwargs) + def set_eager_mode(self, value): + super().set_eager_mode(True) + self.padder.set_eager_mode(True) + def compute_bounding_box(self, img: torch.Tensor): """ Compute the start points and end points of bounding box to crop. @@ -1264,7 +1306,7 @@ def __call__( return results -class ResizeWithPadOrCrop(InvertibleTransform): +class ResizeWithPadOrCrop(InvertibleTransform, LazyTransform): """ Resize an image to a target spatial size by either centrally cropping the image or padding it evenly with a user-specified mode. @@ -1299,6 +1341,11 @@ def __init__( self.padder = SpatialPad(spatial_size=spatial_size, method=method, mode=mode, **pad_kwargs) self.cropper = CenterSpatialCrop(roi_size=spatial_size) + def set_eager_mode(self, value): + super().set_eager_mode(value) + self.padder.set_eager_mode(value) + self.cropper.set_eager_mode(value) + def __call__(self, img: torch.Tensor, mode: Optional[str] = None, **pad_kwargs) -> torch.Tensor: # type: ignore """ Args: @@ -1314,7 +1361,7 @@ def __call__(self, img: torch.Tensor, mode: Optional[str] = None, **pad_kwargs) note that `np.pad` treats channel dimension as the first dimension. """ - orig_size = img.shape[1:] + orig_size = img.spatial_shape if isinstance(img, MetaTensor) and not img.evaluated else img.shape[1:] ret = self.padder(self.cropper(img), mode=mode, **pad_kwargs) # remove the individual info and combine if get_track_meta(): diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index bae6705c22..e7a427c83b 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -44,7 +44,7 @@ SpatialPad, ) from monai.transforms.inverse import InvertibleTransform -from monai.transforms.transform import MapTransform, Randomizable +from monai.transforms.transform import LazyTransform, MapTransform, Randomizable from monai.transforms.utils import is_positive from monai.utils import MAX_SEED, Method, PytorchPadMode, ensure_tuple_rep from monai.utils.deprecate_utils import deprecated_arg @@ -107,7 +107,7 @@ ] -class Padd(MapTransform, InvertibleTransform): +class Padd(MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Pad`. @@ -141,6 +141,11 @@ def __init__( self.padder = padder self.mode = ensure_tuple_rep(mode, len(self.keys)) + def set_eager_mode(self, value): + super().set_eager_mode(value) + if isinstance(self.padder, LazyTransform): + self.padder.set_eager_mode(value) + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) for key, m in self.key_iterator(d, self.mode): @@ -288,7 +293,7 @@ def __init__( super().__init__(keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys) -class Cropd(MapTransform, InvertibleTransform): +class Cropd(MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of abstract class :py:class:`monai.transforms.Crop`. @@ -306,6 +311,11 @@ def __init__(self, keys: KeysCollection, cropper: Crop, allow_missing_keys: bool super().__init__(keys, allow_missing_keys) self.cropper = cropper + def set_eager_mode(self, value): + super().set_eager_mode(value) + if isinstance(self.cropper, LazyTransform): + self.cropper.set_eager_mode(value) + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) for key in self.key_iterator(d): @@ -351,7 +361,12 @@ def randomize(self, img_size: Sequence[int]) -> None: def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) # the first key must exist to execute random operations - self.randomize(d[self.first_key(d)].shape[1:]) + first_item = d[self.first_key(d)] + self.randomize( + first_item.spatial_shape + if isinstance(first_item, MetaTensor) and not first_item.evaluated + else first_item.shape[1:] + ) for key in self.key_iterator(d): kwargs = {"randomize": False} if isinstance(self.cropper, Randomizable) else {} d[key] = self.cropper(d[key], **kwargs) # type: ignore diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index dcddefce3a..07d277c98d 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -12,6 +12,7 @@ A collection of "vanilla" transforms for spatial operations https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design """ +import math import warnings from copy import deepcopy from enum import Enum @@ -31,7 +32,7 @@ from monai.transforms.croppad.array import CenterSpatialCrop, ResizeWithPadOrCrop from monai.transforms.intensity.array import GaussianSmooth from monai.transforms.inverse import InvertibleTransform -from monai.transforms.transform import Randomizable, RandomizableTransform, Transform +from monai.transforms.transform import LazyTransform, Randomizable, RandomizableTransform, Transform from monai.transforms.utils import ( convert_pad_mode, create_control_grid, @@ -107,7 +108,7 @@ RandRange = Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] -class SpatialResample(InvertibleTransform): +class SpatialResample(InvertibleTransform, LazyTransform): """ Resample input image from the orientation/spacing defined by ``src_affine`` affine matrix into the ones specified by ``dst_affine`` affine matrix. @@ -185,6 +186,13 @@ def _post_process( def update_meta(self, img, dst_affine): img.affine = dst_affine + def lazy_call(self, img: torch.Tensor, output_shape, *args) -> torch.Tensor: + if get_track_meta() and isinstance(img, MetaTensor): + img.evaluated = False + img.spatial_shape = output_shape + return self._post_process(img, *args) + return img + @deprecated_arg( name="src_affine", since="0.9", msg_suffix="img should be `MetaTensor`, so affine can be extracted directly." ) @@ -261,6 +269,11 @@ def __call__( spatial_size, _ = compute_shape_offset(in_spatial_size, src_affine_, dst_affine) # type: ignore spatial_size = torch.tensor(fall_back_tuple(ensure_tuple(spatial_size)[:spatial_rank], in_spatial_size)) + if not self.eager_mode: + return self.lazy_call( + img, spatial_size, src_affine_, dst_affine, mode, padding_mode, align_corners, original_spatial_shape + ) + if ( allclose(src_affine_, dst_affine, atol=AFFINE_TOL) and allclose(spatial_size, in_spatial_size) @@ -420,7 +433,7 @@ def __call__( return img -class Spacing(InvertibleTransform): +class Spacing(InvertibleTransform, LazyTransform): """ Resample input image into the specified `pixdim`. """ @@ -510,6 +523,10 @@ def __init__( mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype ) + def set_eager_mode(self, value): + super().set_eager_mode(value) + self.sp_resample.set_eager_mode(value) + @deprecated_arg(name="affine", since="0.9", msg_suffix="Not needed, input should be `MetaTensor`.") def __call__( self, @@ -620,7 +637,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: return self.sp_resample.inverse(data) -class Orientation(InvertibleTransform): +class Orientation(InvertibleTransform, LazyTransform): """ Change the input image's orientation into the specified based on `axcodes`. """ @@ -661,6 +678,14 @@ def __init__( self.as_closest_canonical = as_closest_canonical self.labels = labels + def lazy_call(self, img: torch.Tensor, new_affine, original_affine, ordering) -> torch.Tensor: + if get_track_meta() and isinstance(img, MetaTensor): + img.evaluated = False + self.update_meta(img, new_affine) + self.push_transform(img, extra_info={"original_affine": original_affine}) + img.spatial_shape = img.spatial_shape[[i - 1 for i in ordering if i != 0]] + return img + def __call__(self, data_array: torch.Tensor) -> torch.Tensor: """ If input type is `MetaTensor`, original affine is extracted with `data_array.affine`. @@ -720,15 +745,16 @@ def __call__(self, data_array: torch.Tensor) -> torch.Tensor: spatial_ornt[:, 0] += 1 # skip channel dim spatial_ornt = np.concatenate([np.array([[0, 1]]), spatial_ornt]) axes = [ax for ax, flip in enumerate(spatial_ornt[:, 1]) if flip == -1] - if axes: - data_array = torch.flip(data_array, dims=axes) full_transpose = np.arange(len(data_array.shape)) full_transpose[: len(spatial_ornt)] = np.argsort(spatial_ornt[:, 0]) - if not np.all(full_transpose == np.arange(len(data_array.shape))): - data_array = data_array.permute(full_transpose.tolist()) - new_affine = to_affine_nd(affine_np, new_affine) new_affine, *_ = convert_data_type(new_affine, torch.Tensor, dtype=torch.float32, device=data_array.device) + if not self.eager_mode: + return self.lazy_call(data_array, new_affine, affine_np, full_transpose) + if axes: + data_array = torch.flip(data_array, dims=axes) + if not np.all(full_transpose == np.arange(len(data_array.shape))): + data_array = data_array.permute(full_transpose.tolist()) if get_track_meta(): self.update_meta(data_array, new_affine) @@ -751,7 +777,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: return data -class Flip(InvertibleTransform): +class Flip(InvertibleTransform, LazyTransform): """ Reverses the order of elements along the given spatial axis. Preserves shape. See `torch.flip` documentation for additional details: @@ -783,6 +809,15 @@ def update_meta(self, img, shape, axes): def forward_image(self, img, axes) -> torch.Tensor: return torch.flip(img, axes) + def lazy_call(self, img: torch.Tensor, axes) -> torch.Tensor: + if get_track_meta() and isinstance(img, MetaTensor): + img.evaluated = False + spatial_chn_shape = [1, *convert_to_numpy(img.spatial_shape).tolist()] + self.update_meta(img, spatial_chn_shape, axes) + self.push_transform(img) + img.spatial_shape = img.spatial_shape + return img + def __call__(self, img: torch.Tensor) -> torch.Tensor: """ Args: @@ -790,6 +825,8 @@ def __call__(self, img: torch.Tensor) -> torch.Tensor: """ img = convert_to_tensor(img, track_meta=get_track_meta()) axes = map_spatial_axes(img.ndim, self.spatial_axis) + if not self.eager_mode: + return self.lazy_call(img, axes) out = self.forward_image(img, axes) if get_track_meta(): self.update_meta(out, out.shape, axes) @@ -803,7 +840,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: return flipper(data) -class Resize(InvertibleTransform): +class Resize(InvertibleTransform, LazyTransform): """ Resize the input image to given spatial size (with scaling, not cropping/padding). Implemented using :py:class:`torch.nn.functional.interpolate`. @@ -899,20 +936,24 @@ def __call__( "len(spatial_size) must be greater or equal to img spatial dimensions, " f"got spatial_size={output_ndim} img={input_ndim}." ) - spatial_size_ = fall_back_tuple(self.spatial_size, img.shape[1:]) + img_size = img.spatial_shape if isinstance(img, MetaTensor) and not img.evaluated else img.shape[1:] + spatial_size_ = fall_back_tuple(self.spatial_size, img_size) else: # for the "longest" mode - img_size = img.shape[1:] + img_size = img.spatial_shape if isinstance(img, MetaTensor) and not img.evaluated else img.shape[1:] if not isinstance(self.spatial_size, int): raise ValueError("spatial_size must be an int number if size_mode is 'longest'.") scale = self.spatial_size / max(img_size) spatial_size_ = tuple(int(round(s * scale)) for s in img_size) - original_sp_size = img.shape[1:] _mode = look_up_option(self.mode if mode is None else mode, InterpolateMode) _align_corners = self.align_corners if align_corners is None else align_corners - if tuple(img.shape[1:]) == spatial_size_: # spatial shape is already the desired - img = convert_to_tensor(img, track_meta=get_track_meta()) - + img = convert_to_tensor(img, track_meta=get_track_meta()) # type: ignore + if not self.eager_mode: + if anti_aliasing: + raise ValueError("anti-aliasing is not compatible with lazy evaluation.") + return self.lazy_call(img, spatial_size_, _mode, _align_corners, input_ndim) + original_sp_size = img.spatial_shape if isinstance(img, MetaTensor) and not img.evaluated else img.shape[1:] + if tuple(convert_to_numpy(original_sp_size)) == spatial_size_: # spatial shape is already the desired return self._post_process(img, original_sp_size, spatial_size_, _mode, _align_corners, input_ndim) img_ = convert_to_tensor(img, dtype=torch.float, track_meta=False) @@ -929,7 +970,6 @@ def __call__( anti_aliasing_filter = GaussianSmooth(sigma=anti_aliasing_sigma) img_ = convert_to_tensor(anti_aliasing_filter(img_), track_meta=False) - img = convert_to_tensor(img, track_meta=get_track_meta()) resized = torch.nn.functional.interpolate( input=img_.unsqueeze(0), size=spatial_size_, mode=_mode, align_corners=_align_corners ) @@ -950,6 +990,13 @@ def _post_process(self, img: torch.Tensor, orig_size, sp_size, mode, align_corne ) return img + def lazy_call(self, img: torch.Tensor, sp_size, *args) -> torch.Tensor: + if get_track_meta() and isinstance(img, MetaTensor): + img.evaluated = False + img = self._post_process(img, img.spatial_shape, sp_size, *args) + img.spatial_shape = sp_size # type: ignore + return img + def update_meta(self, img, spatial_size, new_spatial_size): affine = convert_to_tensor(img.affine, track_meta=False) img.affine = scale_affine(affine, spatial_size, new_spatial_size) @@ -972,7 +1019,7 @@ def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor: return data -class Rotate(InvertibleTransform): +class Rotate(InvertibleTransform, LazyTransform): """ Rotates an input image by given angle using :py:class:`monai.networks.layers.AffineTransform`. @@ -1044,7 +1091,7 @@ def __call__( img = convert_to_tensor(img, track_meta=get_track_meta()) _dtype = get_equivalent_dtype(dtype or self.dtype or img.dtype, torch.Tensor) - im_shape = np.asarray(img.shape[1:]) # spatial dimensions + im_shape = np.asarray(img.spatial_shape if isinstance(img, MetaTensor) and not img.evaluated else img.shape[1:]) input_ndim = len(im_shape) if input_ndim not in (2, 3): raise ValueError(f"Unsupported image dimension: {input_ndim}, available options are [2, 3].") @@ -1067,6 +1114,8 @@ def __call__( _mode = look_up_option(mode or self.mode, GridSampleMode) _padding_mode = look_up_option(padding_mode or self.padding_mode, GridSamplePadMode) _align_corners = self.align_corners if align_corners is None else align_corners + if not self.eager_mode: + return self.lazy_call(img, output_shape, transform_t, _mode, _padding_mode, _align_corners, _dtype) xform = AffineTransform( normalized=False, mode=_mode, @@ -1076,20 +1125,32 @@ def __call__( ) output: torch.Tensor = xform(img_t.unsqueeze(0), transform_t, spatial_size=output_shape).float().squeeze(0) out, *_ = convert_to_dst_type(output, dst=img, dtype=output.dtype) + return self._post_process(out, im_shape, transform_t, _mode, _padding_mode, _align_corners, _dtype) + + def _post_process( + self, img: torch.Tensor, orig_size, transform_t, mode, padding_mode, align_corners, dtype + ) -> torch.Tensor: if get_track_meta(): - self.update_meta(out, transform_t) + self.update_meta(img, transform_t) self.push_transform( - out, - orig_size=img_t.shape[1:], + img, + orig_size=orig_size, extra_info={ - "rot_mat": transform, - "mode": _mode, - "padding_mode": _padding_mode, - "align_corners": _align_corners if _align_corners is not None else TraceKeys.NONE, - "dtype": str(_dtype)[6:], # dtype as string; remove "torch": torch.float32 -> float32 + "rot_mat": transform_t, + "mode": mode, + "padding_mode": padding_mode, + "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, + "dtype": str(dtype)[6:], # dtype as string; remove "torch": torch.float32 -> float32 }, ) - return out + return img + + def lazy_call(self, img: torch.Tensor, output_shape, transform_t, *args) -> torch.Tensor: + if get_track_meta() and isinstance(img, MetaTensor): + img.evaluated = False + img = self._post_process(img, img.spatial_shape, transform_t, *args) + img.spatial_shape = output_shape # type: ignore + return img def update_meta(self, img, rotate_mat): affine = convert_to_tensor(img.affine, track_meta=False) @@ -1125,7 +1186,7 @@ def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor: return out -class Zoom(InvertibleTransform): +class Zoom(InvertibleTransform, LazyTransform): """ Zooms an ND image using :py:class:`torch.nn.functional.interpolate`. For details, please see https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html. @@ -1208,6 +1269,13 @@ def __call__( _align_corners = self.align_corners if align_corners is None else align_corners _padding_mode = padding_mode or self.padding_mode + if not self.eager_mode and isinstance(img, MetaTensor): + if self.keep_size: + raise NotImplementedError("keep_size=True is currently not compatible with lazy evaluation") + else: + output_size = [int(math.floor(float(i) * z)) for i, z in zip(img.spatial_shape, _zoom)] + return self.lazy_call(img, output_size, _mode, _align_corners) + zoomed: NdarrayOrTensor = torch.nn.functional.interpolate( recompute_scale_factor=True, input=img_t.unsqueeze(0), @@ -1239,6 +1307,23 @@ def __call__( ) return out + def lazy_call(self, img: torch.Tensor, zoom_size, mode, align_corners) -> torch.Tensor: + if get_track_meta() and isinstance(img, MetaTensor): + img.evaluated = False + self.update_meta(img, img.spatial_shape, zoom_size) + self.push_transform( + img, + orig_size=img.spatial_shape, + extra_info={ + "mode": mode, + "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, + "do_padcrop": False, + "padcrop": {}, + }, + ) + img.spatial_shape = zoom_size # type: ignore + return img + def update_meta(self, img, spatial_size, new_spatial_size): affine = convert_to_tensor(img.affine, track_meta=False) img.affine = scale_affine(affine, spatial_size, new_spatial_size) @@ -1268,7 +1353,7 @@ def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor: return out -class Rotate90(InvertibleTransform): +class Rotate90(InvertibleTransform, LazyTransform): """ Rotate an array by 90 degrees in the plane specified by `axes`. See `torch.rot90` for additional details: @@ -1286,7 +1371,7 @@ def __init__(self, k: int = 1, spatial_axes: Tuple[int, int] = (0, 1)) -> None: Default: (0, 1), this is the first two axis in spatial dimensions. If axis is negative it counts from the last to the first axis. """ - self.k = k + self.k = (4 + (k % 4)) % 4 # 0, 1, 2, 3 spatial_axes_: Tuple[int, int] = ensure_tuple(spatial_axes) # type: ignore if len(spatial_axes_) != 2: raise ValueError("spatial_axes must be 2 int numbers to indicate the axes to rotate 90 degrees.") @@ -1299,7 +1384,9 @@ def __call__(self, img: torch.Tensor) -> torch.Tensor: """ img = convert_to_tensor(img, track_meta=get_track_meta()) axes = map_spatial_axes(img.ndim, self.spatial_axes) - ori_shape = img.shape[1:] + ori_shape = img.spatial_shape if isinstance(img, MetaTensor) and not img.evaluated else img.shape[1:] + if not self.eager_mode: + return self.lazy_call(img, axes, self.k) out: NdarrayOrTensor = torch.rot90(img, self.k, axes) out = convert_to_dst_type(out, img)[0] if get_track_meta(): @@ -1307,6 +1394,19 @@ def __call__(self, img: torch.Tensor) -> torch.Tensor: self.push_transform(out, extra_info={"axes": [d - 1 for d in axes], "k": self.k}) # compensate spatial dim return out + def lazy_call(self, img: torch.Tensor, axes, k) -> torch.Tensor: + if get_track_meta() and isinstance(img, MetaTensor): + img.evaluated = False + ori_shape = img.spatial_shape.cpu().tolist() + output_shape = img.spatial_shape.cpu().tolist() + if k in (1, 3): + a_0, a_1 = axes[0] - 1, axes[1] - 1 + output_shape[a_0], output_shape[a_1] = ori_shape[a_1], ori_shape[a_0] + self.update_meta(img, ori_shape, output_shape, axes, k) + self.push_transform(img, extra_info={"axes": [d - 1 for d in axes], "k": k}) + img.spatial_shape = output_shape + return img + def update_meta(self, img, spatial_size, new_spatial_size, axes, k): affine = convert_data_type(img.affine, torch.Tensor)[0] r, sp_r = len(affine) - 1, len(spatial_size) @@ -1337,7 +1437,7 @@ def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor: return xform(data) -class RandRotate90(RandomizableTransform, InvertibleTransform): +class RandRotate90(RandomizableTransform, InvertibleTransform, LazyTransform): """ With probability `prob`, input arrays are rotated by 90 degrees in the plane specified by `spatial_axes`. @@ -1376,7 +1476,9 @@ def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: self.randomize() if self._do_transform: - out = Rotate90(self._rand_k, self.spatial_axes)(img) + xform = Rotate90(self._rand_k, self.spatial_axes) + xform.set_eager_mode(self.eager_mode) + out = xform(img) else: out = convert_to_tensor(img, track_meta=get_track_meta()) @@ -1393,7 +1495,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: return Rotate90().inverse_transform(data, rotate_xform) -class RandRotate(RandomizableTransform, InvertibleTransform): +class RandRotate(RandomizableTransform, InvertibleTransform, LazyTransform): """ Randomly rotate the input arrays. @@ -1503,6 +1605,7 @@ def __call__( align_corners=self.align_corners if align_corners is None else align_corners, dtype=dtype or self.dtype or img.dtype, ) + rotator.set_eager_mode(self.eager_mode) out = rotator(img) else: out = convert_to_tensor(img, track_meta=get_track_meta(), dtype=torch.float32) @@ -1518,7 +1621,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: return Rotate(0).inverse_transform(data, xform_info[TraceKeys.EXTRA_INFO]) -class RandFlip(RandomizableTransform, InvertibleTransform): +class RandFlip(RandomizableTransform, InvertibleTransform, LazyTransform): """ Randomly flips the image along axes. Preserves shape. See numpy.flip for additional details. @@ -1535,6 +1638,10 @@ def __init__(self, prob: float = 0.1, spatial_axis: Optional[Union[Sequence[int] RandomizableTransform.__init__(self, prob) self.flipper = Flip(spatial_axis=spatial_axis) + def set_eager_mode(self, value): + super().set_eager_mode(value) + self.flipper.set_eager_mode(value) + def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: """ Args: @@ -1558,7 +1665,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: return self.flipper.inverse(data) -class RandAxisFlip(RandomizableTransform, InvertibleTransform): +class RandAxisFlip(RandomizableTransform, InvertibleTransform, LazyTransform): """ Randomly select a spatial axis and flip along it. See numpy.flip for additional details. @@ -1576,6 +1683,10 @@ def __init__(self, prob: float = 0.1) -> None: self._axis: Optional[int] = None self.flipper = Flip(spatial_axis=self._axis) + def set_eager_mode(self, value): + super().set_eager_mode(value) + self.flipper.set_eager_mode(value) + def randomize(self, data: NdarrayOrTensor) -> None: super().randomize(None) if not self._do_transform: @@ -1611,7 +1722,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: return flipper(data) -class RandZoom(RandomizableTransform, InvertibleTransform): +class RandZoom(RandomizableTransform, InvertibleTransform, LazyTransform): """ Randomly zooms input arrays with given probability within given zoom range. @@ -1718,14 +1829,16 @@ def __call__( if not self._do_transform: out = convert_to_tensor(img, track_meta=get_track_meta(), dtype=torch.float32) else: - out = Zoom( + xform = Zoom( self._zoom, keep_size=self.keep_size, mode=look_up_option(mode or self.mode, InterpolateMode), padding_mode=padding_mode or self.padding_mode, align_corners=self.align_corners if align_corners is None else align_corners, **self.kwargs, - )(img) + ) + xform.set_eager_mode(self.eager_mode) + out = xform(img) if get_track_meta(): z_info = self.pop_transform(out, check=False) if self._do_transform else {} self.push_transform(out, extra_info=z_info) @@ -1738,7 +1851,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: return Zoom(self._zoom).inverse_transform(data, xform_info[TraceKeys.EXTRA_INFO]) -class AffineGrid(Transform): +class AffineGrid(Transform, LazyTransform): """ Affine transforms on the coordinates. @@ -1763,8 +1876,7 @@ class AffineGrid(Transform): If ``None``, use the data type of input data (if `grid` is provided). device: device on which the tensor will be allocated, if a new grid is generated. affine: If applied, ignore the params (`rotate_params`, etc.) and use the - supplied matrix. Should be square with each side = num of image spatial - dimensions + 1. + supplied matrix. Should be square with each side = num of image spatial dimensions + 1. """ @@ -1790,7 +1902,7 @@ def __init__( def __call__( self, spatial_size: Optional[Sequence[int]] = None, grid: Optional[torch.Tensor] = None - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[Optional[torch.Tensor], torch.Tensor]: """ The grid can be initialized with a `spatial_size` parameter, or provided directly as `grid`. Therefore, either `spatial_size` or `grid` must be provided. @@ -1804,19 +1916,23 @@ def __call__( ValueError: When ``grid=None`` and ``spatial_size=None``. Incompatible values. """ - if grid is None: # create grid from spatial_size - if spatial_size is None: - raise ValueError("Incompatible values: grid=None and spatial_size=None.") - grid_ = create_grid(spatial_size, device=self.device, backend="torch", dtype=self.dtype) + if self.eager_mode: + if grid is None: # create grid from spatial_size + if spatial_size is None: + raise ValueError("Incompatible values: grid=None and spatial_size=None.") + grid_ = create_grid(spatial_size, device=self.device, backend="torch", dtype=self.dtype) + else: + grid_ = grid + _dtype = self.dtype or grid_.dtype + grid_: torch.Tensor = convert_to_tensor(grid_, dtype=_dtype, track_meta=get_track_meta()) # type: ignore + _device = grid_.device # type: ignore + spatial_dims = len(grid_.shape) - 1 else: - grid_ = grid - _dtype = self.dtype or grid_.dtype - grid_: torch.Tensor = convert_to_tensor(grid_, dtype=_dtype, track_meta=get_track_meta()) # type: ignore + _device = self.device + spatial_dims = len(spatial_size) # type: ignore _b = TransformBackends.TORCH - _device = grid_.device # type: ignore affine: NdarrayOrTensor if self.affine is None: - spatial_dims = len(grid_.shape) - 1 affine = torch.eye(spatial_dims + 1, device=_device) if self.rotate_params: affine = affine @ create_rotate(spatial_dims, self.rotate_params, device=_device, backend=_b) @@ -1828,6 +1944,8 @@ def __call__( affine = affine @ create_scale(spatial_dims, self.scale_params, device=_device, backend=_b) else: affine = self.affine + if not self.eager_mode: + return None, affine # type: ignore affine = to_affine_nd(len(grid_) - 1, affine) affine = convert_to_tensor(affine, device=grid_.device, dtype=grid_.dtype, track_meta=False) # type: ignore @@ -1835,7 +1953,7 @@ def __call__( return grid_, affine # type: ignore -class RandAffineGrid(Randomizable, Transform): +class RandAffineGrid(Randomizable, Transform, LazyTransform): """ Generate randomised affine grid. @@ -1850,6 +1968,7 @@ def __init__( translate_range: RandRange = None, scale_range: RandRange = None, device: Optional[torch.device] = None, + dtype: DtypeLike = np.float32, ) -> None: """ Args: @@ -1876,6 +1995,8 @@ def __init__( the scale factor to translate for every spatial dims. A value of 1.0 is added to the result. This allows 0 to correspond to no change (i.e., a scaling of 1.0). device: device to store the output grid data. + dtype: data type for the grid computation. Defaults to ``np.float32``. + If ``None``, use the data type of input data (if `grid` is provided). See also: - :py:meth:`monai.transforms.utils.create_rotate` @@ -1895,6 +2016,7 @@ def __init__( self.scale_params: Optional[List[float]] = None self.device = device + self.dtype = dtype self.affine: Optional[torch.Tensor] = torch.eye(4, dtype=torch.float64) def _get_rand_param(self, param_range, add_scalar: float = 0.0): @@ -1937,7 +2059,11 @@ def __call__( translate_params=self.translate_params, scale_params=self.scale_params, device=self.device, + dtype=self.dtype, ) + affine_grid.set_eager_mode(self.eager_mode) + if not self.eager_mode: # return the affine only, don't construct the grid + return affine_grid(spatial_size, grid)[1] # type: ignore _grid: torch.Tensor _grid, self.affine = affine_grid(spatial_size, grid) # type: ignore return _grid @@ -2154,7 +2280,7 @@ def __call__( return out_val -class Affine(InvertibleTransform): +class Affine(InvertibleTransform, LazyTransform): """ Transform ``img`` given the affine parameters. A tutorial is available: https://github.com/Project-MONAI/tutorials/blob/0.6.0/modules/transforms_demo_2d.ipynb. @@ -2253,6 +2379,10 @@ def __init__( self.mode = mode self.padding_mode: str = padding_mode + def set_eager_mode(self, value): + super().set_eager_mode(value) + self.affine_grid.set_eager_mode(value) + def __call__( self, img: torch.Tensor, @@ -2311,6 +2441,17 @@ def update_meta(self, img, mat, img_size, sp_size): affine = convert_data_type(img.affine, torch.Tensor)[0] img.affine = Affine.compute_w_affine(affine, mat, img_size, sp_size) + def lazy_call(self, img: torch.Tensor, affine, output_size, mode, padding_mode) -> torch.Tensor: + if get_track_meta() and isinstance(img, MetaTensor): + img.evaluated = False + orig_size = img.spatial_shape + self.update_meta(img, affine, orig_size, output_size) + self.push_transform( + img, orig_size=orig_size, extra_info={"affine": affine, "mode": mode, "padding_mode": padding_mode} + ) + img.spatial_shape = output_size + return img + def inverse(self, data: torch.Tensor) -> torch.Tensor: transform = self.pop_transform(data) orig_size = transform[TraceKeys.ORIG_SIZE] @@ -2332,7 +2473,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: return out -class RandAffine(RandomizableTransform, InvertibleTransform): +class RandAffine(RandomizableTransform, InvertibleTransform, LazyTransform): """ Random affine transform. A tutorial is available: https://github.com/Project-MONAI/tutorials/blob/0.6.0/modules/transforms_demo_2d.ipynb. @@ -2425,10 +2566,16 @@ def __init__( self.mode = mode self.padding_mode: str = padding_mode + def set_eager_mode(self, value): + super().set_eager_mode(value) + self.rand_affine_grid.set_eager_mode(value) + def _init_identity_cache(self): """ Create cache of the identity grid if cache_grid=True and spatial_size is known. """ + if not self.eager_mode: + return None if self.spatial_size is None: if self.cache_grid: warnings.warn( @@ -2454,6 +2601,8 @@ def get_identity_grid(self, spatial_size: Sequence[int]): Args: spatial_size: non-dynamic spatial size """ + if not self.eager_mode: + return None ndim = len(spatial_size) if spatial_size != fall_back_tuple(spatial_size, [1] * ndim) or spatial_size != fall_back_tuple( spatial_size, [2] * ndim @@ -2520,6 +2669,12 @@ def __call__( _mode = mode if mode is not None else self.mode _padding_mode = padding_mode if padding_mode is not None else self.padding_mode img = convert_to_tensor(img, track_meta=get_track_meta()) + if not self.eager_mode: + if self._do_transform: + affine = self.rand_affine_grid(sp_size, grid=grid, randomize=randomize) + else: + affine = convert_to_dst_type(torch.eye(len(sp_size) + 1), img, dtype=self.rand_affine_grid.dtype)[0] + return self.lazy_call(img, affine, sp_size, _mode, _padding_mode, do_resampling) if not do_resampling: out: torch.Tensor = convert_data_type(img, dtype=torch.float32, device=self.resampler.device)[0] else: @@ -2548,6 +2703,24 @@ def update_meta(self, img, mat, img_size, sp_size): affine = convert_data_type(img.affine, torch.Tensor)[0] img.affine = Affine.compute_w_affine(affine, mat, img_size, sp_size) + def lazy_call(self, img: torch.Tensor, affine, output_size, mode, padding_mode, do_resampling) -> torch.Tensor: + if get_track_meta() and isinstance(img, MetaTensor): + img.evaluated = False + orig_size = img.spatial_shape + self.update_meta(img, affine, orig_size, output_size) + self.push_transform( + img, + orig_size=orig_size, + extra_info={ + "affine": affine, + "mode": mode, + "padding_mode": padding_mode, + "do_resampling": do_resampling, + }, + ) + img.spatial_shape = output_size + return img + def inverse(self, data: torch.Tensor) -> torch.Tensor: transform = self.pop_transform(data) # if transform was not performed nothing to do. diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 706e8d7f8b..7928830740 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -50,7 +50,7 @@ SpatialResample, Zoom, ) -from monai.transforms.transform import MapTransform, RandomizableTransform +from monai.transforms.transform import LazyTransform, MapTransform, RandomizableTransform from monai.transforms.utils import create_grid from monai.utils import ( GridSampleMode, @@ -139,7 +139,7 @@ ] -class SpatialResampled(MapTransform, InvertibleTransform): +class SpatialResampled(MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.SpatialResample`. @@ -207,6 +207,10 @@ def __init__( self.dtype = ensure_tuple_rep(dtype, len(self.keys)) self.dst_keys = ensure_tuple_rep(dst_keys, len(self.keys)) + def set_eager_mode(self, value): + super().set_eager_mode(value) + self.sp_transform.set_eager_mode(value) + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d: Dict = dict(data) for (key, mode, padding_mode, align_corners, dtype, dst_key) in self.key_iterator( @@ -230,7 +234,7 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch return d -class ResampleToMatchd(MapTransform, InvertibleTransform): +class ResampleToMatchd(MapTransform, InvertibleTransform, LazyTransform): """Dictionary-based wrapper of :py:class:`monai.transforms.ResampleToMatch`.""" backend = ResampleToMatch.backend @@ -282,6 +286,10 @@ def __init__( self.dtype = ensure_tuple_rep(dtype, len(self.keys)) self.resampler = ResampleToMatch() + def set_eager_mode(self, value): + super().set_eager_mode(value) + self.resampler.set_eager_mode(value) + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) for (key, mode, padding_mode, align_corners, dtype) in self.key_iterator( @@ -304,7 +312,7 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch return d -class Spacingd(MapTransform, InvertibleTransform): +class Spacingd(MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Spacing`. @@ -407,6 +415,10 @@ def __init__( self.dtype = ensure_tuple_rep(dtype, len(self.keys)) self.scale_extent = ensure_tuple_rep(scale_extent, len(self.keys)) + def set_eager_mode(self, value): + super().set_eager_mode(value) + self.spacing_transform.set_eager_mode(value) + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d: Dict = dict(data) for key, mode, padding_mode, align_corners, dtype, scale_extent in self.key_iterator( @@ -430,7 +442,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd return d -class Orientationd(MapTransform, InvertibleTransform): +class Orientationd(MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Orientation`. @@ -473,6 +485,10 @@ def __init__( super().__init__(keys, allow_missing_keys) self.ornt_transform = Orientation(axcodes=axcodes, as_closest_canonical=as_closest_canonical, labels=labels) + def set_eager_mode(self, value): + super().set_eager_mode(value) + self.ornt_transform.set_eager_mode(value) + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d: Dict = dict(data) for key in self.key_iterator(d): @@ -486,7 +502,7 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch return d -class Rotate90d(MapTransform, InvertibleTransform): +class Rotate90d(MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Rotate90`. """ @@ -506,6 +522,10 @@ def __init__( super().__init__(keys, allow_missing_keys) self.rotator = Rotate90(k, spatial_axes) + def set_eager_mode(self, value): + super().set_eager_mode(value) + self.rotator.set_eager_mode(value) + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) for key in self.key_iterator(d): @@ -519,7 +539,7 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch return d -class RandRotate90d(RandomizableTransform, MapTransform, InvertibleTransform): +class RandRotate90d(RandomizableTransform, MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based version :py:class:`monai.transforms.RandRotate90`. With probability `prob`, input arrays are rotated by 90 degrees @@ -567,6 +587,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Mapping[Hashable, t # FIXME: here we didn't use array version `RandRotate90` transform as others, because we need # to be compatible with the random status of some previous integration tests rotator = Rotate90(self._rand_k, self.spatial_axes) + rotator.set_eager_mode(self.eager_mode) for key in self.key_iterator(d): d[key] = rotator(d[key]) if self._do_transform else convert_to_tensor(d[key], track_meta=get_track_meta()) if get_track_meta(): @@ -585,7 +606,7 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch return d -class Resized(MapTransform, InvertibleTransform): +class Resized(MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Resize`. @@ -641,6 +662,10 @@ def __init__( self.anti_aliasing_sigma = ensure_tuple_rep(anti_aliasing_sigma, len(self.keys)) self.resizer = Resize(spatial_size=spatial_size, size_mode=size_mode) + def set_eager_mode(self, value): + super().set_eager_mode(value) + self.resizer.set_eager_mode(value) + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) for key, mode, align_corners, anti_aliasing, anti_aliasing_sigma in self.key_iterator( @@ -662,7 +687,7 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch return d -class Affined(MapTransform, InvertibleTransform): +class Affined(MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Affine`. """ @@ -751,6 +776,10 @@ def __init__( self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) + def set_eager_mode(self, value): + super().set_eager_mode(value) + self.affine.set_eager_mode(value) + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): @@ -764,7 +793,7 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch return d -class RandAffined(RandomizableTransform, MapTransform, InvertibleTransform): +class RandAffined(RandomizableTransform, MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.RandAffine`. """ @@ -859,6 +888,10 @@ def __init__( self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) + def set_eager_mode(self, value): + super().set_eager_mode(value) + self.rand_affine.set_eager_mode(value) + def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None ) -> "RandAffined": @@ -877,8 +910,8 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N # all the keys share the same random Affine factor self.rand_affine.randomize() - spatial_size = d[first_key].shape[1:] - + item = d[first_key] + spatial_size = item.spatial_shape if isinstance(item, MetaTensor) and not item.evaluated else item.shape[1:] # type: ignore sp_size = fall_back_tuple(self.rand_affine.spatial_size, spatial_size) # change image size or do random transform do_resampling = self._do_transform or (sp_size != ensure_tuple(spatial_size)) @@ -887,7 +920,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N if do_resampling: # need to prepare grid grid = self.rand_affine.get_identity_grid(sp_size) if self._do_transform: # add some random factors - grid = self.rand_affine.rand_affine_grid(grid=grid) + grid = self.rand_affine.rand_affine_grid(sp_size, grid=grid) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): # do the transform @@ -1188,7 +1221,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc return d -class Flipd(MapTransform, InvertibleTransform): +class Flipd(MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Flip`. @@ -1212,6 +1245,10 @@ def __init__( super().__init__(keys, allow_missing_keys) self.flipper = Flip(spatial_axis=spatial_axis) + def set_eager_mode(self, value): + super().set_eager_mode(value) + self.flipper.set_eager_mode(value) + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) for key in self.key_iterator(d): @@ -1225,7 +1262,7 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch return d -class RandFlipd(RandomizableTransform, MapTransform, InvertibleTransform): +class RandFlipd(RandomizableTransform, MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based version :py:class:`monai.transforms.RandFlip`. @@ -1252,6 +1289,10 @@ def __init__( RandomizableTransform.__init__(self, prob) self.flipper = Flip(spatial_axis=spatial_axis) + def set_eager_mode(self, value): + super().set_eager_mode(value) + self.flipper.set_eager_mode(value) + def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None ) -> "RandFlipd": @@ -1283,7 +1324,7 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch return d -class RandAxisFlipd(RandomizableTransform, MapTransform, InvertibleTransform): +class RandAxisFlipd(RandomizableTransform, MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based version :py:class:`monai.transforms.RandAxisFlip`. @@ -1304,6 +1345,10 @@ def __init__(self, keys: KeysCollection, prob: float = 0.1, allow_missing_keys: RandomizableTransform.__init__(self, prob) self.flipper = RandAxisFlip(prob=1.0) + def set_eager_mode(self, value): + super().set_eager_mode(value) + self.flipper.set_eager_mode(value) + def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None ) -> "RandAxisFlipd": @@ -1342,7 +1387,7 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch return d -class Rotated(MapTransform, InvertibleTransform): +class Rotated(MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Rotate`. @@ -1391,6 +1436,10 @@ def __init__( self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.dtype = ensure_tuple_rep(dtype, len(self.keys)) + def set_eager_mode(self, value): + super().set_eager_mode(value) + self.rotator.set_eager_mode(value) + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) for key, mode, padding_mode, align_corners, dtype in self.key_iterator( @@ -1408,7 +1457,7 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch return d -class RandRotated(RandomizableTransform, MapTransform, InvertibleTransform): +class RandRotated(RandomizableTransform, MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based version :py:class:`monai.transforms.RandRotate` Randomly rotates the input arrays. @@ -1467,6 +1516,10 @@ def __init__( self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.dtype = ensure_tuple_rep(dtype, len(self.keys)) + def set_eager_mode(self, value): + super().set_eager_mode(value) + self.rand_rotate.set_eager_mode(value) + def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None ) -> "RandRotated": @@ -1509,7 +1562,7 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch return d -class Zoomd(MapTransform, InvertibleTransform): +class Zoomd(MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Zoom`. @@ -1559,6 +1612,10 @@ def __init__( self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.zoomer = Zoom(zoom=zoom, keep_size=keep_size, **kwargs) + def set_eager_mode(self, value): + super().set_eager_mode(value) + self.zoomer.set_eager_mode(value) + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) for key, mode, padding_mode, align_corners in self.key_iterator( @@ -1574,7 +1631,7 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch return d -class RandZoomd(RandomizableTransform, MapTransform, InvertibleTransform): +class RandZoomd(RandomizableTransform, MapTransform, InvertibleTransform, LazyTransform): """ Dict-based version :py:class:`monai.transforms.RandZoom`. @@ -1635,6 +1692,10 @@ def __init__( self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) + def set_eager_mode(self, value): + super().set_eager_mode(value) + self.rand_zoom.set_eager_mode(value) + def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None ) -> "RandZoomd": diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 21d057f5d3..28e0f6ed75 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -26,7 +26,15 @@ from monai.utils.enums import TransformBackends from monai.utils.misc import MONAIEnvVars -__all__ = ["ThreadUnsafe", "apply_transform", "Randomizable", "RandomizableTransform", "Transform", "MapTransform"] +__all__ = [ + "ThreadUnsafe", + "apply_transform", + "Randomizable", + "RandomizableTransform", + "Transform", + "MapTransform", + "LazyTransform", +] ReturnType = TypeVar("ReturnType") @@ -131,6 +139,19 @@ class ThreadUnsafe: pass +class LazyTransform: + """Whether the transform can accept lazy metatensors (metatensor.evaluated is False) and can be evaluated lazily.""" + + eager_mode = True + + def set_eager_mode(self, value: bool): + """ + when eager_mode is True, the transform should return the transformed array with up-to-date metadata. + When it's False, the transform may return updated metadata and not running the actual data array transform. + """ + self.eager_mode = value + + class Randomizable(ThreadUnsafe): """ An interface for handling random state locally, currently based on a class diff --git a/monai/utils/enums.py b/monai/utils/enums.py index 12e82cd378..06ab12672e 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -513,6 +513,7 @@ class MetaKeys(StrEnum): SPATIAL_SHAPE = "spatial_shape" # optional key for the length in each spatial dimension SPACE = "space" # possible values of space type are defined in `SpaceKeys` ORIGINAL_CHANNEL_DIM = "original_channel_dim" # an integer or "no_channel" + EVALUATED = "evaluated" # whether the array is up-to-date with the applied_operations (lazy evaluation) class ColorOrder(StrEnum): diff --git a/tests/test_ensure_channel_first.py b/tests/test_ensure_channel_first.py index e671edc0a7..6b872c4ad6 100644 --- a/tests/test_ensure_channel_first.py +++ b/tests/test_ensure_channel_first.py @@ -82,8 +82,10 @@ def test_check(self): with self.assertRaises(ValueError): # not MetaTensor EnsureChannelFirst(channel_dim=None)(im) + test_case = MetaTensor(im) + test_case.meta.pop("original_channel_dim") with self.assertRaises(ValueError): # no meta - EnsureChannelFirst(channel_dim=None)(MetaTensor(im)) + EnsureChannelFirst(channel_dim=None)(test_case) with self.assertRaises(ValueError): # no meta channel EnsureChannelFirst()(im_nodim) diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index 2f873c2d73..64d604b65b 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -74,8 +74,8 @@ def check_meta(self, a: MetaTensor, b: MetaTensor) -> None: aff_a = meta_a.get("affine", None) aff_b = meta_b.get("affine", None) assert_allclose(aff_a, aff_b) - meta_a = {k: v for k, v in meta_a.items() if k != "affine"} - meta_b = {k: v for k, v in meta_b.items() if k != "affine"} + meta_a = {k: v for k, v in meta_a.items() if k not in ("affine", "original_channel_dim", "evaluated")} + meta_b = {k: v for k, v in meta_b.items() if k not in ("affine", "original_channel_dim", "evaluated")} self.assertEqual(meta_a, meta_b) def check( @@ -122,7 +122,7 @@ def test_as_tensor(self, device, dtype): def test_as_dict(self): m, _ = self.get_im() m_dict = m.as_dict("im") - im, meta = m_dict["im"], m_dict[PostFix.meta("im")] + im, meta = m_dict["im"], deepcopy(m_dict[PostFix.meta("im")]) affine = meta.pop("affine") m2 = MetaTensor(im, affine, meta) self.check(m2, m, check_ids=False)