diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 2aa8fbf8a1..4bb90cb440 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -239,7 +239,7 @@ from .lazy.array import ApplyPending from .lazy.dictionary import ApplyPendingd, ApplyPendingD, ApplyPendingDict from .lazy.functional import apply_pending -from .lazy.utils import combine_transforms, resample +from .lazy.utils import combine_transforms, resample_image from .meta_utility.dictionary import ( FromMetaTensord, FromMetaTensorD, diff --git a/monai/transforms/lazy/functional.py b/monai/transforms/lazy/functional.py index 6b95027832..c71ecedadb 100644 --- a/monai/transforms/lazy/functional.py +++ b/monai/transforms/lazy/functional.py @@ -11,20 +11,24 @@ from __future__ import annotations -from typing import Any, Mapping, Sequence +from typing import Any, Mapping, Sequence, Tuple + +import copy import torch from monai.apps.utils import get_logger from monai.config import NdarrayOrTensor from monai.data.meta_tensor import MetaTensor +from monai.data.meta_obj import get_track_meta from monai.data.utils import to_affine_nd from monai.transforms.lazy.utils import ( affine_from_pending, combine_transforms, is_compatible_apply_kwargs, kwargs_from_pending, - resample, + resample_image, + resample_points, ) from monai.transforms.traits import LazyTrait from monai.transforms.transform import MapTransform @@ -80,6 +84,63 @@ def _log_applied_info(data: Any, key=None, logger_name: bool | str = False): logger.info(f"Pending transforms applied: {key_str}applied_operations: {len(data.applied_operations)}") +def lazily_apply_op( + tensor, + op, + lazy_evaluation, + track_meta=True +) -> MetaTensor | Tuple[torch.Tensor, dict | None]: + """ + This function is intended for use only by developers of spatial functional transforms that + can be lazily executed. + + This function will immediately apply the op to the given tensor if `lazy_evaluation` is set to + False. Its precise behaviour depends on whether it is passed a Tensor or MetaTensor: + + If passed a Tensor, `lazily_apply_op` returns a tuple of Tensor and operation description: + - if `lazy_evaluation` is False, the transformed tensor and op is returned + - if `lazy_evaluation` is True, the tensor and op is returned + + If passed a MetaTensor, only the tensor itself is returned: + - if `lazy_evaluation` is False, the transformed tensor is returned, with the op added to + the applied operations + - if `lazy_evaluation` is True, the untransformed tensor is returned, with the op added to + the pending operations + + Args: + tensor: the tensor to have the operation lazily applied to + op: the operation description containing the transform and metadata + lazy_evaluation: a boolean flag indicating whether to apply the operation lazily + """ + if isinstance(tensor, MetaTensor): + tensor.push_pending_operation(op) + if lazy_evaluation is False: + response = apply_pending(tensor, track_meta=track_meta) + result, pending = response if isinstance(response, tuple) else (response, None) + # result, pending = apply_transforms(tensor, track_meta=track_meta) + return result + else: + return tensor + else: + if lazy_evaluation is False: + response = apply_pending(tensor, [op], track_meta=track_meta) + result, pending = response if isinstance(response, tuple) else (response, None) + # result, pending = apply_transforms(tensor, [op], track_meta=track_meta) + return (result, op) if get_track_meta() is True else result + else: + return (tensor, op) if get_track_meta() is True else tensor + + +def invert( + data: torch.tensor | MetaTensor, + lazy_evaluation=True +): + metadata = data.applied_operations.pop() + inv_metadata = copy.deepcopy(metadata) + inv_metadata.invert() + return lazily_apply_op(data, inv_metadata, lazy_evaluation, False) + + def apply_pending_transforms( data: NdarrayOrTensor | Sequence[Any | NdarrayOrTensor] | Mapping[Any, NdarrayOrTensor], keys: tuple | None, @@ -279,7 +340,7 @@ def apply_pending(data: torch.Tensor | MetaTensor, pending: list | None = None, # carry out an intermediate resample here due to incompatibility between arguments _cur_kwargs = cur_kwargs.copy() _cur_kwargs.update(override_kwargs) - data = resample(data.to(device), cumulative_xform, _cur_kwargs) + data = resample_image(data.to(device), cumulative_xform, _cur_kwargs) next_matrix = affine_from_pending(p) if next_matrix.shape[0] == 3: @@ -288,7 +349,10 @@ def apply_pending(data: torch.Tensor | MetaTensor, pending: list | None = None, cumulative_xform = combine_transforms(cumulative_xform, next_matrix) cur_kwargs.update(new_kwargs) cur_kwargs.update(override_kwargs) - data = resample(data.to(device), cumulative_xform, cur_kwargs) + if data.kind() == 'pixel': + data = resample_image(data.to(device), cumulative_xform, cur_kwargs) + elif data.kind() == 'point': + data = resample_points(data.to(device), cumulative_xform, cur_kwargs) if isinstance(data, MetaTensor): for p in pending: data.push_applied_operation(p) diff --git a/monai/transforms/lazy/utils.py b/monai/transforms/lazy/utils.py index 359559e319..3c27567a24 100644 --- a/monai/transforms/lazy/utils.py +++ b/monai/transforms/lazy/utils.py @@ -19,65 +19,11 @@ import monai from monai.config import NdarrayOrTensor from monai.data.utils import AFFINE_TOL +from monai.transforms.utils import Affine from monai.transforms.utils_pytorch_numpy_unification import allclose from monai.utils import LazyAttr, convert_to_numpy, convert_to_tensor, look_up_option -__all__ = ["resample", "combine_transforms"] - - -class Affine: - """A class to represent an affine transform matrix.""" - - __slots__ = ("data",) - - def __init__(self, data): - self.data = data - - @staticmethod - def is_affine_shaped(data): - """Check if the data is an affine matrix.""" - if isinstance(data, Affine): - return True - if isinstance(data, DisplacementField): - return False - if not hasattr(data, "shape") or len(data.shape) < 2: - return False - return data.shape[-1] in (3, 4) and data.shape[-1] == data.shape[-2] - - -class DisplacementField: - """A class to represent a dense displacement field.""" - - __slots__ = ("data",) - - def __init__(self, data): - self.data = data - - @staticmethod - def is_ddf_shaped(data): - """Check if the data is a DDF.""" - if isinstance(data, DisplacementField): - return True - if isinstance(data, Affine): - return False - if not hasattr(data, "shape") or len(data.shape) < 3: - return False - return not Affine.is_affine_shaped(data) - - -def combine_transforms(left: torch.Tensor, right: torch.Tensor) -> torch.Tensor: - """Given transforms A and B to be applied to x, return the combined transform (AB), so that A(B(x)) becomes AB(x)""" - if Affine.is_affine_shaped(left) and Affine.is_affine_shaped(right): # linear transforms - left = convert_to_tensor(left.data if isinstance(left, Affine) else left, wrap_sequence=True) - right = convert_to_tensor(right.data if isinstance(right, Affine) else right, wrap_sequence=True) - return torch.matmul(left, right) - if DisplacementField.is_ddf_shaped(left) and DisplacementField.is_ddf_shaped( - right - ): # adds DDFs, do we need metadata if metatensor input? - left = convert_to_tensor(left.data if isinstance(left, DisplacementField) else left, wrap_sequence=True) - right = convert_to_tensor(right.data if isinstance(right, DisplacementField) else right, wrap_sequence=True) - return left + right - raise NotImplementedError +__all__ = ["resample_image", "combine_transforms"] def affine_from_pending(pending_item): @@ -145,7 +91,7 @@ def requires_interp(matrix, atol=AFFINE_TOL): __override_lazy_keywords = {*list(LazyAttr), "atol"} -def resample(data: torch.Tensor, matrix: NdarrayOrTensor, kwargs: dict | None = None): +def resample_image(data: torch.Tensor, matrix: NdarrayOrTensor, kwargs: dict | None = None): """ Resample `data` using the affine transformation defined by ``matrix``. @@ -227,3 +173,8 @@ def resample(data: torch.Tensor, matrix: NdarrayOrTensor, kwargs: dict | None = resampler.lazy = False # resampler is a lazytransform with resampler.trace_transform(False): # don't track this transform in `img` return resampler(img=img, **call_kwargs) + + +def resample_points(data: torch.Tensor, matrix: NdarrayOrTensor, kwargs: dict | None = None): + # Handle all point resampling here + raise NotImplementedError() diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 9d55aa013b..b92327bbd7 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -32,14 +32,17 @@ from monai.networks.utils import meshgrid_ij from monai.transforms.croppad.array import CenterSpatialCrop, ResizeWithPadOrCrop from monai.transforms.inverse import InvertibleTransform +from monai.transforms.lazy import invert from monai.transforms.spatial.functional import ( affine_func, flip, + identity, orientation, resize, rotate, rotate90, spatial_resample, + transform_like, zoom, ) from monai.transforms.traits import MultiSampleTrait @@ -136,7 +139,7 @@ def __init__( padding_mode: str = GridSamplePadMode.BORDER, align_corners: bool = False, dtype: DtypeLike = np.float64, - lazy: bool = False, + lazy_evaluation: bool = False, ): """ Args: @@ -155,10 +158,10 @@ def __init__( dtype: data type for resampling computation. Defaults to ``float64`` for best precision. If ``None``, use the data type of input data. To be compatible with other modules, the output data type is always ``float32``. - lazy: a flag to indicate whether this transform should execute lazily or not. + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not. Defaults to False """ - LazyTransform.__init__(self, lazy=lazy) + LazyTransform.__init__(self, lazy_evaluation=lazy_evaluation) self.mode = mode self.padding_mode = padding_mode self.align_corners = align_corners @@ -173,7 +176,7 @@ def __call__( padding_mode: str | None = None, align_corners: bool | None = None, dtype: DtypeLike = None, - lazy: bool | None = None, + lazy_evaluation: bool | None = None, ) -> torch.Tensor: """ Args: @@ -205,8 +208,8 @@ def __call__( dtype: data type for resampling computation. Defaults to ``self.dtype`` or ``np.float64`` (for best precision). If ``None``, use the data type of input data. To be compatible with other modules, the output data type is always `float32`. - lazy: a flag to indicate whether this transform should execute lazily or not - during this call. Setting this to False or True overrides the ``lazy`` flag set + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy_evaluation`` flag set during initialization for this call. Defaults to None. The spatial rank is determined by the smallest among ``img.ndim -1``, ``len(src_affine) - 1``, and ``3``. @@ -219,7 +222,7 @@ def __call__( align_corners = align_corners if align_corners is not None else self.align_corners mode = mode if mode is not None else self.mode padding_mode = padding_mode if padding_mode is not None else self.padding_mode - lazy_ = self.lazy if lazy is None else lazy + lazy_ = self.lazy_evaluation if lazy_evaluation is None else lazy_evaluation return spatial_resample( img, dst_affine, @@ -228,7 +231,7 @@ def __call__( padding_mode, align_corners, dtype_pt, - lazy=lazy_, + lazy_evaluation=lazy_, transform_info=self.get_transform_info(), ) @@ -267,7 +270,7 @@ def __call__( # type: ignore padding_mode: str | None = None, align_corners: bool | None = None, dtype: DtypeLike = None, - lazy: bool | None = None, + lazy_evaluation: bool | None = None, ) -> torch.Tensor: """ Args: @@ -291,8 +294,8 @@ def __call__( # type: ignore dtype: data type for resampling computation. Defaults to ``self.dtype`` or ``np.float64`` (for best precision). If ``None``, use the data type of input data. To be compatible with other modules, the output data type is always `float32`. - lazy: a flag to indicate whether this transform should execute lazily or not - during this call. Setting this to False or True overrides the ``lazy`` flag set + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy_evaluation`` flag set during initialization for this call. Defaults to None. Raises: @@ -303,7 +306,7 @@ def __call__( # type: ignore if img_dst is None: raise RuntimeError("`img_dst` is missing.") dst_affine = img_dst.peek_pending_affine() if isinstance(img_dst, MetaTensor) else torch.eye(4) - lazy_ = self.lazy if lazy is None else lazy + lazy_ = self.lazy_evaluation if lazy_evaluation is None else lazy_evaluation img = super().__call__( img=img, dst_affine=dst_affine, @@ -312,7 +315,7 @@ def __call__( # type: ignore padding_mode=padding_mode, align_corners=align_corners, dtype=dtype, - lazy=lazy_, + lazy_evaluation=lazy_, ) if not lazy_: if isinstance(img, MetaTensor): @@ -325,7 +328,7 @@ def __call__( # type: ignore if isinstance(img, MetaTensor) and isinstance(img_dst, MetaTensor): original_fname = img.meta.get(Key.FILENAME_OR_OBJ, "resample_to_match_source") meta_dict = deepcopy(img_dst.meta) - for k in ("affine", "spatial_shape"): # keys that don't copy from img_dst in lazy evaluation + for k in ("affine", "spatial_shape"): # keys that don't copy from img_dst in lazy_evaluation evaluation meta_dict.pop(k, None) img.meta.update(meta_dict) img.meta[Key.FILENAME_OR_OBJ] = original_fname # keep the original name, the others are overwritten @@ -354,7 +357,7 @@ def __init__( recompute_affine: bool = False, min_pixdim: Sequence[float] | float | np.ndarray | None = None, max_pixdim: Sequence[float] | float | np.ndarray | None = None, - lazy: bool = False, + lazy_evaluation: bool = False, ) -> None: """ Args: @@ -407,10 +410,10 @@ def __init__( max_pixdim: maximal input spacing to be resampled. If provided, input image with a smaller spacing than this value will be kept in its original spacing (not be resampled to `pixdim`). Set it to `None` to use the value of `pixdim`. Default to `None`. - lazy: a flag to indicate whether this transform should execute lazily or not. + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not. Defaults to False """ - LazyTransform.__init__(self, lazy=lazy) + LazyTransform.__init__(self, lazy_evaluation=lazy_evaluation) self.pixdim = np.array(ensure_tuple(pixdim), dtype=np.float64) self.min_pixdim = np.array(ensure_tuple(min_pixdim), dtype=np.float64) self.max_pixdim = np.array(ensure_tuple(max_pixdim), dtype=np.float64) @@ -423,13 +426,13 @@ def __init__( raise ValueError(f"min_pixdim {self.min_pixdim} must be positive, smaller than max {self.max_pixdim}.") self.sp_resample = SpatialResample( - mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype, lazy=lazy + mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype, lazy_evaluation=lazy_evaluation ) - @LazyTransform.lazy.setter # type: ignore - def lazy(self, val: bool) -> None: + @LazyTransform.lazy_evaluation.setter # type: ignore + def lazy_evaluation(self, val: bool) -> None: self._lazy = val - self.sp_resample.lazy = val + self.sp_resample.lazy_evaluation = val def __call__( self, @@ -440,7 +443,7 @@ def __call__( dtype: DtypeLike = None, scale_extent: bool | None = None, output_spatial_shape: Sequence[int] | np.ndarray | int | None = None, - lazy: bool | None = None, + lazy_evaluation: bool | None = None, ) -> torch.Tensor: """ Args: @@ -470,8 +473,8 @@ def __call__( output_spatial_shape: specify the shape of the output data_array. This is typically useful for the inverse of `Spacingd` where sometimes we could not compute the exact shape due to the quantization error with the affine. - lazy: a flag to indicate whether this transform should execute lazily or not - during this call. Setting this to False or True overrides the ``lazy`` flag set + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy_evaluation`` flag set during initialization for this call. Defaults to None. Raises: @@ -521,7 +524,7 @@ def __call__( new_affine[:sr, -1] = offset[:sr] actual_shape = list(output_shape) if output_spatial_shape is None else output_spatial_shape - lazy_ = self.lazy if lazy is None else lazy + lazy_ = self.lazy_evaluation if lazy_evaluation is None else lazy_evaluation data_array = self.sp_resample( data_array, dst_affine=torch.as_tensor(new_affine), @@ -530,7 +533,7 @@ def __call__( padding_mode=padding_mode, align_corners=align_corners, dtype=dtype, - lazy=lazy_, + lazy_evaluation=lazy_, ) if self.recompute_affine and isinstance(data_array, MetaTensor): if lazy_: @@ -558,7 +561,7 @@ def __init__( axcodes: str | None = None, as_closest_canonical: bool = False, labels: Sequence[tuple[str, str]] | None = (("L", "R"), ("P", "A"), ("I", "S")), - lazy: bool = False, + lazy_evaluation: bool = False, ) -> None: """ Args: @@ -571,7 +574,7 @@ def __init__( labels: optional, None or sequence of (2,) sequences (2,) sequences are labels for (beginning, end) of output axis. Defaults to ``(('L', 'R'), ('P', 'A'), ('I', 'S'))``. - lazy: a flag to indicate whether this transform should execute lazily or not. + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not. Defaults to False Raises: @@ -580,7 +583,7 @@ def __init__( See Also: `nibabel.orientations.ornt2axcodes`. """ - LazyTransform.__init__(self, lazy=lazy) + LazyTransform.__init__(self, lazy_evaluation=lazy_evaluation) if axcodes is None and not as_closest_canonical: raise ValueError("Incompatible values: axcodes=None and as_closest_canonical=True.") if axcodes is not None and as_closest_canonical: @@ -589,15 +592,15 @@ def __init__( self.as_closest_canonical = as_closest_canonical self.labels = labels - def __call__(self, data_array: torch.Tensor, lazy: bool | None = None) -> torch.Tensor: + def __call__(self, data_array: torch.Tensor, lazy_evaluation: bool | None = None) -> torch.Tensor: """ If input type is `MetaTensor`, original affine is extracted with `data_array.affine`. If input type is `torch.Tensor`, original affine is assumed to be identity. Args: data_array: in shape (num_channels, H[, W, ...]). - lazy: a flag to indicate whether this transform should execute lazily or not - during this call. Setting this to False or True overrides the ``lazy`` flag set + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy_evaluation`` flag set during initialization for this call. Defaults to None. Raises: @@ -643,8 +646,8 @@ def __call__(self, data_array: torch.Tensor, lazy: bool | None = None) -> torch. f"axcodes must match data_array spatially, got axcodes={len(self.axcodes)}D data_array={sr}D" ) spatial_ornt = nib.orientations.ornt_transform(src, dst) - lazy_ = self.lazy if lazy is None else lazy - return orientation(data_array, affine_np, spatial_ornt, lazy=lazy_, transform_info=self.get_transform_info()) + lazy_ = self.lazy_evaluation if lazy_evaluation is None else lazy_evaluation + return orientation(data_array, affine_np, spatial_ornt, lazy_evaluation=lazy_, transform_info=self.get_transform_info()) def inverse(self, data: torch.Tensor) -> torch.Tensor: transform = self.pop_transform(data) @@ -659,107 +662,101 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: return data -class Flip(InvertibleTransform, LazyTransform): +class Flip(LazyTransform, InvertibleTransform): """ Reverses the order of elements along the given spatial axis. Preserves shape. See `torch.flip` documentation for additional details: https://pytorch.org/docs/stable/generated/torch.flip.html - This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` - for more information. - Args: spatial_axis: spatial axes along which to flip over. Default is None. The default `axis=None` will flip over all of the axes of the input array. If axis is negative it counts from the last to the first axis. If axis is a tuple of ints, flipping is performed on all of the axes specified in the tuple. - lazy: a flag to indicate whether this transform should execute lazily or not. - Defaults to False """ backend = [TransformBackends.TORCH] - def __init__(self, spatial_axis: Sequence[int] | int | None = None, lazy: bool = False) -> None: - LazyTransform.__init__(self, lazy=lazy) + def __init__( + self, + spatial_axis: Sequence[int] | int | None = None, + lazy_evaluation: bool = False, + ) -> None: + LazyTransform.__init__(self, lazy_evaluation) self.spatial_axis = spatial_axis - def __call__(self, img: torch.Tensor, lazy: bool | None = None) -> torch.Tensor: - """ - Args: - img: channel first array, must have shape: (num_channels, H[, W, ..., ]) - lazy: a flag to indicate whether this transform should execute lazily or not - during this call. Setting this to False or True overrides the ``lazy`` flag set - during initialization for this call. Defaults to None. - """ - img = convert_to_tensor(img, track_meta=get_track_meta()) - lazy_ = self.lazy if lazy is None else lazy - return flip(img, self.spatial_axis, lazy=lazy_, transform_info=self.get_transform_info()) # type: ignore + def __call__( + self, + img: NdarrayOrTensor, + spatial_axis: Optional[Union[Sequence[int], int]] = None, + shape_override: Optional[Sequence] = None, + lazy_evaluation: bool | None = None, + ): + spatial_axis_ = spatial_axis or self.spatial_axis + lazy_ = self.lazy_evaluation if lazy_evaluation is None else lazy_evaluation - def inverse(self, data: torch.Tensor) -> torch.Tensor: - self.pop_transform(data) - flipper = Flip(spatial_axis=self.spatial_axis) - with flipper.trace_transform(False): - return flipper(data) + img_t = flip(img, spatial_axis_, lazy_evaluation=lazy_) + return img_t -class Resize(InvertibleTransform, LazyTransform): - """ - Resize the input image to given spatial size (with scaling, not cropping/padding). - Implemented using :py:class:`torch.nn.functional.interpolate`. + def inverse(self, data): + return invert(data, self.lazy_evaluation) - This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` - for more information. - Args: - spatial_size: expected shape of spatial dimensions after resize operation. - if some components of the `spatial_size` are non-positive values, the transform will use the - corresponding components of img size. For example, `spatial_size=(32, -1)` will be adapted - to `(32, 64)` if the second spatial dimension size of img is `64`. - size_mode: should be "all" or "longest", if "all", will use `spatial_size` for all the spatial dims, - if "longest", rescale the image so that only the longest side is equal to specified `spatial_size`, - which must be an int number in this case, keeping the aspect ratio of the initial image, refer to: - https://albumentations.ai/docs/api_reference/augmentations/geometric/resize/ - #albumentations.augmentations.geometric.resize.LongestMaxSize. - mode: {``"nearest"``, ``"nearest-exact"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} - The interpolation mode. Defaults to ``"area"``. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html - align_corners: This only has an effect when mode is - 'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html - anti_aliasing: bool - Whether to apply a Gaussian filter to smooth the image prior - to downsampling. It is crucial to filter when downsampling - the image to avoid aliasing artifacts. See also ``skimage.transform.resize`` - anti_aliasing_sigma: {float, tuple of floats}, optional - Standard deviation for Gaussian filtering used when anti-aliasing. - By default, this value is chosen as (s - 1) / 2 where s is the - downsampling factor, where s > 1. For the up-size case, s < 1, no - anti-aliasing is performed prior to rescaling. - dtype: data type for resampling computation. Defaults to ``float32``. - If None, use the data type of input data. - lazy: a flag to indicate whether this transform should execute lazily or not. - Defaults to False +class Resize(LazyTransform, InvertibleTransform): + """ + TODO: update for unified resampling parameters on general resample/Resample """ backend = [TransformBackends.TORCH] def __init__( - self, - spatial_size: Sequence[int] | int, - size_mode: str = "all", - mode: str = InterpolateMode.AREA, - align_corners: bool | None = None, - anti_aliasing: bool = False, - anti_aliasing_sigma: Sequence[float] | float | None = None, - dtype: DtypeLike | torch.dtype = torch.float32, - lazy: bool = False, + self, + spatial_size: Sequence[int] | int, + size_mode: str = "all", + mode: str = InterpolateMode.AREA, + align_corners: bool | None = None, + anti_aliasing: bool = False, + anti_aliasing_sigma: Sequence[float] | float | None = None, + dtype: Optional[Union[DtypeLike, torch.dtype]] = np.float32, + lazy_evaluation: Optional[bool] = True, ) -> None: - LazyTransform.__init__(self, lazy=lazy) - self.size_mode = look_up_option(size_mode, ["all", "longest"]) + """ + Resize the input image to given spatial size (with scaling, not cropping/padding). + Implemented using :py:class:`torch.nn.functional.interpolate`. + + Args: + spatial_size: expected shape of spatial dimensions after resize operation. + if some components of the `spatial_size` are non-positive values, the transform will use the + corresponding components of img size. For example, `spatial_size=(32, -1)` will be adapted + to `(32, 64)` if the second spatial dimension size of img is `64`. + size_mode: should be "all" or "longest", if "all", will use `spatial_size` for all the spatial dims, + if "longest", rescale the image so that only the longest side is equal to specified `spatial_size`, + which must be an int number in this case, keeping the aspect ratio of the initial image, refer to: + https://albumentations.ai/docs/api_reference/augmentations/geometric/resize/ + #albumentations.augmentations.geometric.resize.LongestMaxSize. + mode: {``"nearest"``, ``"nearest-exact"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} + The interpolation mode. Defaults to ``"area"``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html + align_corners: This only has an effect when mode is + 'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html + anti_aliasing: bool + Whether to apply a Gaussian filter to smooth the image prior + to downsampling. It is crucial to filter when downsampling + the image to avoid aliasing artifacts. See also ``skimage.transform.resize`` + anti_aliasing_sigma: {float, tuple of floats}, optional + Standard deviation for Gaussian filtering used when anti-aliasing. + By default, this value is chosen as (s - 1) / 2 where s is the + downsampling factor, where s > 1. For the up-size case, s < 1, no + anti-aliasing is performed prior to rescaling. + """ + LazyTransform.__init__(self, lazy_evaluation) self.spatial_size = spatial_size - self.mode = mode + self.size_mode = look_up_option(size_mode, ["all", "longest"]) + self.mode: InterpolateMode = look_up_option(mode, InterpolateMode) self.align_corners = align_corners self.anti_aliasing = anti_aliasing self.anti_aliasing_sigma = anti_aliasing_sigma @@ -767,13 +764,13 @@ def __init__( def __call__( self, - img: torch.Tensor, - mode: str | None = None, - align_corners: bool | None = None, - anti_aliasing: bool | None = None, - anti_aliasing_sigma: Sequence[float] | float | None = None, - dtype: DtypeLike | torch.dtype = None, - lazy: bool | None = None, + img: torch.Tensor, + mode: str | None = None, + align_corners: bool | None = None, + anti_aliasing: bool | None = None, + anti_aliasing_sigma: Sequence[float] | float | None = None, + shape_override: Sequence = None, + lazy_evaluation: bool | None = None, ) -> torch.Tensor: """ Args: @@ -794,104 +791,27 @@ def __call__( By default, this value is chosen as (s - 1) / 2 where s is the downsampling factor, where s > 1. For the up-size case, s < 1, no anti-aliasing is performed prior to rescaling. - dtype: data type for resampling computation. Defaults to ``self.dtype``. - If None, use the data type of input data. - lazy: a flag to indicate whether this transform should execute lazily or not - during this call. Setting this to False or True overrides the ``lazy`` flag set - during initialization for this call. Defaults to None. + Raises: ValueError: When ``self.spatial_size`` length is less than ``img`` spatial dimensions. """ - anti_aliasing = self.anti_aliasing if anti_aliasing is None else anti_aliasing - anti_aliasing_sigma = self.anti_aliasing_sigma if anti_aliasing_sigma is None else anti_aliasing_sigma + mode_ = mode or self.mode + align_corners_ = align_corners or self.align_corners + anti_aliasing_ = anti_aliasing or self.anti_aliasing + anti_aliasing_sigma_ = anti_aliasing_sigma or self.anti_aliasing_sigma + lazy_ = self.lazy_evaluation if lazy_evaluation is None else lazy_evaluation - input_ndim = img.ndim - 1 # spatial ndim - if self.size_mode == "all": - output_ndim = len(ensure_tuple(self.spatial_size)) - if output_ndim > input_ndim: - input_shape = ensure_tuple_size(img.shape, output_ndim + 1, 1) - img = img.reshape(input_shape) - elif output_ndim < input_ndim: - raise ValueError( - "len(spatial_size) must be greater or equal to img spatial dimensions, " - f"got spatial_size={output_ndim} img={input_ndim}." - ) - _sp = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] - sp_size = fall_back_tuple(self.spatial_size, _sp) - else: # for the "longest" mode - img_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] - if not isinstance(self.spatial_size, int): - raise ValueError("spatial_size must be an int number if size_mode is 'longest'.") - scale = self.spatial_size / max(img_size) - sp_size = tuple(int(round(s * scale)) for s in img_size) - - _mode = self.mode if mode is None else mode - _align_corners = self.align_corners if align_corners is None else align_corners - _dtype = get_equivalent_dtype(dtype or self.dtype or img.dtype, torch.Tensor) - lazy_ = self.lazy if lazy is None else lazy - return resize( # type: ignore - img, - tuple(int(_s) for _s in sp_size), - _mode, - _align_corners, - _dtype, - input_ndim, - anti_aliasing, - anti_aliasing_sigma, - lazy_, - self.get_transform_info(), - ) + img_t = resize(img, self.spatial_size, self.size_mode, mode_, align_corners_, anti_aliasing_, + anti_aliasing_sigma_, self.dtype, lazy_evaluation=lazy_) - def inverse(self, data: torch.Tensor) -> torch.Tensor: - transform = self.pop_transform(data) - return self.inverse_transform(data, transform) + return img_t - def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor: - orig_size = transform[TraceKeys.ORIG_SIZE] - mode = transform[TraceKeys.EXTRA_INFO]["mode"] - align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] - dtype = transform[TraceKeys.EXTRA_INFO]["dtype"] - xform = Resize( - spatial_size=orig_size, - mode=mode, - align_corners=None if align_corners == TraceKeys.NONE else align_corners, - dtype=dtype, - ) - with xform.trace_transform(False): - data = xform(data) - for _ in range(transform[TraceKeys.EXTRA_INFO]["new_dim"]): - data = data.squeeze(-1) # remove the additional dims - return data + def inverse(self, data): + return invert(data, self.lazy_evaluation) class Rotate(InvertibleTransform, LazyTransform): - """ - Rotates an input image by given angle using :py:class:`monai.networks.layers.AffineTransform`. - - This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` - for more information. - - Args: - angle: Rotation angle(s) in radians. should a float for 2D, three floats for 3D. - keep_size: If it is True, the output shape is kept the same as the input. - If it is False, the output shape is adapted so that the - input array is contained completely in the output. Default is True. - mode: {``"bilinear"``, ``"nearest"``} - Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html - padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} - Padding mode for outside grid values. Defaults to ``"border"``. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html - align_corners: Defaults to False. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html - dtype: data type for resampling computation. Defaults to ``float32``. - If None, use the data type of input data. To be compatible with other modules, - the output data type is always ``float32``. - lazy: a flag to indicate whether this transform should execute lazily or not. - Defaults to False - """ - backend = [TransformBackends.TORCH] def __init__( @@ -902,24 +822,45 @@ def __init__( padding_mode: str = GridSamplePadMode.BORDER, align_corners: bool = False, dtype: DtypeLike | torch.dtype = torch.float32, - lazy: bool = False, + lazy_evaluation: bool = False, ) -> None: - LazyTransform.__init__(self, lazy=lazy) + """ + Rotates an input image by given angle using :py:class:`monai.networks.layers.AffineTransform`. + + Args: + angle: Rotation angle(s) in radians. should a float for 2D, three floats for 3D. + keep_size: If it is True, the output shape is kept the same as the input. + If it is False, the output shape is adapted so that the + input array is contained completely in the output. Default is True. + mode: {``"bilinear"``, ``"nearest"``} + Interpolation mode to calculate output values. Defaults to ``"bilinear"``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. Defaults to ``"border"``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + align_corners: Defaults to False. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + dtype: data type for resampling computation. Defaults to ``float32``. + If None, use the data type of input data. To be compatible with other modules, + the output data type is always ``float32``. + """ + LazyTransform.__init__(self, lazy_evaluation) self.angle = angle self.keep_size = keep_size - self.mode: str = mode - self.padding_mode: str = padding_mode + self.mode: str = look_up_option(mode, GridSampleMode) + self.padding_mode: str = look_up_option(padding_mode, GridSamplePadMode) self.align_corners = align_corners self.dtype = dtype def __call__( - self, - img: torch.Tensor, - mode: str | None = None, - padding_mode: str | None = None, - align_corners: bool | None = None, - dtype: DtypeLike | torch.dtype = None, - lazy: bool | None = None, + self, + img: torch.Tensor, + mode: str | None = None, + padding_mode: str | None = None, + align_corners: bool | None = None, + dtype: DtypeLike | torch.dtype = None, + shape_override: Sequence = None, + lazy_evaluation: bool | None = None, ) -> torch.Tensor: """ Args: @@ -937,102 +878,27 @@ def __call__( dtype: data type for resampling computation. Defaults to ``self.dtype``. If None, use the data type of input data. To be compatible with other modules, the output data type is always ``float32``. - lazy: a flag to indicate whether this transform should execute lazily or not - during this call. Setting this to False or True overrides the ``lazy`` flag set - during initialization for this call. Defaults to None. Raises: ValueError: When ``img`` spatially is not one of [2D, 3D]. """ - img = convert_to_tensor(img, track_meta=get_track_meta()) - _dtype = get_equivalent_dtype(dtype or self.dtype or img.dtype, torch.Tensor) - _mode = mode or self.mode - _padding_mode = padding_mode or self.padding_mode - _align_corners = self.align_corners if align_corners is None else align_corners - im_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] - output_shape = im_shape if self.keep_size else None - lazy_ = self.lazy if lazy is None else lazy - return rotate( # type: ignore - img, - self.angle, - output_shape, - _mode, - _padding_mode, - _align_corners, - _dtype, - lazy=lazy_, - transform_info=self.get_transform_info(), - ) + mode_ = mode or self.mode + padding_mode_ = padding_mode or self.padding_mode + align_corners_ = align_corners or self.align_corners + keep_size = self.keep_size + dtype_ = self.dtype + lazy_ = self.lazy_evaluation if lazy_evaluation is None else lazy_evaluation - def inverse(self, data: torch.Tensor) -> torch.Tensor: - transform = self.pop_transform(data) - return self.inverse_transform(data, transform) + img_t = rotate(img, self.angle, keep_size, mode_, padding_mode_, align_corners_, dtype_, lazy_evaluation=lazy_) - def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor: - fwd_rot_mat = transform[TraceKeys.EXTRA_INFO]["rot_mat"] - mode = transform[TraceKeys.EXTRA_INFO]["mode"] - padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] - align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] - dtype = transform[TraceKeys.EXTRA_INFO]["dtype"] - inv_rot_mat = linalg_inv(convert_to_numpy(fwd_rot_mat)) - - _, _m, _p, _ = resolves_modes(mode, padding_mode) - xform = AffineTransform( - normalized=False, - mode=_m, - padding_mode=_p, - align_corners=False if align_corners == TraceKeys.NONE else align_corners, - reverse_indexing=True, - ) - img_t: torch.Tensor = convert_data_type(data, MetaTensor, dtype=dtype)[0] - transform_t, *_ = convert_to_dst_type(inv_rot_mat, img_t) - sp_size = transform[TraceKeys.ORIG_SIZE] - out: torch.Tensor = xform(img_t.unsqueeze(0), transform_t, spatial_size=sp_size).float().squeeze(0) - out = convert_to_dst_type(out, dst=data, dtype=out.dtype)[0] - if isinstance(out, MetaTensor): - affine = convert_to_tensor(out.peek_pending_affine(), track_meta=False) - mat = to_affine_nd(len(affine) - 1, transform_t) - out.affine @= convert_to_dst_type(mat, affine)[0] - return out + return img_t + def inverse(self, data): + return invert(data, self.lazy_evaluation) -class Zoom(InvertibleTransform, LazyTransform): - """ - Zooms an ND image using :py:class:`torch.nn.functional.interpolate`. - For details, please see https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html. - Different from :py:class:`monai.transforms.resize`, this transform takes scaling factors - as input, and provides an option of preserving the input spatial size. - - This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` - for more information. - - Args: - zoom: The zoom factor along the spatial axes. - If a float, zoom is the same for each spatial axis. - If a sequence, zoom should contain one value for each spatial axis. - mode: {``"nearest"``, ``"nearest-exact"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} - The interpolation mode. Defaults to ``"area"``. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html - padding_mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, - ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} - available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. - One of the listed string values or a user supplied function. Defaults to ``"edge"``. - The mode to pad data after zooming. - See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html - https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html - align_corners: This only has an effect when mode is - 'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html - dtype: data type for resampling computation. Defaults to ``float32``. - If None, use the data type of input data. - keep_size: Should keep original size (padding/slicing if needed), default is True. - lazy: a flag to indicate whether this transform should execute lazily or not. - Defaults to False - kwargs: other arguments for the `np.pad` or `torch.pad` function. - note that `np.pad` treats channel dimension as the first dimension. - """ +class Zoom(LazyTransform, InvertibleTransform): backend = [TransformBackends.TORCH] @@ -1042,28 +908,54 @@ def __init__( mode: str = InterpolateMode.AREA, padding_mode: str = NumpyPadMode.EDGE, align_corners: bool | None = None, - dtype: DtypeLike | torch.dtype = torch.float32, keep_size: bool = True, - lazy: bool = False, + lazy_evaluation: Optional[bool] = True, **kwargs, ) -> None: - LazyTransform.__init__(self, lazy=lazy) + """ + Zooms an ND image using :py:class:`torch.nn.functional.interpolate`. + For details, please see https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html. + + Different from :py:class:`monai.transforms.resize`, this transform takes scaling factors + as input, and provides an option of preserving the input spatial size. + + Args: + zoom: The zoom factor along the spatial axes. + If a float, zoom is the same for each spatial axis. + If a sequence, zoom should contain one value for each spatial axis. + mode: {``"nearest"``, ``"nearest-exact"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} + The interpolation mode. Defaults to ``"area"``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html + padding_mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to ``"edge"``. + The mode to pad data after zooming. + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + align_corners: This only has an effect when mode is + 'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html + keep_size: Should keep original size (padding/slicing if needed), default is True. + kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. + """ + LazyTransform.__init__(self, lazy_evaluation) self.zoom = zoom - self.mode = mode + self.mode: InterpolateMode = InterpolateMode(mode) self.padding_mode = padding_mode self.align_corners = align_corners - self.dtype = dtype self.keep_size = keep_size self.kwargs = kwargs def __call__( - self, - img: torch.Tensor, - mode: str | None = None, - padding_mode: str | None = None, - align_corners: bool | None = None, - dtype: DtypeLike | torch.dtype = None, - lazy: bool | None = None, + self, + img: NdarrayOrTensor, + mode: str | None = None, + padding_mode: str | None = None, + align_corners: bool | None = None, + shape_override: Sequence = None, + lazy_evaluation: bool | None = None, ) -> torch.Tensor: """ Args: @@ -1082,55 +974,19 @@ def __call__( align_corners: This only has an effect when mode is 'linear', 'bilinear', 'bicubic' or 'trilinear'. Defaults to ``self.align_corners``. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html - dtype: data type for resampling computation. Defaults to ``self.dtype``. - If None, use the data type of input data. - lazy: a flag to indicate whether this transform should execute lazily or not - during this call. Setting this to False or True overrides the ``lazy`` flag set - during initialization for this call. Defaults to None. + """ - img = convert_to_tensor(img, track_meta=get_track_meta()) - _zoom = ensure_tuple_rep(self.zoom, img.ndim - 1) # match the spatial image dim - _mode = self.mode if mode is None else mode - _padding_mode = padding_mode or self.padding_mode - _align_corners = self.align_corners if align_corners is None else align_corners - _dtype = get_equivalent_dtype(dtype or self.dtype or img.dtype, torch.Tensor) - lazy_ = self.lazy if lazy is None else lazy - return zoom( # type: ignore - img, - _zoom, - self.keep_size, - _mode, - _padding_mode, - _align_corners, - _dtype, - lazy=lazy_, - transform_info=self.get_transform_info(), - ) + mode_ = self.mode if mode is None else mode + padding_mode_ = self.padding_mode if padding_mode is None else padding_mode + align_corners_ = self.align_corners if align_corners is None else align_corners + lazy_ = self.lazy_evaluation if lazy_evaluation is None else lazy_evaluation - def inverse(self, data: torch.Tensor) -> torch.Tensor: - transform = self.pop_transform(data) - return self.inverse_transform(data, transform) - - def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor: - if transform[TraceKeys.EXTRA_INFO]["do_padcrop"]: - orig_size = transform[TraceKeys.ORIG_SIZE] - pad_or_crop = ResizeWithPadOrCrop(spatial_size=orig_size, mode="edge") - padcrop_xform = transform[TraceKeys.EXTRA_INFO]["padcrop"] - padcrop_xform[TraceKeys.EXTRA_INFO]["pad_info"][TraceKeys.ID] = TraceKeys.NONE - padcrop_xform[TraceKeys.EXTRA_INFO]["crop_info"][TraceKeys.ID] = TraceKeys.NONE - # this uses inverse because spatial_size // 2 in the forward pass of center crop may cause issues - data = pad_or_crop.inverse_transform(data, padcrop_xform) # type: ignore - # Create inverse transform - mode = transform[TraceKeys.EXTRA_INFO]["mode"] - align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] - dtype = transform[TraceKeys.EXTRA_INFO]["dtype"] - inverse_transform = Resize(spatial_size=transform[TraceKeys.ORIG_SIZE]) - # Apply inverse - with inverse_transform.trace_transform(False): - out = inverse_transform( - data, mode=mode, align_corners=None if align_corners == TraceKeys.NONE else align_corners, dtype=dtype - ) - return out + img_t = zoom(img, self.zoom, mode_, padding_mode_, align_corners_, self.keep_size, img.dtype, lazy_evaluation=lazy_) + + return img_t + + def inverse(self, data): + return invert(data, self.lazy_evaluation) class Rotate90(InvertibleTransform, LazyTransform): @@ -1139,68 +995,65 @@ class Rotate90(InvertibleTransform, LazyTransform): See `torch.rot90` for additional details: https://pytorch.org/docs/stable/generated/torch.rot90.html#torch-rot90. - This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` - for more information. """ backend = [TransformBackends.TORCH] - def __init__(self, k: int = 1, spatial_axes: tuple[int, int] = (0, 1), lazy: bool = False) -> None: + def __init__( + self, + k: int = 1, + spatial_axes: tuple[int, int] = (0, 1), + lazy_evaluation: bool = False, + ) -> None: """ Args: k: number of times to rotate by 90 degrees. spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes. Default: (0, 1), this is the first two axis in spatial dimensions. If axis is negative it counts from the last to the first axis. - lazy: a flag to indicate whether this transform should execute lazily or not. - Defaults to False """ - LazyTransform.__init__(self, lazy=lazy) - self.k = (4 + (k % 4)) % 4 # 0, 1, 2, 3 + LazyTransform.__init__(self, lazy_evaluation) + self.k = k spatial_axes_: tuple[int, int] = ensure_tuple(spatial_axes) # type: ignore if len(spatial_axes_) != 2: - raise ValueError(f"spatial_axes must be 2 numbers to define the plane to rotate, got {spatial_axes_}.") + raise ValueError("spatial_axes must be 2 int numbers to indicate the axes to rotate 90 degrees.") self.spatial_axes = spatial_axes_ - def __call__(self, img: torch.Tensor, lazy: bool | None = None) -> torch.Tensor: + def __call__( + self, + img: torch.Tensor, + shape_override: Sequence = None, + lazy_evaluation: bool | None = None, + ) -> torch.Tensor: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), - lazy: a flag to indicate whether this transform should execute lazily or not - during this call. Setting this to False or True overrides the ``lazy`` flag set - during initialization for this call. Defaults to None. """ - img = convert_to_tensor(img, track_meta=get_track_meta()) - axes = map_spatial_axes(img.ndim, self.spatial_axes) - lazy_ = self.lazy if lazy is None else lazy - return rotate90(img, axes, self.k, lazy=lazy_, transform_info=self.get_transform_info()) # type: ignore + lazy_ = self.lazy_evaluation if lazy_evaluation is None else lazy_evaluation + img_t = rotate90(img, self.k, self.spatial_axes, lazy_evaluation=lazy_) - def inverse(self, data: torch.Tensor) -> torch.Tensor: - transform = self.pop_transform(data) - return self.inverse_transform(data, transform) + return img_t - def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor: - axes = transform[TraceKeys.EXTRA_INFO]["axes"] - k = transform[TraceKeys.EXTRA_INFO]["k"] - inv_k = 4 - k % 4 - xform = Rotate90(k=inv_k, spatial_axes=axes) - with xform.trace_transform(False): - return xform(data) + def inverse(self, data): + return invert(data, self.lazy_evaluation) class RandRotate90(RandomizableTransform, InvertibleTransform, LazyTransform): """ With probability `prob`, input arrays are rotated by 90 degrees in the plane specified by `spatial_axes`. - - This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` - for more information. """ backend = Rotate90.backend def __init__( - self, prob: float = 0.1, max_k: int = 3, spatial_axes: tuple[int, int] = (0, 1), lazy: bool = False + self, + prob: float = 0.1, + max_k: int = 3, + spatial_axes: tuple[int, int] = (0, 1), + lazy_evaluation: bool = False, + seed: int | None = None, + state: np.random.RandomState | None = None, ) -> None: """ Args: @@ -1209,14 +1062,14 @@ def __init__( max_k: number of rotations will be sampled from `np.random.randint(max_k) + 1`, (Default 3). spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes. Default: (0, 1), this is the first two axis in spatial dimensions. - lazy: a flag to indicate whether this transform should execute lazily or not. + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not. Defaults to False """ RandomizableTransform.__init__(self, prob) - LazyTransform.__init__(self, lazy=lazy) + LazyTransform.__init__(self, lazy_evaluation) + self.max_k = max_k self.spatial_axes = spatial_axes - self._rand_k = 0 def randomize(self, data: Any | None = None) -> None: @@ -1225,35 +1078,32 @@ def randomize(self, data: Any | None = None) -> None: return None self._rand_k = self.R.randint(self.max_k) + 1 - def __call__(self, img: torch.Tensor, randomize: bool = True, lazy: bool | None = None) -> torch.Tensor: + def __call__( + self, + img: torch.Tensor, + randomize: bool = True, + shape_override: Sequence = None, + lazy_evaluation: bool | None = None, + ) -> torch.Tensor: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), randomize: whether to execute `randomize()` function first, default to True. - lazy: a flag to indicate whether this transform should execute lazily or not - during this call. Setting this to False or True overrides the ``lazy`` flag set + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy_evaluation`` flag set during initialization for this call. Defaults to None. """ - if randomize: self.randomize() - lazy_ = self.lazy if lazy is None else lazy + lazy_ = self.lazy_evaluation if lazy_evaluation is None else lazy_evaluation if self._do_transform: - xform = Rotate90(self._rand_k, self.spatial_axes, lazy=lazy_) - out = xform(img) + return rotate90(img, self._rand_k, self.spatial_axes, lazy_evaluation=lazy_) else: - out = convert_to_tensor(img, track_meta=get_track_meta()) + return identity(img, None, None, lazy_evaluation=lazy_) - self.push_transform(out, replace=True, lazy=lazy_) - return out - - def inverse(self, data: torch.Tensor) -> torch.Tensor: - xform_info = self.pop_transform(data) - if not xform_info[TraceKeys.DO_TRANSFORM]: - return data - rotate_xform = xform_info[TraceKeys.EXTRA_INFO] - return Rotate90().inverse_transform(data, rotate_xform) + def inverse(self, data): + return invert(data, self.lazy_evaluation) class RandRotate(RandomizableTransform, InvertibleTransform, LazyTransform): @@ -1285,65 +1135,57 @@ class RandRotate(RandomizableTransform, InvertibleTransform, LazyTransform): dtype: data type for resampling computation. Defaults to ``float32``. If None, use the data type of input data. To be compatible with other modules, the output data type is always ``float32``. - lazy: a flag to indicate whether this transform should execute lazily or not. + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not. Defaults to False """ backend = Rotate.backend def __init__( - self, - range_x: tuple[float, float] | float = 0.0, - range_y: tuple[float, float] | float = 0.0, - range_z: tuple[float, float] | float = 0.0, - prob: float = 0.1, - keep_size: bool = True, - mode: str = GridSampleMode.BILINEAR, - padding_mode: str = GridSamplePadMode.BORDER, - align_corners: bool = False, - dtype: DtypeLike | torch.dtype = np.float32, - lazy: bool = False, + self, + range_x: tuple[float, float] | float = 0.0, + range_y: tuple[float, float] | float = 0.0, + range_z: tuple[float, float] | float = 0.0, + prob: float = 0.1, + keep_size: bool = True, + mode: str = GridSampleMode.BILINEAR, + padding_mode: str = GridSamplePadMode.BORDER, + align_corners: bool = False, + dtype: DtypeLike | torch.dtype = np.float32, + lazy_evaluation: bool = False, + seed: int | None = None, + state: np.random.RandomState | None = None, ) -> None: RandomizableTransform.__init__(self, prob) - LazyTransform.__init__(self, lazy=lazy) - self.range_x = ensure_tuple(range_x) - if len(self.range_x) == 1: - self.range_x = tuple(sorted([-self.range_x[0], self.range_x[0]])) - self.range_y = ensure_tuple(range_y) - if len(self.range_y) == 1: - self.range_y = tuple(sorted([-self.range_y[0], self.range_y[0]])) - self.range_z = ensure_tuple(range_z) - if len(self.range_z) == 1: - self.range_z = tuple(sorted([-self.range_z[0], self.range_z[0]])) - + LazyTransform.__init__(self, lazy_evaluation) self.keep_size = keep_size - self.mode: str = mode - self.padding_mode: str = padding_mode + self.mode: str = look_up_option(mode, GridSampleMode) + self.padding_mode: str = look_up_option(padding_mode, GridSamplePadMode) self.align_corners = align_corners self.dtype = dtype - - self.x = 0.0 - self.y = 0.0 - self.z = 0.0 + self.range_x, self.range_y, self.range_z = range_x, range_y, range_z + self._rand_x, self._rand_y, self._rand_z = 0, 0, 0 def randomize(self, data: Any | None = None) -> None: super().randomize(None) if not self._do_transform: return None - self.x = self.R.uniform(low=self.range_x[0], high=self.range_x[1]) - self.y = self.R.uniform(low=self.range_y[0], high=self.range_y[1]) - self.z = self.R.uniform(low=self.range_z[0], high=self.range_z[1]) + self._rand_x = self.R.uniform(low=self.range_x[0], high=self.range_x[1]) + self._rand_y = self.R.uniform(low=self.range_y[0], high=self.range_y[1]) + self._rand_z = self.R.uniform(low=self.range_z[0], high=self.range_z[1]) def __call__( - self, - img: torch.Tensor, - mode: str | None = None, - padding_mode: str | None = None, - align_corners: bool | None = None, - dtype: DtypeLike | torch.dtype = None, - randomize: bool = True, - lazy: bool | None = None, - ): + self, + img: torch.Tensor, + mode: str | None = None, + padding_mode: str | None = None, + align_corners: bool | None = None, + dtype: DtypeLike | torch.dtype = None, + randomize: bool = True, + get_matrix: bool = False, + shape_override: Sequence | None = None, + lazy_evaluation: bool | None = None, + ) -> NdarrayOrTensor: """ Args: img: channel first array, must have shape 2D: (nchannels, H, W), or 3D: (nchannels, H, W, D). @@ -1359,36 +1201,30 @@ def __call__( If None, use the data type of input data. To be compatible with other modules, the output data type is always ``float32``. randomize: whether to execute `randomize()` function first, default to True. - lazy: a flag to indicate whether this transform should execute lazily or not - during this call. Setting this to False or True overrides the ``lazy`` flag set + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy_evaluation`` flag set during initialization for this call. Defaults to None. """ if randomize: self.randomize() - lazy_ = self.lazy if lazy is None else lazy + mode_ = mode or self.mode + padding_mode_ = padding_mode or self.padding_mode + align_corners_ = align_corners or self.align_corners + dtype_ = dtype or self.dtype + lazy_ = self.lazy_evaluation if lazy_evaluation is None else lazy_evaluation + + # TODO: ideally, the rotate function should be told that it was called by the RandRotate class for + # pending / applied op descriptions if self._do_transform: ndim = len(img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]) - rotator = Rotate( - angle=self.x if ndim == 2 else (self.x, self.y, self.z), - keep_size=self.keep_size, - mode=mode or self.mode, - padding_mode=padding_mode or self.padding_mode, - align_corners=self.align_corners if align_corners is None else align_corners, - dtype=dtype or self.dtype or img.dtype, - lazy=lazy_, - ) - out = rotator(img) + angles=self._rand_x if ndim == 2 else (self._rand_x, self._rand_y, self._rand_z), + return rotate(img, angles, self.keep_size, mode_, padding_mode_, align_corners_, dtype_, lazy_evaluation=lazy_) else: - out = convert_to_tensor(img, track_meta=get_track_meta(), dtype=torch.float32) - self.push_transform(out, replace=True, lazy=lazy_) - return out + return identity(img, None, None, lazy_evaluation=lazy_) - def inverse(self, data: torch.Tensor) -> torch.Tensor: - xform_info = self.pop_transform(data) - if not xform_info[TraceKeys.DO_TRANSFORM]: - return data - return Rotate(0).inverse_transform(data, xform_info[TraceKeys.EXTRA_INFO]) + def inverse(self, data): + return invert(data, self.lazy_evaluation) class RandFlip(RandomizableTransform, InvertibleTransform, LazyTransform): @@ -1403,45 +1239,51 @@ class RandFlip(RandomizableTransform, InvertibleTransform, LazyTransform): Args: prob: Probability of flipping. spatial_axis: Spatial axes along which to flip over. Default is None. - lazy: a flag to indicate whether this transform should execute lazily or not. + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not. Defaults to False """ backend = Flip.backend - def __init__(self, prob: float = 0.1, spatial_axis: Sequence[int] | int | None = None, lazy: bool = False) -> None: + def __init__( + self, + prob: float = 0.1, + spatial_axis: Sequence[int] | int | None = None, + lazy_evaluation: bool = False, + seed: int | None = None, + state: np.random.RandomState | None = None, + ) -> None: RandomizableTransform.__init__(self, prob) - LazyTransform.__init__(self, lazy=lazy) - self.flipper = Flip(spatial_axis=spatial_axis, lazy=lazy) + LazyTransform.__init__(self, lazy_evaluation) + self.spatial_axis = spatial_axis - @LazyTransform.lazy.setter # type: ignore - def lazy(self, val: bool): - self.flipper.lazy = val - self._lazy = val - def __call__(self, img: torch.Tensor, randomize: bool = True, lazy: bool | None = None) -> torch.Tensor: + def __call__( + self, + img: torch.Tensor, + randomize: bool = True, + shape_override: Sequence = None, + lazy_evaluation: bool | None = None, + ) -> torch.Tensor: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), randomize: whether to execute `randomize()` function first, default to True. - lazy: a flag to indicate whether this transform should execute lazily or not - during this call. Setting this to False or True overrides the ``lazy`` flag set + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy_evaluation`` flag set during initialization for this call. Defaults to None. """ if randomize: self.randomize(None) - lazy_ = self.lazy if lazy is None else lazy - out = self.flipper(img, lazy=lazy_) if self._do_transform else img - out = convert_to_tensor(out, track_meta=get_track_meta()) - self.push_transform(out, replace=True, lazy=lazy_) - return out + lazy_ = self.lazy_evaluation if lazy_evaluation is None else lazy_evaluation - def inverse(self, data: torch.Tensor) -> torch.Tensor: - transform = self.pop_transform(data) - if not transform[TraceKeys.DO_TRANSFORM]: - return data - data.applied_operations.append(transform[TraceKeys.EXTRA_INFO]) # type: ignore - return self.flipper.inverse(data) + if self._do_transform: + return flip(img, self.spatial_axis, lazy_evaluation=lazy_) + else: + return identity(img, None, None, lazy_evaluation=lazy_) + + def inverse(self, data): + return invert(data, self.lazy_evaluation) class RandAxisFlip(RandomizableTransform, InvertibleTransform, LazyTransform): @@ -1455,57 +1297,56 @@ class RandAxisFlip(RandomizableTransform, InvertibleTransform, LazyTransform): Args: prob: Probability of flipping. - lazy: a flag to indicate whether this transform should execute lazily or not. + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not. Defaults to False """ backend = Flip.backend - def __init__(self, prob: float = 0.1, lazy: bool = False) -> None: + def __init__( + self, + prob: float = 0.1, + lazy_evaluation: bool = False, + seed: int | None = None, + state: np.random.RandomState | None = None, + ) -> None: RandomizableTransform.__init__(self, prob) - LazyTransform.__init__(self, lazy=lazy) - self._axis: int | None = None - self.flipper = Flip(spatial_axis=self._axis) + LazyTransform.__init__(self, lazy_evaluation) - @LazyTransform.lazy.setter # type: ignore - def lazy(self, val: bool): - self.flipper.lazy = val - self._lazy = val + self._rand_axis = None def randomize(self, data: NdarrayOrTensor) -> None: super().randomize(None) if not self._do_transform: return None - self._axis = self.R.randint(data.ndim - 1) + self._rand_axis = self.R.randint(data.ndim - 1) - def __call__(self, img: torch.Tensor, randomize: bool = True, lazy: bool | None = None) -> torch.Tensor: + def __call__( + self, + img: torch.Tensor, + randomize: bool = True, + shape_override: Sequence = None, + lazy_evaluation: bool | None = None, + ) -> torch.Tensor: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]) randomize: whether to execute `randomize()` function first, default to True. - lazy: a flag to indicate whether this transform should execute lazily or not - during this call. Setting this to False or True overrides the ``lazy`` flag set + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy_evaluation`` flag set during initialization for this call. Defaults to None. """ if randomize: self.randomize(data=img) - lazy_ = self.lazy if lazy is None else lazy + lazy_ = self.lazy_evaluation if lazy_evaluation is None else lazy_evaluation if self._do_transform: - self.flipper.spatial_axis = self._axis - out = self.flipper(img, lazy=lazy_) + return flip(img, self._rand_axis, lazy_evaluation=lazy_) else: - out = convert_to_tensor(img, track_meta=get_track_meta()) - self.push_transform(out, replace=True, lazy=lazy_) - return out + return identity(img, None, None, lazy_evaluation=lazy_) - def inverse(self, data: torch.Tensor) -> torch.Tensor: - transform = self.pop_transform(data) - if not transform[TraceKeys.DO_TRANSFORM]: - return data - flipper = Flip(spatial_axis=transform[TraceKeys.EXTRA_INFO][TraceKeys.EXTRA_INFO]["axes"]) - with flipper.trace_transform(False): - return flipper(data) + def inverse(self, data): + return invert(data, self.lazy_evaluation) class RandZoom(RandomizableTransform, InvertibleTransform, LazyTransform): @@ -1543,66 +1384,63 @@ class RandZoom(RandomizableTransform, InvertibleTransform, LazyTransform): dtype: data type for resampling computation. Defaults to ``float32``. If None, use the data type of input data. keep_size: Should keep original size (pad if needed), default is True. - lazy: a flag to indicate whether this transform should execute lazily or not. + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not. Defaults to False kwargs: other arguments for the `np.pad` or `torch.pad` function. note that `np.pad` treats channel dimension as the first dimension. - """ - backend = Zoom.backend - def __init__( - self, - prob: float = 0.1, - min_zoom: Sequence[float] | float = 0.9, - max_zoom: Sequence[float] | float = 1.1, - mode: str = InterpolateMode.AREA, - padding_mode: str = NumpyPadMode.EDGE, - align_corners: bool | None = None, - dtype: DtypeLike | torch.dtype = torch.float32, - keep_size: bool = True, - lazy: bool = False, - **kwargs, + self, + prob: float = 0.1, + min_zoom: Sequence[float] | float = 0.9, + max_zoom: Sequence[float] | float = 1.1, + mode: str = InterpolateMode.AREA, + padding_mode: str = NumpyPadMode.EDGE, + align_corners: bool | None = None, + keep_size: bool = True, + lazy_evaluation: bool = False, + seed: int | None = None, + state: np.random.RandomState | None = None, + **kwargs, ) -> None: RandomizableTransform.__init__(self, prob) - LazyTransform.__init__(self, lazy=lazy) + LazyTransform.__init__(self, lazy_evaluation) self.min_zoom = ensure_tuple(min_zoom) self.max_zoom = ensure_tuple(max_zoom) if len(self.min_zoom) != len(self.max_zoom): raise ValueError( f"min_zoom and max_zoom must have same length, got {len(self.min_zoom)} and {len(self.max_zoom)}." ) - self.mode = mode + self.mode: InterpolateMode = look_up_option(mode, InterpolateMode) self.padding_mode = padding_mode self.align_corners = align_corners - self.dtype = dtype self.keep_size = keep_size + self.lazy_evaluation = lazy_evaluation self.kwargs = kwargs - - self._zoom: Sequence[float] = [1.0] + self._rand_zoom = [1.0] def randomize(self, img: NdarrayOrTensor) -> None: super().randomize(None) if not self._do_transform: return None - self._zoom = [self.R.uniform(l, h) for l, h in zip(self.min_zoom, self.max_zoom)] - if len(self._zoom) == 1: + self._rand_zoom = [self.R.uniform(l, h) for l, h in zip(self.min_zoom, self.max_zoom)] + if len(self.rand_zoom) == 1: # to keep the spatial shape ratio, use same random zoom factor for all dims - self._zoom = ensure_tuple_rep(self._zoom[0], img.ndim - 1) - elif len(self._zoom) == 2 and img.ndim > 3: + self._rand_zoom = ensure_tuple_rep(self._rand_zoom[0], img.ndim - 1) + elif len(self._rand_zoom) == 2 and img.ndim > 3: # if 2 zoom factors provided for 3D data, use the first factor for H and W dims, second factor for D dim - self._zoom = ensure_tuple_rep(self._zoom[0], img.ndim - 2) + ensure_tuple(self._zoom[-1]) + self._rand_zoom = ensure_tuple_rep(self._rand_zoom[0], img.ndim - 2) + ensure_tuple(self._rand_zoom[-1]) def __call__( - self, - img: torch.Tensor, - mode: str | None = None, - padding_mode: str | None = None, - align_corners: bool | None = None, - dtype: DtypeLike | torch.dtype = None, - randomize: bool = True, - lazy: bool | None = None, + self, + img: torch.Tensor, + mode: str | None = None, + padding_mode: str | None = None, + align_corners: bool | None = None, + randomize: bool = True, + shape_override: Sequence = None, + lazy_evaluation: bool | None = None, ) -> torch.Tensor: """ Args: @@ -1623,37 +1461,26 @@ def __call__( dtype: data type for resampling computation. Defaults to ``self.dtype``. If None, use the data type of input data. randomize: whether to execute `randomize()` function first, default to True. - lazy: a flag to indicate whether this transform should execute lazily or not - during this call. Setting this to False or True overrides the ``lazy`` flag set + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy_evaluation`` flag set during initialization for this call. Defaults to None. """ # match the spatial image dim if randomize: self.randomize(img=img) - lazy_ = self.lazy if lazy is None else lazy - if not self._do_transform: - out = convert_to_tensor(img, track_meta=get_track_meta(), dtype=torch.float32) + lazy_ = self.lazy_evaluation if lazy_evaluation is None else lazy_evaluation + if self._do_transform: + mode_ = mode or self.mode + padding_mode_ = padding_mode or self.padding_mode + align_corners_ = align_corners or self.align_corners + + return zoom(img, self._rand_zoom, mode_, padding_mode_, align_corners_, self.keep_size, lazy_evaluation=lazy_) else: - xform = Zoom( - self._zoom, - keep_size=self.keep_size, - mode=mode or self.mode, - padding_mode=padding_mode or self.padding_mode, - align_corners=self.align_corners if align_corners is None else align_corners, - dtype=dtype or self.dtype, - lazy=lazy_, - **self.kwargs, - ) - out = xform(img) - self.push_transform(out, replace=True, lazy=lazy_) - return out # type: ignore + return identity(img, None, None, lazy_evaluation=lazy_) - def inverse(self, data: torch.Tensor) -> torch.Tensor: - xform_info = self.pop_transform(data) - if not xform_info[TraceKeys.DO_TRANSFORM]: - return data - return Zoom(self._zoom).inverse_transform(data, xform_info[TraceKeys.EXTRA_INFO]) + def inverse(self, data): + return invert(data, self.lazy_evaluation) class AffineGrid(LazyTransform): @@ -1688,7 +1515,7 @@ class AffineGrid(LazyTransform): affine: If applied, ignore the params (`rotate_params`, etc.) and use the supplied matrix. Should be square with each side = num of image spatial dimensions + 1. - lazy: a flag to indicate whether this transform should execute lazily or not. + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not. Defaults to False """ @@ -1704,9 +1531,9 @@ def __init__( dtype: DtypeLike = np.float32, align_corners: bool = False, affine: NdarrayOrTensor | None = None, - lazy: bool = False, + lazy_evaluation: bool = False, ) -> None: - LazyTransform.__init__(self, lazy=lazy) + LazyTransform.__init__(self, lazy_evaluation=lazy_evaluation) self.rotate_params = rotate_params self.shear_params = shear_params self.translate_params = translate_params @@ -1718,7 +1545,7 @@ def __init__( self.affine = affine def __call__( - self, spatial_size: Sequence[int] | None = None, grid: torch.Tensor | None = None, lazy: bool | None = None + self, spatial_size: Sequence[int] | None = None, grid: torch.Tensor | None = None, lazy_evaluation: bool | None = None ) -> tuple[torch.Tensor | None, torch.Tensor]: """ The grid can be initialized with a `spatial_size` parameter, or provided directly as `grid`. @@ -1728,14 +1555,14 @@ def __call__( Args: spatial_size: output grid size. grid: grid to be transformed. Shape must be (3, H, W) for 2D or (4, H, W, D) for 3D. - lazy: a flag to indicate whether this transform should execute lazily or not - during this call. Setting this to False or True overrides the ``lazy`` flag set + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy_evaluation`` flag set during initialization for this call. Defaults to None. Raises: ValueError: When ``grid=None`` and ``spatial_size=None``. Incompatible values. """ - lazy_ = self.lazy if lazy is None else lazy + lazy_ = self.lazy_evaluation if lazy_evaluation is None else lazy_evaluation if not lazy_: if grid is None: # create grid from spatial_size if spatial_size is None: @@ -1798,7 +1625,7 @@ def __init__( scale_range: RandRange = None, device: torch.device | None = None, dtype: DtypeLike = np.float32, - lazy: bool = False, + lazy_evaluation: bool = False, ) -> None: """ Args: @@ -1827,7 +1654,7 @@ def __init__( device: device to store the output grid data. dtype: data type for the grid computation. Defaults to ``np.float32``. If ``None``, use the data type of input data (if `grid` is provided). - lazy: a flag to indicate whether this transform should execute lazily or not. + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not. Defaults to False See also: @@ -1837,7 +1664,7 @@ def __init__( - :py:meth:`monai.transforms.utils.create_scale` """ - LazyTransform.__init__(self, lazy=lazy) + LazyTransform.__init__(self, lazy_evaluation=lazy_evaluation) self.rotate_range = ensure_tuple(rotate_range) self.shear_range = ensure_tuple(shear_range) self.translate_range = ensure_tuple(translate_range) @@ -1874,15 +1701,15 @@ def __call__( spatial_size: Sequence[int] | None = None, grid: NdarrayOrTensor | None = None, randomize: bool = True, - lazy: bool | None = None, + lazy_evaluation: bool | None = None, ) -> torch.Tensor: """ Args: spatial_size: output grid size. grid: grid to be transformed. Shape must be (3, H, W) for 2D or (4, H, W, D) for 3D. randomize: boolean as to whether the grid parameters governing the grid should be randomized. - lazy: a flag to indicate whether this transform should execute lazily or not - during this call. Setting this to False or True overrides the ``lazy`` flag set + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy_evaluation`` flag set during initialization for this call. Defaults to None. Returns: @@ -1890,7 +1717,7 @@ def __call__( """ if randomize: self.randomize() - lazy_ = self.lazy if lazy is None else lazy + lazy_ = self.lazy_evaluation if lazy_evaluation is None else lazy_evaluation affine_grid = AffineGrid( rotate_params=self.rotate_params, shear_params=self.shear_params, @@ -1898,7 +1725,7 @@ def __call__( scale_params=self.scale_params, device=self.device, dtype=self.dtype, - lazy=lazy_, + lazy_evaluation=lazy_, ) if lazy_: # return the affine only, don't construct the grid self.affine = affine_grid(spatial_size, grid)[1] # type: ignore @@ -2140,7 +1967,7 @@ def __init__( dtype: DtypeLike = np.float32, align_corners: bool = False, image_only: bool = False, - lazy: bool = False, + lazy_evaluation: bool = False, ) -> None: """ The affine transformations are applied in rotate, shear, translate, scale order. @@ -2195,10 +2022,10 @@ def __init__( align_corners: Defaults to False. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html image_only: if True return only the image volume, otherwise return (image, affine). - lazy: a flag to indicate whether this transform should execute lazily or not. + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not. Defaults to False """ - LazyTransform.__init__(self, lazy=lazy) + LazyTransform.__init__(self, lazy_evaluation=lazy_evaluation) self.affine_grid = AffineGrid( rotate_params=rotate_params, shear_params=shear_params, @@ -2208,7 +2035,7 @@ def __init__( dtype=dtype, align_corners=align_corners, device=device, - lazy=lazy, + lazy_evaluation=lazy_evaluation, ) self.image_only = image_only self.norm_coord = not normalized @@ -2217,9 +2044,9 @@ def __init__( self.mode = mode self.padding_mode: str = padding_mode - @LazyTransform.lazy.setter # type: ignore - def lazy(self, val: bool) -> None: - self.affine_grid.lazy = val + @LazyTransform.lazy_evaluation.setter # type: ignore + def lazy_evaluation(self, val: bool) -> None: + self.affine_grid.lazy_evaluation = val self._lazy = val def __call__( @@ -2228,7 +2055,7 @@ def __call__( spatial_size: Sequence[int] | int | None = None, mode: str | int | None = None, padding_mode: str | None = None, - lazy: bool | None = None, + lazy_evaluation: bool | None = None, ) -> torch.Tensor | tuple[torch.Tensor, NdarrayOrTensor]: """ Args: @@ -2250,17 +2077,17 @@ def __call__( When `mode` is an integer, using numpy/cupy backends, this argument accepts {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}. See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html - lazy: a flag to indicate whether this transform should execute lazily or not - during this call. Setting this to False or True overrides the ``lazy`` flag set + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy_evaluation`` flag set during initialization for this call. Defaults to None. """ img = convert_to_tensor(img, track_meta=get_track_meta()) img_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] sp_size = fall_back_tuple(self.spatial_size if spatial_size is None else spatial_size, img_size) - lazy_ = self.lazy if lazy is None else lazy + lazy_ = self.lazy_evaluation if lazy_evaluation is None else lazy_evaluation _mode = mode if mode is not None else self.mode _padding_mode = padding_mode if padding_mode is not None else self.padding_mode - grid, affine = self.affine_grid(spatial_size=sp_size, lazy=lazy_) + grid, affine = self.affine_grid(spatial_size=sp_size, lazy_evaluation=lazy_) return affine_func( # type: ignore img, @@ -2272,7 +2099,7 @@ def __call__( _padding_mode, True, self.image_only, - lazy=lazy_, + lazy_evaluation=lazy_, transform_info=self.get_transform_info(), ) @@ -2334,7 +2161,7 @@ def __init__( padding_mode: str = GridSamplePadMode.REFLECTION, cache_grid: bool = False, device: torch.device | None = None, - lazy: bool = False, + lazy_evaluation: bool = False, ) -> None: """ Args: @@ -2384,7 +2211,7 @@ def __init__( If the spatial size is not dynamically defined by input image, enabling this option could accelerate the transform. device: device on which the tensor will be allocated. - lazy: a flag to indicate whether this transform should execute lazily or not. + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not. Defaults to False See also: @@ -2393,33 +2220,33 @@ def __init__( """ RandomizableTransform.__init__(self, prob) - LazyTransform.__init__(self, lazy=lazy) + LazyTransform.__init__(self, lazy_evaluation=lazy_evaluation) self.rand_affine_grid = RandAffineGrid( rotate_range=rotate_range, shear_range=shear_range, translate_range=translate_range, scale_range=scale_range, device=device, - lazy=lazy, + lazy_evaluation=lazy_evaluation, ) self.resampler = Resample(device=device) self.spatial_size = spatial_size self.cache_grid = cache_grid - self._cached_grid = self._init_identity_cache(lazy) + self._cached_grid = self._init_identity_cache(lazy_evaluation) self.mode = mode self.padding_mode: str = padding_mode - @LazyTransform.lazy.setter # type: ignore - def lazy(self, val: bool) -> None: + @LazyTransform.lazy_evaluation.setter # type: ignore + def lazy_evaluation(self, val: bool) -> None: self._lazy = val - self.rand_affine_grid.lazy = val + self.rand_affine_grid.lazy_evaluation = val - def _init_identity_cache(self, lazy: bool): + def _init_identity_cache(self, lazy_evaluation: bool): """ Create cache of the identity grid if cache_grid=True and spatial_size is known. """ - if lazy: + if lazy_evaluation: return None if self.spatial_size is None: if self.cache_grid: @@ -2439,14 +2266,14 @@ def _init_identity_cache(self, lazy: bool): return None return create_grid(spatial_size=_sp_size, device=self.rand_affine_grid.device, backend="torch") - def get_identity_grid(self, spatial_size: Sequence[int], lazy: bool): + def get_identity_grid(self, spatial_size: Sequence[int], lazy_evaluation: bool): """ Return a cached or new identity grid depends on the availability. Args: spatial_size: non-dynamic spatial size """ - if lazy: + if lazy_evaluation: return None ndim = len(spatial_size) if spatial_size != fall_back_tuple(spatial_size, [1] * ndim) or spatial_size != fall_back_tuple( @@ -2478,7 +2305,7 @@ def __call__( padding_mode: str | None = None, randomize: bool = True, grid=None, - lazy: bool | None = None, + lazy_evaluation: bool | None = None, ) -> torch.Tensor: """ Args: @@ -2502,8 +2329,8 @@ def __call__( See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html randomize: whether to execute `randomize()` function first, default to True. grid: precomputed grid to be used (mainly to accelerate `RandAffined`). - lazy: a flag to indicate whether this transform should execute lazily or not - during this call. Setting this to False or True overrides the ``lazy`` flag set + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy_evaluation`` flag set during initialization for this call. Defaults to None. """ if randomize: @@ -2515,12 +2342,12 @@ def __call__( do_resampling = self._do_transform or (sp_size != ensure_tuple(ori_size)) _mode = mode if mode is not None else self.mode _padding_mode = padding_mode if padding_mode is not None else self.padding_mode - lazy_ = self.lazy if lazy is None else lazy + lazy_ = self.lazy_evaluation if lazy_evaluation is None else lazy_evaluation img = convert_to_tensor(img, track_meta=get_track_meta()) if lazy_: if self._do_transform: if grid is None: - self.rand_affine_grid(sp_size, randomize=randomize, lazy=True) + self.rand_affine_grid(sp_size, randomize=randomize, lazy_evaluation=True) affine = self.rand_affine_grid.get_transformation_matrix() else: affine = convert_to_dst_type(torch.eye(len(sp_size) + 1), img, dtype=self.rand_affine_grid.dtype)[0] @@ -2528,7 +2355,7 @@ def __call__( if grid is None: grid = self.get_identity_grid(sp_size, lazy_) if self._do_transform: - grid = self.rand_affine_grid(grid=grid, randomize=randomize, lazy=lazy_) + grid = self.rand_affine_grid(grid=grid, randomize=randomize, lazy_evaluation=lazy_) affine = self.rand_affine_grid.get_transformation_matrix() return affine_func( # type: ignore img, @@ -2540,7 +2367,7 @@ def __call__( _padding_mode, do_resampling, True, - lazy=lazy_, + lazy_evaluation=lazy_, transform_info=self.get_transform_info(), ) @@ -2656,7 +2483,7 @@ def __init__( translate_range=translate_range, scale_range=scale_range, device=device, - lazy=False, + lazy_evaluation=False, ) self.resampler = Resample(device=device) @@ -2824,7 +2651,7 @@ def __init__( translate_range=translate_range, scale_range=scale_range, device=device, - lazy=False, + lazy_evaluation=False, ) self.resampler = Resample(device=device) @@ -3544,3 +3371,24 @@ def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: else: return img + + +class TransformLike(InvertibleTransform, LazyTransform): + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__(self, lazy_evaluation): + super().__init__(lazy_evaluation) + + + def __call__(self, data, reference, lazy_evaluation=None): + + lazy_ = self.lazy_evaluation if lazy_evaluation is None else self.lazy_evaluation + data = convert_to_tensor(data, track_meta=get_track_meta()) + + data_t = transform_like(data=data, reference=reference, lazy_evaluation=lazy_) + + return data_t + + def inverse(self, data): + return invert(data, self.lazy_evaluation) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 01fadcfb69..82edd5296d 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -31,6 +31,16 @@ from monai.networks.layers.simplelayers import GaussianFilter from monai.transforms.croppad.array import CenterSpatialCrop from monai.transforms.inverse import InvertibleTransform +from monai.transforms.lazy import invert +from monai.transforms.spatial.functional import ( + flip, + identity, + resize, + rotate, + rotate90, + transform_like, + zoom, +) from monai.transforms.spatial.array import ( Affine, Flip, @@ -176,7 +186,7 @@ def __init__( dtype: Sequence[DtypeLike] | DtypeLike = np.float64, dst_keys: KeysCollection | None = "dst_affine", allow_missing_keys: bool = False, - lazy: bool = False, + lazy_evaluation: bool = False, ) -> None: """ Args: @@ -204,37 +214,37 @@ def __init__( It also can be a sequence of dtypes, each element corresponds to a key in ``keys``. dst_keys: the key of the corresponding ``dst_affine`` in the metadata dictionary. allow_missing_keys: don't raise exception if key is missing. - lazy: a flag to indicate whether this transform should execute lazily or not. + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not. Defaults to False. """ MapTransform.__init__(self, keys, allow_missing_keys) - LazyTransform.__init__(self, lazy=lazy) - self.sp_transform = SpatialResample(lazy=lazy) + LazyTransform.__init__(self, lazy_evaluation=lazy_evaluation) + self.sp_transform = SpatialResample(lazy_evaluation=lazy_evaluation) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.dtype = ensure_tuple_rep(dtype, len(self.keys)) self.dst_keys = ensure_tuple_rep(dst_keys, len(self.keys)) - @LazyTransform.lazy.setter # type: ignore - def lazy(self, val: bool) -> None: + @LazyTransform.lazy_evaluation.setter # type: ignore + def lazy_evaluation(self, val: bool) -> None: self._lazy = val - self.sp_transform.lazy = val + self.sp_transform.lazy_evaluation = val - def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy_evaluation: bool | None = None) -> dict[Hashable, torch.Tensor]: """ Args: data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified in this dictionary must be tensor like arrays that are channel first and have at most three spatial dimensions - lazy: a flag to indicate whether this transform should execute lazily or not - during this call. Setting this to False or True overrides the ``lazy`` flag set + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy_evaluation`` flag set during initialization for this call. Defaults to None. Returns: a dictionary containing the transformed data, as well as any other data present in the dictionary """ - lazy_ = self.lazy if lazy is None else lazy + lazy_ = self.lazy_evaluation if lazy_evaluation is None else lazy_evaluation d: dict = dict(data) for key, mode, padding_mode, align_corners, dtype, dst_key in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners, self.dtype, self.dst_keys @@ -247,7 +257,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = No padding_mode=padding_mode, align_corners=align_corners, dtype=dtype, - lazy=lazy_, + lazy_evaluation=lazy_, ) return d @@ -278,7 +288,7 @@ def __init__( align_corners: Sequence[bool] | bool = False, dtype: Sequence[DtypeLike] | DtypeLike = np.float64, allow_missing_keys: bool = False, - lazy: bool = False, + lazy_evaluation: bool = False, ): """ Args: @@ -306,37 +316,37 @@ def __init__( the output data type is always ``float32``. It also can be a sequence of dtypes, each element corresponds to a key in ``keys``. allow_missing_keys: don't raise exception if key is missing. - lazy: a flag to indicate whether this transform should execute lazily or not. + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not. Defaults to False """ MapTransform.__init__(self, keys, allow_missing_keys) - LazyTransform.__init__(self, lazy=lazy) + LazyTransform.__init__(self, lazy_evaluation=lazy_evaluation) self.key_dst = key_dst self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.dtype = ensure_tuple_rep(dtype, len(self.keys)) - self.resampler = ResampleToMatch(lazy=lazy) + self.resampler = ResampleToMatch(lazy_evaluation=lazy_evaluation) - @LazyTransform.lazy.setter # type: ignore - def lazy(self, val: bool) -> None: + @LazyTransform.lazy_evaluation.setter # type: ignore + def lazy_evaluation(self, val: bool) -> None: self._lazy = val - self.resampler.lazy = val + self.resampler.lazy_evaluation = val - def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy_evaluation: bool | None = None) -> dict[Hashable, torch.Tensor]: """ Args: data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified in this dictionary must be tensor like arrays that are channel first and have at most three spatial dimensions - lazy: a flag to indicate whether this transform should execute lazily or not - during this call. Setting this to False or True overrides the ``lazy`` flag set + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy_evaluation`` flag set during initialization for this call. Defaults to None. Returns: a dictionary containing the transformed data, as well as any other data present in the dictionary """ - lazy_ = self.lazy if lazy is None else lazy + lazy_ = self.lazy_evaluation if lazy_evaluation is None else lazy_evaluation d = dict(data) for key, mode, padding_mode, align_corners, dtype in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners, self.dtype @@ -348,7 +358,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = No padding_mode=padding_mode, align_corners=align_corners, dtype=dtype, - lazy=lazy_, + lazy_evaluation=lazy_, ) return d @@ -393,7 +403,7 @@ def __init__( max_pixdim: Sequence[float] | float | None = None, ensure_same_shape: bool = True, allow_missing_keys: bool = False, - lazy: bool = False, + lazy_evaluation: bool = False, ) -> None: """ Args: @@ -453,18 +463,18 @@ def __init__( ensure_same_shape: when the inputs have the same spatial shape, and almost the same pixdim, whether to ensure exactly the same output spatial shape. Default to True. allow_missing_keys: don't raise exception if key is missing. - lazy: a flag to indicate whether this transform should execute lazily or not. + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not. Defaults to False """ MapTransform.__init__(self, keys, allow_missing_keys) - LazyTransform.__init__(self, lazy=lazy) + LazyTransform.__init__(self, lazy_evaluation=lazy_evaluation) self.spacing_transform = Spacing( pixdim, diagonal=diagonal, recompute_affine=recompute_affine, min_pixdim=min_pixdim, max_pixdim=max_pixdim, - lazy=lazy, + lazy_evaluation=lazy_evaluation, ) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) @@ -473,19 +483,19 @@ def __init__( self.scale_extent = ensure_tuple_rep(scale_extent, len(self.keys)) self.ensure_same_shape = ensure_same_shape - @LazyTransform.lazy.setter # type: ignore - def lazy(self, val: bool) -> None: + @LazyTransform.lazy_evaluation.setter # type: ignore + def lazy_evaluation(self, val: bool) -> None: self._lazy = val - self.spacing_transform.lazy = val + self.spacing_transform.lazy_evaluation = val - def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy_evaluation: bool | None = None) -> dict[Hashable, torch.Tensor]: """ Args: data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified in this dictionary must be tensor like arrays that are channel first and have at most three spatial dimensions - lazy: a flag to indicate whether this transform should execute lazily or not - during this call. Setting this to False or True overrides the ``lazy`` flag set + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy_evaluation`` flag set during initialization for this call. Defaults to None. Returns: @@ -495,7 +505,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = No _init_shape, _pixdim, should_match = None, None, False output_shape_k = None # tracking output shape - lazy_ = self.lazy if lazy is None else lazy + lazy_ = self.lazy_evaluation if lazy_evaluation is None else lazy_evaluation for key, mode, padding_mode, align_corners, dtype, scale_extent in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners, self.dtype, self.scale_extent @@ -515,7 +525,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = No dtype=dtype, scale_extent=scale_extent, output_spatial_shape=output_shape_k if should_match else None, - lazy=lazy_, + lazy_evaluation=lazy_, ) if output_shape_k is None: output_shape_k = d[key].peek_pending_shape() if isinstance(d[key], MetaTensor) else d[key].shape[1:] @@ -549,7 +559,7 @@ def __init__( as_closest_canonical: bool = False, labels: Sequence[tuple[str, str]] | None = (("L", "R"), ("P", "A"), ("I", "S")), allow_missing_keys: bool = False, - lazy: bool = False, + lazy_evaluation: bool = False, ) -> None: """ Args: @@ -563,7 +573,7 @@ def __init__( (2,) sequences are labels for (beginning, end) of output axis. Defaults to ``(('L', 'R'), ('P', 'A'), ('I', 'S'))``. allow_missing_keys: don't raise exception if key is missing. - lazy: a flag to indicate whether this transform should execute lazily or not. + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not. Defaults to False See Also: @@ -571,33 +581,33 @@ def __init__( """ MapTransform.__init__(self, keys, allow_missing_keys) - LazyTransform.__init__(self, lazy=lazy) + LazyTransform.__init__(self, lazy_evaluation=lazy_evaluation) self.ornt_transform = Orientation( - axcodes=axcodes, as_closest_canonical=as_closest_canonical, labels=labels, lazy=lazy + axcodes=axcodes, as_closest_canonical=as_closest_canonical, labels=labels, lazy_evaluation=lazy_evaluation ) - @LazyTransform.lazy.setter # type: ignore - def lazy(self, val: bool) -> None: + @LazyTransform.lazy_evaluation.setter # type: ignore + def lazy_evaluation(self, val: bool) -> None: self._lazy = val - self.ornt_transform.lazy = val + self.ornt_transform.lazy_evaluation = val - def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy_evaluation: bool | None = None) -> dict[Hashable, torch.Tensor]: """ Args: data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified in this dictionary must be tensor like arrays that are channel first and have at most three spatial dimensions - lazy: a flag to indicate whether this transform should execute lazily or not - during this call. Setting this to False or True overrides the ``lazy`` flag set + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy_evaluation`` flag set during initialization for this call. Defaults to None. Returns: a dictionary containing the transformed data, as well as any other data present in the dictionary """ d: dict = dict(data) - lazy_ = self.lazy if lazy is None else lazy + lazy_ = self.lazy_evaluation if lazy_evaluation is None else lazy_evaluation for key in self.key_iterator(d): - d[key] = self.ornt_transform(d[key], lazy=lazy_) + d[key] = self.ornt_transform(d[key], lazy_evaluation=lazy_) return d def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: @@ -618,12 +628,12 @@ class Rotate90d(MapTransform, InvertibleTransform, LazyTransform): backend = Rotate90.backend def __init__( - self, - keys: KeysCollection, - k: int = 1, - spatial_axes: tuple[int, int] = (0, 1), - allow_missing_keys: bool = False, - lazy: bool = False, + self, + keys: KeysCollection, + k: int = 1, + spatial_axes: tuple[int, int] = (0, 1), + allow_missing_keys: bool = False, + lazy_evaluation: bool = False, ) -> None: """ Args: @@ -631,42 +641,42 @@ def __init__( spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes. Default: (0, 1), this is the first two axis in spatial dimensions. allow_missing_keys: don't raise exception if key is missing. - lazy: a flag to indicate whether this transform should execute lazily or not. + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not. Defaults to False """ MapTransform.__init__(self, keys, allow_missing_keys) - LazyTransform.__init__(self, lazy=lazy) - self.rotator = Rotate90(k, spatial_axes, lazy=lazy) - - @LazyTransform.lazy.setter # type: ignore - def lazy(self, val: bool) -> None: - self._lazy = val - self.rotator.lazy = val + LazyTransform.__init__(self, lazy_evaluation) + self.k = k + self.spatial_axes = spatial_axes - def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: + def __call__( + self, + data: Mapping[Hashable, torch.Tensor], + lazy_evaluation: bool | None = None, + ) -> dict[Hashable, torch.Tensor]: """ Args: data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified in this dictionary must be tensor like arrays that are channel first and have at most three spatial dimensions - lazy: a flag to indicate whether this transform should execute lazily or not - during this call. Setting this to False or True overrides the ``lazy`` flag set + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy_evaluation`` flag set during initialization for this call. Defaults to None. Returns: a dictionary containing the transformed data, as well as any other data present in the dictionary """ - d = dict(data) - lazy_ = self.lazy if lazy is None else lazy - for key in self.key_iterator(d): - d[key] = self.rotator(d[key], lazy=lazy_) - return d + rd = dict(data) + lazy_ = self.lazy_evaluation if lazy_evaluation is None else lazy_evaluation + for key in self.key_iterator(rd): + rd[key] = rotate90(rd[key], self.k, self.spatial_axes, lazy_evaluation=self.lazy_evaluation) + return rd - def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: - d = dict(data) - for key in self.key_iterator(d): - d[key] = self.rotator.inverse(d[key]) - return d + def inverse(self, data): + rd = dict(data) + for key in self.key_iterator(rd): + rd[key] = invert(rd[key], self.lazy_evaluation) + return rd class RandRotate90d(RandomizableTransform, MapTransform, InvertibleTransform, LazyTransform): @@ -682,13 +692,15 @@ class RandRotate90d(RandomizableTransform, MapTransform, InvertibleTransform, La backend = Rotate90.backend def __init__( - self, - keys: KeysCollection, - prob: float = 0.1, - max_k: int = 3, - spatial_axes: tuple[int, int] = (0, 1), - allow_missing_keys: bool = False, - lazy: bool = False, + self, + keys: KeysCollection, + prob: float = 0.1, + max_k: int = 3, + spatial_axes: tuple[int, int] = (0, 1), + allow_missing_keys: bool = False, + lazy_evaluation: bool = False, + seed: int | None = None, + state: np.random.RandomState | None = None, ) -> None: """ Args: @@ -701,12 +713,12 @@ def __init__( spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes. Default: (0, 1), this is the first two axis in spatial dimensions. allow_missing_keys: don't raise exception if key is missing. - lazy: a flag to indicate whether this transform should execute lazily or not. + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not. Defaults to False """ MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) - LazyTransform.__init__(self, lazy=lazy) + LazyTransform.__init__(self, lazy_evaluation) self.max_k = max_k self.spatial_axes = spatial_axes @@ -718,41 +730,38 @@ def randomize(self, data: Any | None = None) -> None: super().randomize(None) def __call__( - self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None + self, + data: Mapping[Hashable, torch.Tensor], + lazy_evaluation: bool | None = None, ) -> Mapping[Hashable, torch.Tensor]: """ Args: data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified in this dictionary must be tensor like arrays that are channel first and have at most three spatial dimensions - lazy: a flag to indicate whether this transform should execute lazily or not - during this call. Setting this to False or True overrides the ``lazy`` flag set + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy_evaluation`` flag set during initialization for this call. Defaults to None. Returns: a dictionary containing the transformed data, as well as any other data present in the dictionary """ - self.randomize() - d = dict(data) + rd = dict(data) + lazy_ = self.lazy_evaluation if lazy_evaluation is None else lazy_evaluation + k = self.randomizer.sample() + for key in self.key_iterator(rd): + if self._do_transform: + rd[key] = rotate90(rd[key], k, self.spatial_axes, lazy_evaluation=lazy_) + else: + rd[key] = identity(rd[key], None, None, lazy_evaluation=lazy_) - # FIXME: here we didn't use array version `RandRotate90` transform as others, because we need - # to be compatible with the random status of some previous integration tests - lazy_ = self.lazy if lazy is None else lazy - rotator = Rotate90(self._rand_k, self.spatial_axes, lazy=lazy_) - for key in self.key_iterator(d): - d[key] = rotator(d[key]) if self._do_transform else convert_to_tensor(d[key], track_meta=get_track_meta()) - self.push_transform(d[key], replace=True, lazy=lazy_) - return d + return rd - def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: - d = dict(data) - for key in self.key_iterator(d): - if not isinstance(d[key], MetaTensor): - continue - xform = self.pop_transform(d[key]) - if xform[TraceKeys.DO_TRANSFORM]: - d[key] = Rotate90().inverse_transform(d[key], xform[TraceKeys.EXTRA_INFO]) - return d + def inverse(self, data): + rd = dict(data) + for key in self.key_iterator(rd): + rd[key] = invert(rd[key], self.lazy_evaluation) + return rd class Resized(MapTransform, InvertibleTransform, LazyTransform): @@ -794,73 +803,66 @@ class Resized(MapTransform, InvertibleTransform, LazyTransform): dtype: data type for resampling computation. Defaults to ``float32``. If None, use the data type of input data. allow_missing_keys: don't raise exception if key is missing. - lazy: a flag to indicate whether this transform should execute lazily or not. + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not. Defaults to False """ backend = Resize.backend def __init__( - self, - keys: KeysCollection, - spatial_size: Sequence[int] | int, - size_mode: str = "all", - mode: SequenceStr = InterpolateMode.AREA, - align_corners: Sequence[bool | None] | bool | None = None, - anti_aliasing: Sequence[bool] | bool = False, - anti_aliasing_sigma: Sequence[Sequence[float] | float | None] | Sequence[float] | float | None = None, - dtype: Sequence[DtypeLike | torch.dtype] | DtypeLike | torch.dtype = np.float32, - allow_missing_keys: bool = False, - lazy: bool = False, + self, + keys: KeysCollection, + spatial_size: Sequence[int] | int, + size_mode: str = "all", + mode: SequenceStr = InterpolateMode.AREA, + align_corners: Sequence[bool | None] | bool | None = None, + anti_aliasing: Sequence[bool] | bool = False, + anti_aliasing_sigma: Sequence[Sequence[float] | float | None] | Sequence[float] | float | None = None, + allow_missing_keys: bool = False, + lazy_evaluation: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) - LazyTransform.__init__(self, lazy=lazy) + LazyTransform.__init__(self, lazy_evaluation) + + self.spatial_size = spatial_size + self.size_mode = size_mode self.mode = ensure_tuple_rep(mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) - self.dtype = ensure_tuple_rep(dtype, len(self.keys)) self.anti_aliasing = ensure_tuple_rep(anti_aliasing, len(self.keys)) self.anti_aliasing_sigma = ensure_tuple_rep(anti_aliasing_sigma, len(self.keys)) - self.resizer = Resize(spatial_size=spatial_size, size_mode=size_mode, lazy=lazy) - - @LazyTransform.lazy.setter # type: ignore - def lazy(self, val: bool) -> None: - self._lazy = val - self.resizer.lazy = val - def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: + def __call__( + self, + data: Mapping[Hashable, torch.Tensor], + lazy_evaluation: bool | None = None, + ) -> dict[Hashable, torch.Tensor]: """ Args: data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified in this dictionary must be tensor like arrays that are channel first and have at most three spatial dimensions - lazy: a flag to indicate whether this transform should execute lazily or not - during this call. Setting this to False or True overrides the ``lazy`` flag set + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy_evaluation`` flag set during initialization for this call. Defaults to None. Returns: a dictionary containing the transformed data, as well as any other data present in the dictionary """ - d = dict(data) - lazy_ = self.lazy if lazy is None else lazy - for key, mode, align_corners, anti_aliasing, anti_aliasing_sigma, dtype in self.key_iterator( - d, self.mode, self.align_corners, self.anti_aliasing, self.anti_aliasing_sigma, self.dtype + rd = dict(data) + lazy_ = self.lazy_evaluation if lazy_evaluation is None else lazy_evaluation + for key, mode, align_corners, anti_aliasing, anti_aliasing_sigma in self.key_iterator( + rd, self.mode, self.align_corners, self.anti_aliasing, self.anti_aliasing_sigma ): - d[key] = self.resizer( - d[key], - mode=mode, - align_corners=align_corners, - anti_aliasing=anti_aliasing, - anti_aliasing_sigma=anti_aliasing_sigma, - dtype=dtype, - lazy=lazy_, - ) - return d + rd[key] = resize(rd[key], spatial_size=self.spatial_size, size_mode=self.size_mode, mode=mode, + align_corners=align_corners, anti_aliasing=anti_aliasing, + anti_aliasing_sigma=anti_aliasing_sigma, lazy_evaluation=lazy_) + return rd - def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: - d = dict(data) - for key in self.key_iterator(d): - d[key] = self.resizer.inverse(d[key]) - return d + def inverse(self, data): + rd = dict(data) + for key in self.key_iterator(rd): + rd[key] = invert(rd[key], self.lazy_evaluation) + return rd class Affined(MapTransform, InvertibleTransform, LazyTransform): @@ -888,7 +890,7 @@ def __init__( dtype: DtypeLike | torch.dtype = np.float32, align_corners: bool = False, allow_missing_keys: bool = False, - lazy: bool = False, + lazy_evaluation: bool = False, ) -> None: """ Args: @@ -939,7 +941,7 @@ def __init__( align_corners: Defaults to False. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html allow_missing_keys: don't raise exception if key is missing. - lazy: a flag to indicate whether this transform should execute lazily or not. + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not. Defaults to False See also: @@ -948,7 +950,7 @@ def __init__( """ MapTransform.__init__(self, keys, allow_missing_keys) - LazyTransform.__init__(self, lazy=lazy) + LazyTransform.__init__(self, lazy_evaluation=lazy_evaluation) self.affine = Affine( rotate_params=rotate_params, shear_params=shear_params, @@ -959,33 +961,33 @@ def __init__( device=device, dtype=dtype, # type: ignore align_corners=align_corners, - lazy=lazy, + lazy_evaluation=lazy_evaluation, ) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) - @LazyTransform.lazy.setter # type: ignore - def lazy(self, val: bool) -> None: + @LazyTransform.lazy_evaluation.setter # type: ignore + def lazy_evaluation(self, val: bool) -> None: self._lazy = val - self.affine.lazy = val + self.affine.lazy_evaluation = val - def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy_evaluation: bool | None = None) -> dict[Hashable, torch.Tensor]: """ Args: data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified in this dictionary must be tensor like arrays that are channel first and have at most three spatial dimensions - lazy: a flag to indicate whether this transform should execute lazily or not - during this call. Setting this to False or True overrides the ``lazy`` flag set + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy_evaluation`` flag set during initialization for this call. Defaults to None. Returns: a dictionary containing the transformed data, as well as any other data present in the dictionary """ - lazy_ = self.lazy if lazy is None else lazy + lazy_ = self.lazy_evaluation if lazy_evaluation is None else lazy_evaluation d = dict(data) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): - d[key], _ = self.affine(d[key], mode=mode, padding_mode=padding_mode, lazy=lazy_) + d[key], _ = self.affine(d[key], mode=mode, padding_mode=padding_mode, lazy_evaluation=lazy_) return d def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: @@ -1019,7 +1021,7 @@ def __init__( cache_grid: bool = False, device: torch.device | None = None, allow_missing_keys: bool = False, - lazy: bool = False, + lazy_evaluation: bool = False, ) -> None: """ Args: @@ -1073,7 +1075,7 @@ def __init__( accelerate the transform. device: device on which the tensor will be allocated. allow_missing_keys: don't raise exception if key is missing. - lazy: a flag to indicate whether this transform should execute lazily or not. + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not. Defaults to False See also: @@ -1083,7 +1085,7 @@ def __init__( """ MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) - LazyTransform.__init__(self, lazy=lazy) + LazyTransform.__init__(self, lazy_evaluation=lazy_evaluation) self.rand_affine = RandAffine( prob=1.0, # because probability handled in this class rotate_range=rotate_range, @@ -1093,15 +1095,15 @@ def __init__( spatial_size=spatial_size, cache_grid=cache_grid, device=device, - lazy=lazy, + lazy_evaluation=lazy_evaluation, ) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) - @LazyTransform.lazy.setter # type: ignore - def lazy(self, val: bool) -> None: + @LazyTransform.lazy_evaluation.setter # type: ignore + def lazy_evaluation(self, val: bool) -> None: self._lazy = val - self.rand_affine.lazy = val + self.rand_affine.lazy_evaluation = val def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandAffined: self.rand_affine.set_random_state(seed, state) @@ -1109,15 +1111,15 @@ def set_random_state(self, seed: int | None = None, state: np.random.RandomState return self def __call__( - self, data: Mapping[Hashable, NdarrayOrTensor], lazy: bool | None = None + self, data: Mapping[Hashable, NdarrayOrTensor], lazy_evaluation: bool | None = None ) -> dict[Hashable, NdarrayOrTensor]: """ Args: data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified in this dictionary must be tensor like arrays that are channel first and have at most three spatial dimensions - lazy: a flag to indicate whether this transform should execute lazily or not - during this call. Setting this to False or True overrides the ``lazy`` flag set + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy_evaluation`` flag set during initialization for this call. Defaults to None. Returns: @@ -1135,7 +1137,7 @@ def __call__( item = d[first_key] spatial_size = item.peek_pending_shape() if isinstance(item, MetaTensor) else item.shape[1:] - lazy_ = self.lazy if lazy is None else lazy + lazy_ = self.lazy_evaluation if lazy_evaluation is None else lazy_evaluation sp_size = fall_back_tuple(self.rand_affine.spatial_size, spatial_size) # change image size or do random transform @@ -1143,19 +1145,19 @@ def __call__( # converting affine to tensor because the resampler currently only support torch backend grid = None if do_resampling: # need to prepare grid - grid = self.rand_affine.get_identity_grid(sp_size, lazy=lazy_) + grid = self.rand_affine.get_identity_grid(sp_size, lazy_evaluation=lazy_) if self._do_transform: # add some random factors - grid = self.rand_affine.rand_affine_grid(sp_size, grid=grid, lazy=lazy_) + grid = self.rand_affine.rand_affine_grid(sp_size, grid=grid, lazy_evaluation=lazy_) grid = 0 if grid is None else grid # always provide a grid to self.rand_affine for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): # do the transform if do_resampling: - d[key] = self.rand_affine(d[key], None, mode, padding_mode, True, grid, lazy=lazy_) # type: ignore + d[key] = self.rand_affine(d[key], None, mode, padding_mode, True, grid, lazy_evaluation=lazy_) # type: ignore else: d[key] = convert_to_tensor(d[key], track_meta=get_track_meta(), dtype=torch.float32) self._do_transform = do_resampling # TODO: unify self._do_transform and do_resampling - self.push_transform(d[key], replace=True, lazy=lazy_) + self.push_transform(d[key], replace=True, lazy_evaluation=lazy_) return d def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: @@ -1479,52 +1481,51 @@ class Flipd(MapTransform, InvertibleTransform, LazyTransform): keys: Keys to pick data for transformation. spatial_axis: Spatial axes along which to flip over. Default is None. allow_missing_keys: don't raise exception if key is missing. - lazy: a flag to indicate whether this transform should execute lazily or not. + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not. Defaults to False """ backend = Flip.backend def __init__( - self, - keys: KeysCollection, - spatial_axis: Sequence[int] | int | None = None, - allow_missing_keys: bool = False, - lazy: bool = False, + self, + keys: KeysCollection, + spatial_axis: Sequence[int] | int | None = None, + allow_missing_keys: bool = False, + lazy_evaluation: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) - LazyTransform.__init__(self, lazy=lazy) - self.flipper = Flip(spatial_axis=spatial_axis) - - @LazyTransform.lazy.setter # type: ignore - def lazy(self, val: bool): - self.flipper.lazy = val - self._lazy = val + LazyTransform.__init__(self, lazy_evaluation) + self.spatial_axis = spatial_axis - def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: + def __call__( + self, + data: Mapping[Hashable, torch.Tensor], + lazy_evaluation: bool | None = None, + ) -> dict[Hashable, torch.Tensor]: """ Args: data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified in this dictionary must be tensor like arrays that are channel first and have at most three spatial dimensions - lazy: a flag to indicate whether this transform should execute lazily or not - during this call. Setting this to False or True overrides the ``lazy`` flag set + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy_evaluation`` flag set during initialization for this call. Defaults to None. Returns: a dictionary containing the transformed data, as well as any other data present in the dictionary """ - d = dict(data) - lazy_ = self.lazy if lazy is None else lazy - for key in self.key_iterator(d): - d[key] = self.flipper(d[key], lazy=lazy_) - return d + rd = dict(data) + lazy_ = self.lazy_evaluation if lazy_evaluation is None else lazy_evaluation + for key in self.key_iterator(rd): + rd[key] = flip(rd[key], self.spatial_axis, lazy_evaluation=lazy_) + return rd - def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: - d = dict(data) - for key in self.key_iterator(d): - d[key] = self.flipper.inverse(d[key]) - return d + def inverse(self, data): + rd = dict(data) + for key in self.key_iterator(rd): + rd[key] = invert(rd[key], self.lazy_evaluation) + return rd class RandFlipd(RandomizableTransform, MapTransform, InvertibleTransform, LazyTransform): @@ -1542,68 +1543,61 @@ class RandFlipd(RandomizableTransform, MapTransform, InvertibleTransform, LazyTr prob: Probability of flipping. spatial_axis: Spatial axes along which to flip over. Default is None. allow_missing_keys: don't raise exception if key is missing. - lazy: a flag to indicate whether this transform should execute lazily or not. + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not. Defaults to False """ backend = Flip.backend def __init__( - self, - keys: KeysCollection, - prob: float = 0.1, - spatial_axis: Sequence[int] | int | None = None, - allow_missing_keys: bool = False, - lazy: bool = False, + self, + keys: KeysCollection, + prob: float = 0.1, + spatial_axis: Sequence[int] | int | None = None, + allow_missing_keys: bool = False, + seed: int | None = None, + state: np.random.RandomState | None = None, + lazy_evaluation: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) - LazyTransform.__init__(self, lazy=lazy) - self.flipper = Flip(spatial_axis=spatial_axis, lazy=lazy) + LazyTransform.__init__(self, lazy_evaluation) - @LazyTransform.lazy.setter # type: ignore - def lazy(self, val: bool): - self.flipper.lazy = val - self._lazy = val + self.spatial_axis = spatial_axis - def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandFlipd: - super().set_random_state(seed, state) - return self - - def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: + def __call__( + self, + data: Mapping[Hashable, torch.Tensor], + lazy_evaluation: bool | None = None, + ) -> dict[Hashable, torch.Tensor]: """ Args: data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified in this dictionary must be tensor like arrays that are channel first and have at most three spatial dimensions - lazy: a flag to indicate whether this transform should execute lazily or not - during this call. Setting this to False or True overrides the ``lazy`` flag set + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy_evaluation`` flag set during initialization for this call. Defaults to None. Returns: a dictionary containing the transformed data, as well as any other data present in the dictionary """ - d = dict(data) - self.randomize(None) - - lazy_ = self.lazy if lazy is None else lazy - for key in self.key_iterator(d): + rd = dict(data) + self.randomize() + lazy_ = self.lazy_evaluation if lazy_evaluation is None else lazy_evaluation + for key in self.key_iterator(rd): if self._do_transform: - d[key] = self.flipper(d[key], lazy=lazy_) + rd[key] = flip(rd[key], self.spatial_axis, lazy_evaluation=lazy_) else: - d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) - self.push_transform(d[key], replace=True, lazy=lazy_) - return d + rd[key] = identity(rd[key], None, None, lazy_evaluation=lazy_) - def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: - d = dict(data) - for key in self.key_iterator(d): - xform = self.pop_transform(d[key]) - if not xform[TraceKeys.DO_TRANSFORM]: - continue - with self.flipper.trace_transform(False): - d[key] = self.flipper(d[key]) - return d + return rd + + def inverse(self, data): + rd = dict(data) + for key in self.key_iterator(rd): + rd[key] = invert(rd[key], self.lazy_evaluation) + return rd class RandAxisFlipd(RandomizableTransform, MapTransform, InvertibleTransform, LazyTransform): @@ -1620,70 +1614,70 @@ class RandAxisFlipd(RandomizableTransform, MapTransform, InvertibleTransform, La keys: Keys to pick data for transformation. prob: Probability of flipping. allow_missing_keys: don't raise exception if key is missing. - lazy: a flag to indicate whether this transform should execute lazily or not. + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not. Defaults to False """ backend = RandAxisFlip.backend def __init__( - self, keys: KeysCollection, prob: float = 0.1, allow_missing_keys: bool = False, lazy: bool = False + self, + keys: KeysCollection, + prob: float = 0.1, + allow_missing_keys: bool = False, + lazy_evaluation: bool = False, + seed: int | None = None, + state: np.random.RandomState | None = None, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) - LazyTransform.__init__(self, lazy=lazy) - self.flipper = RandAxisFlip(prob=1.0, lazy=lazy) - - @LazyTransform.lazy.setter # type: ignore - def lazy(self, val: bool): - self.flipper.lazy = val - self._lazy = val + LazyTransform.__init__(self, lazy_evaluation) + self._rand_axis = 0 - def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandAxisFlipd: - super().set_random_state(seed, state) - self.flipper.set_random_state(seed, state) - return self + def randomize(self, data: NdarrayOrTensor) -> None: + super().randomize(None) + if not self._do_transform: + return None + self._rand_axis = self.R.randint(data.ndim - 1) - def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: + def __call__( + self, + data: Mapping[Hashable, torch.Tensor], + lazy_evaluation: bool | None = None, + ) -> dict[Hashable, torch.Tensor]: """ Args: data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified in this dictionary must be tensor like arrays that are channel first and have at most three spatial dimensions - lazy: a flag to indicate whether this transform should execute lazily or not - during this call. Setting this to False or True overrides the ``lazy`` flag set + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy_evaluation`` flag set during initialization for this call. Defaults to None. Returns: a dictionary containing the transformed data, as well as any other data present in the dictionary """ - d = dict(data) - first_key: Hashable = self.first_key(d) + rd = dict(data) + first_key: Hashable = self.first_key(rd) if first_key == (): - return d - - self.randomize(None) + return rd - # all the keys share the same random selected axis - self.flipper.randomize(d[first_key]) + self.randomize(rd[first_key]) - lazy_ = self.lazy if lazy is None else lazy - for key in self.key_iterator(d): + lazy_ = self.lazy_evaluation if lazy_evaluation is None else lazy_evaluation + for key in self.key_iterator(rd): if self._do_transform: - d[key] = self.flipper(d[key], randomize=False, lazy=lazy_) + rd[key] = flip(rd[key], self._rand_axis, lazy_evaluation=self.lazy_evaluation) else: - d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) - self.push_transform(d[key], replace=True, lazy=lazy_) - return d + rd[key] = identity(rd[key], None, None, lazy_evaluation=lazy_) - def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: - d = dict(data) - for key in self.key_iterator(d): - xform = self.pop_transform(d[key]) - if xform[TraceKeys.DO_TRANSFORM]: - d[key].applied_operations.append(xform[TraceKeys.EXTRA_INFO]) # type: ignore - d[key] = self.flipper.inverse(d[key]) - return d + return rd + + def inverse(self, data): + rd = dict(data) + for key in self.key_iterator(rd): + rd[key] = invert(rd[key], self.lazy_evaluation) + return rd class Rotated(MapTransform, InvertibleTransform, LazyTransform): @@ -1715,69 +1709,68 @@ class Rotated(MapTransform, InvertibleTransform, LazyTransform): the output data type is always ``float32``. It also can be a sequence of dtype or None, each element corresponds to a key in ``keys``. allow_missing_keys: don't raise exception if key is missing. - lazy: a flag to indicate whether this transform should execute lazily or not. + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not. Defaults to False """ backend = Rotate.backend def __init__( - self, - keys: KeysCollection, - angle: Sequence[float] | float, - keep_size: bool = True, - mode: SequenceStr = GridSampleMode.BILINEAR, - padding_mode: SequenceStr = GridSamplePadMode.BORDER, - align_corners: Sequence[bool] | bool = False, - dtype: Sequence[DtypeLike | torch.dtype] | DtypeLike | torch.dtype = np.float32, - allow_missing_keys: bool = False, - lazy: bool = False, + self, + keys: KeysCollection, + angle: Sequence[float] | float, + keep_size: bool = True, + mode: SequenceStr = GridSampleMode.BILINEAR, + padding_mode: SequenceStr = GridSamplePadMode.BORDER, + align_corners: Sequence[bool] | bool = False, + dtype: Sequence[DtypeLike | torch.dtype] | DtypeLike | torch.dtype = np.float32, + allow_missing_keys: bool = False, + lazy_evaluation: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) - LazyTransform.__init__(self, lazy=lazy) - self.rotator = Rotate(angle=angle, keep_size=keep_size, lazy=lazy) + LazyTransform.__init__(self, lazy_evaluation) + self.angle = angle self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.dtype = ensure_tuple_rep(dtype, len(self.keys)) - @LazyTransform.lazy.setter # type: ignore - def lazy(self, val: bool): - self.rotator.lazy = val - self._lazy = val - - def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: + def __call__( + self, + data: Mapping[Hashable, torch.Tensor], + lazy_evaluation: bool | None = None, + ) -> dict[Hashable, torch.Tensor]: """ Args: data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified in this dictionary must be tensor like arrays that are channel first and have at most three spatial dimensions - lazy: a flag to indicate whether this transform should execute lazily or not - during this call. Setting this to False or True overrides the ``lazy`` flag set + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy_evaluation`` flag set during initialization for this call. Defaults to None. Returns: a dictionary containing the transformed data, as well as any other data present in the dictionary """ d = dict(data) - lazy_ = self.lazy if lazy is None else lazy + lazy_ = self.lazy_evaluation if lazy_evaluation is None else lazy_evaluation for key, mode, padding_mode, align_corners, dtype in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners, self.dtype ): - d[key] = self.rotator( - d[key], mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype, lazy=lazy_ + d[key] = rotate( + d[key], mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype, lazy_evaluation=lazy_ ) return d - def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: - d = dict(data) - for key in self.key_iterator(d): - d[key] = self.rotator.inverse(d[key]) - return d + def inverse(self, data): + rd = dict(data) + for key in self.key_iterator(rd): + rd[key] = invert(rd[key], self.lazy_evaluation) + return rd -class RandRotated(RandomizableTransform, MapTransform, InvertibleTransform, LazyTransform): +class RandRotated(RandomizableTransform,MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based version :py:class:`monai.transforms.RandRotate` Randomly rotates the input arrays. @@ -1813,95 +1806,91 @@ class RandRotated(RandomizableTransform, MapTransform, InvertibleTransform, Lazy the output data type is always ``float32``. It also can be a sequence of dtype or None, each element corresponds to a key in ``keys``. allow_missing_keys: don't raise exception if key is missing. - lazy: a flag to indicate whether this transform should execute lazily or not. + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not. Defaults to False """ - backend = RandRotate.backend def __init__( - self, - keys: KeysCollection, - range_x: tuple[float, float] | float = 0.0, - range_y: tuple[float, float] | float = 0.0, - range_z: tuple[float, float] | float = 0.0, - prob: float = 0.1, - keep_size: bool = True, - mode: SequenceStr = GridSampleMode.BILINEAR, - padding_mode: SequenceStr = GridSamplePadMode.BORDER, - align_corners: Sequence[bool] | bool = False, - dtype: Sequence[DtypeLike | torch.dtype] | DtypeLike | torch.dtype = np.float32, - allow_missing_keys: bool = False, - lazy: bool = False, + self, + keys: KeysCollection, + range_x: tuple[float, float] | float = 0.0, + range_y: tuple[float, float] | float = 0.0, + range_z: tuple[float, float] | float = 0.0, + prob: float = 0.1, + keep_size: bool = True, + mode: SequenceStr = GridSampleMode.BILINEAR, + padding_mode: SequenceStr = GridSamplePadMode.BORDER, + align_corners: Sequence[bool] | bool = False, + dtype: Sequence[DtypeLike | torch.dtype] | DtypeLike | torch.dtype = np.float32, + allow_missing_keys: bool = False, + lazy_evaluation: bool = False, + seed: int | None = None, + state: np.random.RandomState | None = None, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) - LazyTransform.__init__(self, lazy=lazy) - self.rand_rotate = RandRotate( - range_x=range_x, range_y=range_y, range_z=range_z, prob=1.0, keep_size=keep_size, lazy=lazy - ) - self.mode = ensure_tuple_rep(mode, len(self.keys)) - self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) - self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) - self.dtype = ensure_tuple_rep(dtype, len(self.keys)) + LazyTransform.__init__(self, lazy_evaluation) - @LazyTransform.lazy.setter # type: ignore - def lazy(self, val: bool): - self.rand_rotate.lazy = val - self._lazy = val + self.keys = keys + self.keep_size = keep_size + self.allow_missing_keys = allow_missing_keys + self.mode = ensure_tuple_rep(mode, len(keys)) + self.padding_mode = ensure_tuple_rep(padding_mode, len(keys)) + self.align_corners = align_corners + self.dtype = ensure_tuple_rep(dtype, len(keys)) - def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandRotated: - super().set_random_state(seed, state) - self.rand_rotate.set_random_state(seed, state) - return self + self.range_x, self.range_y, self.range_z = range_x, range_y, range_z + self._rand_x, self._rand_y, self._rand_z = 0, 0, 0 - def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: + def randomize(self, data: Any | None = None) -> None: + super().randomize(None) + if not self._do_transform: + return None + self._rand_x = self.R.uniform(low=self.range_x[0], high=self.range_x[1]) + self._rand_y = self.R.uniform(low=self.range_y[0], high=self.range_y[1]) + self._rand_z = self.R.uniform(low=self.range_z[0], high=self.range_z[1]) + + def __call__( + self, + data: Mapping[Hashable, torch.Tensor], + lazy_evaluation: bool | None = None, + ) -> dict[Hashable, torch.Tensor]: """ Args: data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified in this dictionary must be tensor like arrays that are channel first and have at most three spatial dimensions - lazy: a flag to indicate whether this transform should execute lazily or not - during this call. Setting this to False or True overrides the ``lazy`` flag set + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy_evaluation`` flag set during initialization for this call. Defaults to None. Returns: a dictionary containing the transformed data, as well as any other data present in the dictionary """ - d = dict(data) - self.randomize(None) - - # all the keys share the same random rotate angle - self.rand_rotate.randomize() - lazy_ = self.lazy if lazy is None else lazy + self.randomize() - for key, mode, padding_mode, align_corners, dtype in self.key_iterator( - d, self.mode, self.padding_mode, self.align_corners, self.dtype + rd = dict(data) + lazy_ = self.lazy_evaluation if lazy_evaluation is None else lazy_evaluation + for key_, mode_, padding_mode_, dtype_ in self.key_iterator( + rd, self.mode, self.padding_mode, self.dtype ): if self._do_transform: - d[key] = self.rand_rotate( - d[key], - mode=mode, - padding_mode=padding_mode, - align_corners=align_corners, - dtype=dtype, - randomize=False, - lazy=lazy_, - ) - else: - d[key] = convert_to_tensor(d[key], track_meta=get_track_meta(), dtype=torch.float32) - self.push_transform(d[key], replace=True, lazy=lazy_) - return d + ndim = len(rd[key_].peek_pending_shape() if isinstance(rd[key_], MetaTensor) else rd[key_].shape[1:]) + angles=self._rand_x if ndim == 2 else (self._rand_x, self._rand_y, self._rand_z), - def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: - d = dict(data) - for key in self.key_iterator(d): - xform = self.pop_transform(d[key]) - if xform[TraceKeys.DO_TRANSFORM]: - d[key].applied_operations.append(xform[TraceKeys.EXTRA_INFO]) # type: ignore - d[key] = self.rand_rotate.inverse(d[key]) - return d + rd[key_] = rotate(rd[key_], angles, self.keep_size, + mode_, padding_mode_, self.align_corners, dtype_, + None, lazy_evaluation=lazy_) + else: + rd[key_] = identity(rd[key_], None, None, lazy_evaluation=lazy_) + return rd + def inverse(self, data): + rd = dict(data) + for key in self.key_iterator(rd): + rd[key] = invert(rd[key], self.lazy_evaluation) + return rd class Zoomd(MapTransform, InvertibleTransform, LazyTransform): """ @@ -1934,7 +1923,7 @@ class Zoomd(MapTransform, InvertibleTransform, LazyTransform): If None, use the data type of input data. keep_size: Should keep original size (pad if needed), default is True. allow_missing_keys: don't raise exception if key is missing. - lazy: a flag to indicate whether this transform should execute lazily or not. + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not. Defaults to False kwargs: other arguments for the `np.pad` or `torch.pad` function. note that `np.pad` treats channel dimension as the first dimension. @@ -1944,60 +1933,57 @@ class Zoomd(MapTransform, InvertibleTransform, LazyTransform): backend = Zoom.backend def __init__( - self, - keys: KeysCollection, - zoom: Sequence[float] | float, - mode: SequenceStr = InterpolateMode.AREA, - padding_mode: SequenceStr = NumpyPadMode.EDGE, - align_corners: Sequence[bool | None] | bool | None = None, - dtype: Sequence[DtypeLike | torch.dtype] | DtypeLike | torch.dtype = np.float32, - keep_size: bool = True, - allow_missing_keys: bool = False, - lazy: bool = False, - **kwargs, + self, + keys: KeysCollection, + zoom: Sequence[float] | float, + mode: SequenceStr = InterpolateMode.AREA, + padding_mode: SequenceStr = NumpyPadMode.EDGE, + align_corners: Sequence[bool | None] | bool | None = None, + keep_size: bool = True, + allow_missing_keys: bool = False, + lazy_evaluation: bool = False, + **kwargs, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) - LazyTransform.__init__(self, lazy=lazy) - + LazyTransform.__init__(self, lazy_evaluation) + self.zoom = zoom self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) - self.dtype = ensure_tuple_rep(dtype, len(self.keys)) - self.zoomer = Zoom(zoom=zoom, keep_size=keep_size, lazy=lazy, **kwargs) - - @LazyTransform.lazy.setter # type: ignore - def lazy(self, val: bool): - self.zoomer.lazy = val - self._lazy = val + self.keep_size = keep_size - def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: + def __call__( + self, + data: Mapping[Hashable, torch.Tensor], + lazy_evaluation: bool | None = None, + ) -> dict[Hashable, torch.Tensor]: """ Args: data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified in this dictionary must be tensor like arrays that are channel first and have at most three spatial dimensions - lazy: a flag to indicate whether this transform should execute lazily or not - during this call. Setting this to False or True overrides the ``lazy`` flag set + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy_evaluation`` flag set during initialization for this call. Defaults to None. Returns: a dictionary containing the transformed data, as well as any other data present in the dictionary """ - d = dict(data) - lazy_ = self.lazy if lazy is None else lazy - for key, mode, padding_mode, align_corners, dtype in self.key_iterator( - d, self.mode, self.padding_mode, self.align_corners, self.dtype + rd = dict(data) + lazy_ = self.lazy_evaluation if lazy_evaluation is None else lazy_evaluation + for key, mode, padding_mode, align_corners in self.key_iterator( + rd, self.mode, self.padding_mode, self.align_corners ): - d[key] = self.zoomer( - d[key], mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype, lazy=lazy_ - ) - return d + rd[key] = zoom(rd[key], self.zoom, mode=mode, + padding_mode=padding_mode, align_corners=align_corners, + lazy_evaluation=lazy_) + return rd - def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: - d = dict(data) - for key in self.key_iterator(d): - d[key] = self.zoomer.inverse(d[key]) - return d + def inverse(self, data): + rd = dict(data) + for key in self.key_iterator(rd): + rd[key] = invert(rd[key], self.lazy_evaluation) + return rd class RandZoomd(RandomizableTransform, MapTransform, InvertibleTransform, LazyTransform): @@ -2039,7 +2025,7 @@ class RandZoomd(RandomizableTransform, MapTransform, InvertibleTransform, LazyTr If None, use the data type of input data. keep_size: Should keep original size (pad if needed), default is True. allow_missing_keys: don't raise exception if key is missing. - lazy: a flag to indicate whether this transform should execute lazily or not. + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not. Defaults to False kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html @@ -2048,92 +2034,99 @@ class RandZoomd(RandomizableTransform, MapTransform, InvertibleTransform, LazyTr backend = RandZoom.backend def __init__( - self, - keys: KeysCollection, - prob: float = 0.1, - min_zoom: Sequence[float] | float = 0.9, - max_zoom: Sequence[float] | float = 1.1, - mode: SequenceStr = InterpolateMode.AREA, - padding_mode: SequenceStr = NumpyPadMode.EDGE, - align_corners: Sequence[bool | None] | bool | None = None, - dtype: Sequence[DtypeLike | torch.dtype] | DtypeLike | torch.dtype = np.float32, - keep_size: bool = True, - allow_missing_keys: bool = False, - lazy: bool = False, - **kwargs, + self, + keys: KeysCollection, + prob: float = 0.1, + min_zoom: Sequence[float] | float = 0.9, + max_zoom: Sequence[float] | float = 1.1, + mode: SequenceStr = InterpolateMode.AREA, + padding_mode: SequenceStr = NumpyPadMode.EDGE, + align_corners: Sequence[bool | None] | bool | None = None, + keep_size: bool = True, + allow_missing_keys: bool = False, + lazy_evaluation: bool = False, + seed: int | None = None, + state: np.random.RandomState | None = None, + **kwargs, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) - LazyTransform.__init__(self, lazy=lazy) - self.rand_zoom = RandZoom( - prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, keep_size=keep_size, lazy=lazy, **kwargs - ) + LazyTransform.__init__(self, lazy_evaluation) + self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) - self.dtype = ensure_tuple_rep(dtype, len(self.keys)) + self.keep_size = keep_size + self.allow_missing_keys = allow_missing_keys - @LazyTransform.lazy.setter # type: ignore - def lazy(self, val: bool): - self.rand_zoom.lazy = val - self._lazy = val + self.min_zoom, self.max_zoom = min_zoom, max_zoom + self._rand_zoom = [1.0] - def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandZoomd: - super().set_random_state(seed, state) - self.rand_zoom.set_random_state(seed, state) - return self + def randomize(self, img: NdarrayOrTensor) -> None: + super().randomize(None) + if not self._do_transform: + return None + self._rand_zoom = [self.R.uniform(l, h) for l, h in zip(self.min_zoom, self.max_zoom)] + if len(self.rand_zoom) == 1: + # to keep the spatial shape ratio, use same random zoom factor for all dims + self._rand_zoom = ensure_tuple_rep(self._rand_zoom[0], img.ndim - 1) + elif len(self._rand_zoom) == 2 and img.ndim > 3: + # if 2 zoom factors provided for 3D data, use the first factor for H and W dims, second factor for D dim + self._rand_zoom = ensure_tuple_rep(self._rand_zoom[0], img.ndim - 2) + ensure_tuple(self._rand_zoom[-1]) - def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: + def __call__( + self, + data: Mapping[Hashable, torch.Tensor], + lazy_evaluation: bool | None = None, + ) -> dict[Hashable, torch.Tensor]: """ Args: - data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified - in this dictionary must be tensor like arrays that are channel first and have at most - three spatial dimensions - lazy: a flag to indicate whether this transform should execute lazily or not - during this call. Setting this to False or True overrides the ``lazy`` flag set + img: channel first array, must have shape 2D: (nchannels, H, W), or 3D: (nchannels, H, W, D). + mode: {``"nearest"``, ``"nearest-exact"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, + ``"area"``}, the interpolation mode. Defaults to ``self.mode``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html + padding_mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to ``"constant"``. + The mode to pad data after zooming. + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + align_corners: This only has an effect when mode is + 'linear', 'bilinear', 'bicubic' or 'trilinear'. Defaults to ``self.align_corners``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html + dtype: data type for resampling computation. Defaults to ``self.dtype``. + If None, use the data type of input data. + randomize: whether to execute `randomize()` function first, default to True. + lazy_evaluation: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy_evaluation`` flag set during initialization for this call. Defaults to None. - - Returns: - a dictionary containing the transformed data, as well as any other data present in the dictionary """ - d = dict(data) - first_key: Hashable = self.first_key(d) + rd = dict(data) + first_key: Hashable = self.first_key(rd) if first_key == (): - out: dict[Hashable, torch.Tensor] = convert_to_tensor(d, track_meta=get_track_meta()) + out: dict[Hashable, torch.Tensor] = convert_to_tensor(rd, track_meta=get_track_meta()) return out - self.randomize(None) + lazy_ = self.lazy_evaluation if lazy_evaluation is None else lazy_evaluation - # all the keys share the same random zoom factor - self.rand_zoom.randomize(d[first_key]) - lazy_ = self.lazy if lazy is None else lazy + self.randomize(rd[first_key]) - for key, mode, padding_mode, align_corners, dtype in self.key_iterator( - d, self.mode, self.padding_mode, self.align_corners, self.dtype + for key, mode, padding_mode, align_corners in self.key_iterator( + rd, self.mode, self.padding_mode, self.align_corners ): if self._do_transform: - d[key] = self.rand_zoom( - d[key], - mode=mode, - padding_mode=padding_mode, - align_corners=align_corners, - dtype=dtype, - randomize=False, - lazy=lazy_, - ) + rd[key] = zoom(rd[key], self._rand_zoom, mode=mode, padding_mode=padding_mode, align_corners=align_corners, + lazy_evaluation=lazy_) else: - d[key] = convert_to_tensor(d[key], track_meta=get_track_meta(), dtype=torch.float32) - self.push_transform(d[key], replace=True, lazy=lazy_) - return d + rd[key] = identity(rd[key], None, None, lazy_evaluation=lazy_) + return rd - def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: - d = dict(data) - for key in self.key_iterator(d): - xform = self.pop_transform(d[key]) - if xform[TraceKeys.DO_TRANSFORM]: - d[key].applied_operations.append(xform[TraceKeys.EXTRA_INFO]) # type: ignore - d[key] = self.rand_zoom.inverse(d[key]) - return d + def inverse(self, data): + rd = dict(data) + for key in self.key_iterator(rd): + rd[key] = invert(rd[key], self.lazy_evaluation) + return rd class GridDistortiond(MapTransform): @@ -2611,6 +2604,36 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N return d +class TransformLiked(MapTransform, InvertibleTransform, LazyTransform): + + def __init__( + self, + keys: KeysCollection, + reference_key = Hashable, + allow_missing_keys: bool = False, + lazy_evaluation: bool = False, + ): + MapTransform().__init__(self, keys, allow_missing_keys) + LazyTransform().__init__(self, lazy_evaluation) + self.reference_key = reference_key + + def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy_evaluation=None): + + lazy_ = self.lazy_evaluation if lazy_evaluation is None else self.lazy_evaluation + d: dict = dict(data) + reference = data[self.reference_key] + for key in self.key_iterator(d): + d[key] = transform_like(data=d[key], reference=reference, lazy_evaluation=lazy_) + + return d + + def inverse(self, data): + rd = dict(data) + for key in self.key_iterator(rd): + rd[key] = invert(rd[key], self.lazy_evaluation) + return rd + + SpatialResampleD = SpatialResampleDict = SpatialResampled ResampleToMatchD = ResampleToMatchDict = ResampleToMatchd SpacingD = SpacingDict = Spacingd @@ -2635,3 +2658,4 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N GridPatchD = GridPatchDict = GridPatchd RandGridPatchD = RandGridPatchDict = RandGridPatchd RandSimulateLowResolutionD = RandSimulateLowResolutionDict = RandSimulateLowResolutiond +TransformLikeD = TransformLikeDict = TransformLiked diff --git a/monai/transforms/spatial/functional.py b/monai/transforms/spatial/functional.py index add4e7f5ea..f6bd207d55 100644 --- a/monai/transforms/spatial/functional.py +++ b/monai/transforms/spatial/functional.py @@ -14,6 +14,9 @@ from __future__ import annotations +from typing import Sequence, Tuple + +import copy import math import warnings from enum import Enum @@ -23,7 +26,7 @@ import monai from monai.config import USE_COMPILED -from monai.config.type_definitions import NdarrayOrTensor +from monai.config.type_definitions import DtypeLike, NdarrayOrTensor from monai.data.meta_obj import get_track_meta from monai.data.meta_tensor import MetaTensor from monai.data.utils import AFFINE_TOL, compute_shape_offset, to_affine_nd @@ -31,7 +34,20 @@ from monai.transforms.croppad.array import ResizeWithPadOrCrop from monai.transforms.intensity.array import GaussianSmooth from monai.transforms.inverse import TraceableTransform -from monai.transforms.utils import create_rotate, create_translate, resolves_modes, scale_affine +from monai.transforms.lazy.functional import lazily_apply_op +from monai.transforms.utils import ( + apply_align_corners, + create_identity, + create_flip, + create_rotate, + create_rotate_90, + create_scale, + create_translate, + get_input_shape_and_dtype, + resolves_modes, + scale_affine, + transform_shape +) from monai.transforms.utils_pytorch_numpy_unification import allclose from monai.utils import ( LazyAttr, @@ -41,16 +57,36 @@ convert_to_tensor, ensure_tuple, ensure_tuple_rep, + ensure_tuple_size, fall_back_tuple, + look_up_option, optional_import, ) +from monai.utils.enums import ( + GridSampleMode, + GridSamplePadMode, + InterpolateMode, + NumpyPadMode, +) +from monai.utils.type_conversion import ( + get_equivalent_dtype, +) nib, has_nib = optional_import("nibabel") cupy, _ = optional_import("cupy") cupy_ndi, _ = optional_import("cupyx.scipy.ndimage") np_ndi, _ = optional_import("scipy.ndimage") -__all__ = ["spatial_resample", "orientation", "flip", "resize", "rotate", "zoom", "rotate90", "affine_func"] +__all__ = [ + "spatial_resample", + "orientation", + "flip", + "resize", + "rotate", + "zoom", + "rotate90", + "affine_func" +] def _maybe_new_metatensor(img, dtype=None, device=None): @@ -64,13 +100,50 @@ def _maybe_new_metatensor(img, dtype=None, device=None): ) +def identity( + img: torch.Tensor, + mode: InterpolateMode | str = None, + padding_mode: NumpyPadMode | GridSamplePadMode | str = None, + dtype: DtypeLike | torch.dtype = None, + shape_override: Sequence[int] | None = None, + dtype_override: DtypeLike | torch.dtype | None = None, + lazy_evaluation: bool = False +): + img_ = convert_to_tensor(img, track_meta=get_track_meta()) + + input_shape, input_dtype = get_input_shape_and_dtype(shape_override, dtype_override, img_) + + input_ndim = len(input_shape) - 1 + + mode_ = None if mode is None else look_up_option(mode, GridSampleMode) + padding_mode_ = None if padding_mode is None else look_up_option(padding_mode, GridSamplePadMode) + dtype_ = get_equivalent_dtype(dtype or img_.dtype, torch.Tensor) + + transform = create_identity(input_ndim) + + metadata = { + "transform": transform, + "op": "identity", + LazyAttr.IN_SHAPE: input_shape, + LazyAttr.IN_DTYPE: input_dtype, + LazyAttr.OUT_SHAPE: input_shape, + LazyAttr.OUT_DTYPE: dtype_, + } + if mode_ is not None: + metadata[LazyAttr.INTERP_MODE] = mode_ + if padding_mode_ is not None: + metadata[LazyAttr.PADDING_MODE] = padding_mode_ + # metadata[LazyAttr.DTYPE] = dtype_ + + return lazily_apply_op(img_, metadata, lazy_evaluation) + + def spatial_resample( - img, dst_affine, spatial_size, mode, padding_mode, align_corners, dtype_pt, lazy, transform_info + img, dst_affine, spatial_size, mode, padding_mode, align_corners, dtype_pt, lazy_evaluation, transform_info ) -> torch.Tensor: """ Functional implementation of resampling the input image to the specified ``dst_affine`` matrix and ``spatial_size``. - This function operates eagerly or lazily according to - ``lazy`` (default ``False``). + This function operates eagerly or lazily according to ``lazy_evaluation`` (default ``False``). Args: img: data to be resampled, assuming `img` is channel-first. @@ -91,7 +164,7 @@ def spatial_resample( align_corners: Geometrically, we consider the pixels of the input as squares rather than points. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html dtype_pt: data `dtype` for resampling computation. - lazy: a flag that indicates whether the operation should be performed lazily or not + lazy_evaluation: a flag that indicates whether the operation should be performed lazily or not transform_info: a dictionary with the relevant information pertaining to an applied transform. """ original_spatial_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] @@ -135,13 +208,13 @@ def spatial_resample( meta_info = TraceableTransform.track_transform_meta( img, sp_size=spatial_size, - affine=None if affine_unchanged and not lazy else xform, + affine=None if affine_unchanged and not lazy_evaluation else xform, extra_info=extra_info, orig_size=original_spatial_shape, transform_info=transform_info, - lazy=lazy, + lazy_evaluation=lazy_evaluation, ) - if lazy: + if lazy_evaluation: out = _maybe_new_metatensor(img) return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info # type: ignore if affine_unchanged: @@ -183,18 +256,17 @@ def spatial_resample( return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out # type: ignore -def orientation(img, original_affine, spatial_ornt, lazy, transform_info) -> torch.Tensor: +def orientation(img, original_affine, spatial_ornt, lazy_evaluation, transform_info) -> torch.Tensor: """ Functional implementation of changing the input image's orientation into the specified based on `spatial_ornt`. - This function operates eagerly or lazily according to - ``lazy`` (default ``False``). + This function operates eagerly or lazily according to ``lazy_evaluation`` (default ``False``). Args: img: data to be changed, assuming `img` is channel-first. original_affine: original affine of the input image. spatial_ornt: orientations of the spatial axes, see also https://nipy.org/nibabel/reference/nibabel.orientations.html - lazy: a flag that indicates whether the operation should be performed lazily or not + lazy_evaluation: a flag that indicates whether the operation should be performed lazily or not transform_info: a dictionary with the relevant information pertaining to an applied transform. """ spatial_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] @@ -217,10 +289,10 @@ def orientation(img, original_affine, spatial_ornt, lazy, transform_info) -> tor extra_info=extra_info, orig_size=spatial_shape, transform_info=transform_info, - lazy=lazy, + lazy_evaluation=lazy_evaluation, ) out = _maybe_new_metatensor(img) - if lazy: + if lazy_evaluation: return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info # type: ignore if axes: out = torch.flip(out, dims=axes) @@ -229,11 +301,16 @@ def orientation(img, original_affine, spatial_ornt, lazy, transform_info) -> tor return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out # type: ignore -def flip(img, sp_axes, lazy, transform_info): +def flip( + img: torch.Tensor, + spatial_axis: Sequence[int] | int, + shape_override: Sequence | None = None, + dtype_override: DtypeLike | torch.dtype | None = None, + lazy_evaluation: bool = False +): """ Functional implementation of flip. - This function operates eagerly or lazily according to - ``lazy`` (default ``False``). + This function operates eagerly or lazily according to ``lazy_evaluation`` (default ``False``). Args: img: data to be changed, assuming `img` is channel-first. @@ -242,36 +319,49 @@ def flip(img, sp_axes, lazy, transform_info): If axis is negative it counts from the last to the first axis. If axis is a tuple of ints, flipping is performed on all of the axes specified in the tuple. - lazy: a flag that indicates whether the operation should be performed lazily or not + lazy_evaluation: a flag that indicates whether the operation should be performed lazily or not transform_info: a dictionary with the relevant information pertaining to an applied transform. """ - sp_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] - sp_size = convert_to_numpy(sp_size, wrap_sequence=True).tolist() - extra_info = {"axes": sp_axes} # track the spatial axes - axes = monai.transforms.utils.map_spatial_axes(img.ndim, sp_axes) # use the axes with channel dim - rank = img.peek_pending_rank() if isinstance(img, MetaTensor) else torch.tensor(3.0, dtype=torch.double) - # axes include the channel dim - xform = torch.eye(int(rank) + 1, dtype=torch.double) - for axis in axes: - sp = axis - 1 - xform[sp, sp], xform[sp, -1] = xform[sp, sp] * -1, sp_size[sp] - 1 - meta_info = TraceableTransform.track_transform_meta( - img, sp_size=sp_size, affine=xform, extra_info=extra_info, transform_info=transform_info, lazy=lazy - ) - out = _maybe_new_metatensor(img) - if lazy: - return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info - out = torch.flip(out, axes) - return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out + img_ = convert_to_tensor(img, track_meta=get_track_meta()) + + input_shape, input_dtype = get_input_shape_and_dtype(shape_override, dtype_override, img_) + + input_ndim = len(input_shape) - 1 + + spatial_axis_ = spatial_axis + if spatial_axis_ is None: + spatial_axis_ = tuple(i for i in range(len(input_shape[1:]))) + + transform = create_flip(input_ndim, spatial_axis_) + + metadata = { + "transform": transform, + "op": "flip", + "spatial_axis": spatial_axis_, + LazyAttr.IN_SHAPE: input_shape, + LazyAttr.IN_DTYPE: input_dtype, + LazyAttr.OUT_SHAPE: input_shape, + LazyAttr.OUT_DTYPE: input_dtype, + } + return lazily_apply_op(img_, metadata, lazy_evaluation) def resize( - img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, anti_aliasing_sigma, lazy, transform_info + img: torch.Tensor, + spatial_size: Sequence[int] | int, + size_mode: str = "all", + mode: InterpolateMode | str = InterpolateMode.AREA, + align_corners: bool = False, + anti_aliasing: bool = None, + anti_aliasing_sigma: Sequence[float] | float | None = None, + dtype: DtypeLike | torch.dtype | None = None, + shape_override: Sequence[int] | None = None, + dtype_override: DtypeLike | torch.dtype | None = None, + lazy_evaluation: bool = False ): """ Functional implementation of resize. - This function operates eagerly or lazily according to - ``lazy`` (default ``False``). + This function operates eagerly or lazily according to ``lazy_evaluation`` (default ``False``). Args: img: data to be changed, assuming `img` is channel-first. @@ -289,132 +379,147 @@ def resize( the image to avoid aliasing artifacts. See also ``skimage.transform.resize`` anti_aliasing_sigma: {float, tuple of floats}, optional Standard deviation for Gaussian filtering used when anti-aliasing. - lazy: a flag that indicates whether the operation should be performed lazily or not - transform_info: a dictionary with the relevant information pertaining to an applied transform. + lazy_evaluation: a flag that indicates whether the operation should be performed lazily or not """ - img = convert_to_tensor(img, track_meta=get_track_meta()) - orig_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] - extra_info = { - "mode": mode, - "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, - "dtype": str(dtype)[6:], # dtype as string; remove "torch": torch.float32 -> float32 - "new_dim": len(orig_size) - input_ndim, - } - meta_info = TraceableTransform.track_transform_meta( - img, - sp_size=out_size, - affine=scale_affine(orig_size, out_size), - extra_info=extra_info, - orig_size=orig_size, - transform_info=transform_info, - lazy=lazy, - ) - if lazy: - if anti_aliasing and lazy: - warnings.warn("anti-aliasing is not compatible with lazy evaluation.") - out = _maybe_new_metatensor(img) - return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info - if tuple(convert_to_numpy(orig_size)) == out_size: - out = _maybe_new_metatensor(img, dtype=torch.float32) - return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out - out = _maybe_new_metatensor(img) - img_ = convert_to_tensor(out, dtype=dtype, track_meta=False) # convert to a regular tensor - if anti_aliasing and any(x < y for x, y in zip(out_size, img_.shape[1:])): - factors = torch.div(torch.Tensor(list(img_.shape[1:])), torch.Tensor(out_size)) - if anti_aliasing_sigma is None: - # if sigma is not given, use the default sigma in skimage.transform.resize - anti_aliasing_sigma = torch.maximum(torch.zeros(factors.shape), (factors - 1) / 2).tolist() - else: - # if sigma is given, use the given value for downsampling axis - anti_aliasing_sigma = list(ensure_tuple_rep(anti_aliasing_sigma, len(out_size))) - for axis in range(len(out_size)): - anti_aliasing_sigma[axis] = anti_aliasing_sigma[axis] * int(factors[axis] > 1) - anti_aliasing_filter = GaussianSmooth(sigma=anti_aliasing_sigma) - img_ = convert_to_tensor(anti_aliasing_filter(img_), track_meta=False) - _, _m, _, _ = resolves_modes(mode, torch_interpolate_spatial_nd=len(img_.shape) - 1) - resized = torch.nn.functional.interpolate( - input=img_.unsqueeze(0), size=out_size, mode=_m, align_corners=align_corners - ) - out, *_ = convert_to_dst_type(resized.squeeze(0), out, dtype=torch.float32) - return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out - -def rotate(img, angle, output_shape, mode, padding_mode, align_corners, dtype, lazy, transform_info): + img_ = convert_to_tensor(img, track_meta=get_track_meta()) + + input_shape, input_dtype = get_input_shape_and_dtype(shape_override, dtype_override, img_) + + input_ndim = len(input_shape) - 1 + + if size_mode == "all": + spatial_size_ = fall_back_tuple(spatial_size, input_shape[1:]) + output_ndim = len(ensure_tuple(spatial_size_)) + if output_ndim > input_ndim: + input_shape = ensure_tuple_size(input_shape, output_ndim + 1, 1) + img = img.reshape(input_shape) + elif output_ndim < input_ndim: + raise ValueError( + "len(spatial_size) must be greater or equal to img spatial dimensions, " + f"got spatial_size={output_ndim} img={input_ndim}." + ) + else: # for the "longest" mode + img_size = input_shape[1:] + if not isinstance(spatial_size, int): + raise ValueError("spatial_size must be an int number if size_mode is 'longest'.") + scale = spatial_size / max(img_size) + spatial_size_ = tuple(int(round(s * scale)) for s in img_size) + + mode_ = look_up_option(mode, InterpolateMode) + dtype_ = get_equivalent_dtype(dtype or img.dtype, torch.Tensor) + shape_zoom_factors = [i / j for i, j in zip(spatial_size_, input_shape[1:])] + pixel_zoom_factors = [j / i for i, j in zip(spatial_size_, input_shape[1:])] + + shape_transform = create_scale(input_ndim, shape_zoom_factors) + pixel_transform = create_scale(input_ndim, pixel_zoom_factors) + + output_shape = transform_shape(input_shape, shape_transform) + + metadata = { + "transform": pixel_transform, + "op": "resize", + "spatial_size": spatial_size, + "size_mode": size_mode, + LazyAttr.INTERP_MODE: mode_, + LazyAttr.ALIGN_CORNERS: align_corners, + "anti_aliasing": anti_aliasing, + "anti_aliasing_sigma": anti_aliasing_sigma, + LazyAttr.IN_SHAPE: input_shape, + LazyAttr.IN_DTYPE: input_dtype, + LazyAttr.OUT_SHAPE: output_shape, + LazyAttr.OUT_DTYPE: dtype_, + } + return lazily_apply_op(img_, metadata, lazy_evaluation) + + +def rotate( + img: torch.Tensor, + angle: Sequence[float] | float, + keep_size: bool = True, + mode: InterpolateMode | str = InterpolateMode.AREA, + padding_mode: NumpyPadMode | GridSamplePadMode | str = NumpyPadMode.EDGE, + align_corners: bool = False, + dtype: DtypeLike | torch.dtype = None, + shape_override: Sequence[int] | None = None, + dtype_override: DtypeLike | torch.dtype = None, + lazy_evaluation: bool = False +): """ - Functional implementation of rotate. - This function operates eagerly or lazily according to - ``lazy`` (default ``False``). - Args: - img: data to be changed, assuming `img` is channel-first. + img: channel first array, must have shape: [chns, H, W] or [chns, H, W, D]. angle: Rotation angle(s) in radians. should a float for 2D, three floats for 3D. - output_shape: output shape of the rotated data. + keep_size: If it is True, the output shape is kept the same as the input. + If it is False, the output shape is adapted so that the + input array is contained completely in the output. Default is True. mode: {``"bilinear"``, ``"nearest"``} - Interpolation mode to calculate output values. + Interpolation mode to calculate output values. Defaults to ``self.mode``. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} - Padding mode for outside grid values. + Padding mode for outside grid values. Defaults to ``self.padding_mode``. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html - align_corners: See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html - dtype: data type for resampling computation. + align_corners: Defaults to ``self.align_corners``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + align_corners: Defaults to ``self.align_corners``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + dtype: data type for resampling computation. Defaults to ``self.dtype``. If None, use the data type of input data. To be compatible with other modules, - the output data type is always ``float32``. - lazy: a flag that indicates whether the operation should be performed lazily or not - transform_info: a dictionary with the relevant information pertaining to an applied transform. + the output data type is always ``np.float32``. + + Raises: + ValueError: When ``img`` spatially is not one of [2D, 3D]. """ - im_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] - input_ndim = len(im_shape) + img_ = convert_to_tensor(img, track_meta=get_track_meta()) + mode_ = look_up_option(mode, GridSampleMode) + padding_mode_ = look_up_option(padding_mode, GridSamplePadMode) + dtype_ = get_equivalent_dtype(dtype or img_.dtype, torch.Tensor) + # img_ = img_.to(dtype_) + + input_shape, input_dtype = get_input_shape_and_dtype(shape_override, dtype_override, img_) + + input_ndim = len(input_shape) - 1 if input_ndim not in (2, 3): raise ValueError(f"Unsupported image dimension: {input_ndim}, available options are [2, 3].") - _angle = ensure_tuple_rep(angle, 1 if input_ndim == 2 else 3) - transform = create_rotate(input_ndim, _angle) - if output_shape is None: - corners = np.asarray(np.meshgrid(*[(0, dim) for dim in im_shape], indexing="ij")).reshape((len(im_shape), -1)) - corners = transform[:-1, :-1] @ corners # type: ignore - output_shape = np.asarray(corners.ptp(axis=1) + 0.5, dtype=int) - else: - output_shape = np.asarray(output_shape, dtype=int) - shift = create_translate(input_ndim, ((np.array(im_shape) - 1) / 2).tolist()) - shift_1 = create_translate(input_ndim, (-(np.asarray(output_shape, dtype=int) - 1) / 2).tolist()) - transform = shift @ transform @ shift_1 - extra_info = { - "rot_mat": transform, - "mode": mode, - "padding_mode": padding_mode, - "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, - "dtype": str(dtype)[6:], # dtype as string; remove "torch": torch.float32 -> float32 - } - meta_info = TraceableTransform.track_transform_meta( - img, - sp_size=output_shape, - affine=transform, - extra_info=extra_info, - orig_size=im_shape, - transform_info=transform_info, - lazy=lazy, - ) - out = _maybe_new_metatensor(img) - if lazy: - return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info - _, _m, _p, _ = resolves_modes(mode, padding_mode) - xform = AffineTransform( - normalized=False, mode=_m, padding_mode=_p, align_corners=align_corners, reverse_indexing=True - ) - img_t = out.to(dtype) - transform_t, *_ = convert_to_dst_type(transform, img_t) - output: torch.Tensor = xform(img_t.unsqueeze(0), transform_t, spatial_size=tuple(int(i) for i in output_shape)) - output = output.float().squeeze(0) - out, *_ = convert_to_dst_type(output, dst=out, dtype=torch.float32) - return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out - -def zoom(img, scale_factor, keep_size, mode, padding_mode, align_corners, dtype, lazy, transform_info): + angle_ = ensure_tuple_rep(angle, 1 if input_ndim == 2 else 3) + + # rotate_tx = compatible_rotate(img_, angle_) + rotate_tx = create_rotate(input_ndim, angle_).astype(np.float64) + output_shape = input_shape if keep_size is True else transform_shape(input_shape, rotate_tx) + + metadata = { + "transform": rotate_tx, + "op": "rotate", + "angle": angle, + "keep_size": keep_size, + LazyAttr.INTERP_MODE: mode_, + LazyAttr.PADDING_MODE: padding_mode_, + LazyAttr.ALIGN_CORNERS: align_corners, + LazyAttr.IN_SHAPE: input_shape, + LazyAttr.IN_DTYPE: input_dtype, + LazyAttr.OUT_SHAPE: output_shape, + LazyAttr.OUT_DTYPE: dtype_, + } + return lazily_apply_op(img_, metadata, lazy_evaluation) + + +def zoom( + img: torch.Tensor, + factor: Sequence[float] | float, + mode: InterpolateMode | str = InterpolateMode.BILINEAR, + padding_mode: NumpyPadMode | GridSamplePadMode | str = NumpyPadMode.EDGE, + align_corners: bool = False, + keep_size: bool = True, + dtype: DtypeLike | torch.dtype | None = None, + shape_override: Sequence[int] | None = None, + dtype_override: DtypeLike | torch.dtype | None = None, + lazy_evaluation: bool = False +): """ Functional implementation of zoom. - This function operates eagerly or lazily according to - ``lazy`` (default ``False``). + This function operates eagerly or lazily according to ``lazy_evaluation`` (default ``False``). Args: img: data to be changed, assuming `img` is channel-first. @@ -432,125 +537,127 @@ def zoom(img, scale_factor, keep_size, mode, padding_mode, align_corners, dtype, dtype: data type for resampling computation. If None, use the data type of input data. To be compatible with other modules, the output data type is always ``float32``. - lazy: a flag that indicates whether the operation should be performed lazily or not + lazy_evaluation: a flag that indicates whether the operation should be performed lazily or not transform_info: a dictionary with the relevant information pertaining to an applied transform. """ - im_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] - output_size = [int(math.floor(float(i) * z)) for i, z in zip(im_shape, scale_factor)] - xform = scale_affine(im_shape, output_size) - extra_info = { - "mode": mode, - "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, - "dtype": str(dtype)[6:], # dtype as string; remove "torch": torch.float32 -> float32 - "do_padcrop": False, - "padcrop": {}, + + img_ = convert_to_tensor(img, track_meta=get_track_meta()) + + input_shape, input_dtype = get_input_shape_and_dtype(shape_override, dtype_override, img_) + + input_ndim = len(input_shape) - 1 + + zoom_factors = ensure_tuple_rep(factor, input_ndim) + # zoom_factors = ensure_tuple(factor) + zoom_factors = [1 / f for f in zoom_factors] + shape_zoom_factors = [1 / z for z in zoom_factors] + + # TODO: Remove this after consolidated resampling + mode_ = 'bilinear' if mode == 'area' else mode + mode_ = look_up_option(mode_, GridSampleMode) + # TODO: Remove this after consolidated resampling + padding_mode_ = 'border' if padding_mode == 'edge' else padding_mode + padding_mode_ = look_up_option(padding_mode_, GridSamplePadMode) + dtype_ = get_equivalent_dtype(dtype or img_.dtype, torch.Tensor) + + transform = create_scale(input_ndim, zoom_factors) + shape_transform = create_scale(input_ndim, shape_zoom_factors) + + output_shape = input_shape if keep_size is True else transform_shape(input_shape, + shape_transform) + + if align_corners is True: + transform_ = apply_align_corners(transform, output_shape[1:], + lambda scale_factors: create_scale(input_ndim, scale_factors)) + # TODO: confirm whether a second transform shape is required or not + output_shape = transform_shape(output_shape, transform) + else: + transform_ = transform + + + metadata = { + "transform": transform_, + "op": "zoom", + "factor": zoom_factors, + LazyAttr.INTERP_MODE: mode_, + LazyAttr.PADDING_MODE: padding_mode_, + LazyAttr.ALIGN_CORNERS: align_corners, + "keep_size": keep_size, + LazyAttr.IN_SHAPE: input_shape, + LazyAttr.IN_DTYPE: input_dtype, + LazyAttr.OUT_SHAPE: output_shape, + LazyAttr.OUT_DTYPE: dtype_, } - if keep_size: - do_pad_crop = not np.allclose(output_size, im_shape) - if do_pad_crop and lazy: # update for lazy evaluation - _pad_crop = ResizeWithPadOrCrop(spatial_size=im_shape, mode=padding_mode) - _pad_crop.lazy = True - _tmp_img = MetaTensor([], affine=torch.eye(len(output_size) + 1)) - _tmp_img.push_pending_operation({LazyAttr.SHAPE: list(output_size), LazyAttr.AFFINE: xform}) - lazy_cropped = _pad_crop(_tmp_img) - if isinstance(lazy_cropped, MetaTensor): - xform = lazy_cropped.peek_pending_affine() - extra_info["padcrop"] = lazy_cropped.pending_operations[-1] - extra_info["do_padcrop"] = do_pad_crop - output_size = [int(i) for i in im_shape] - meta_info = TraceableTransform.track_transform_meta( - img, - sp_size=output_size, - affine=xform, - extra_info=extra_info, - orig_size=im_shape, - transform_info=transform_info, - lazy=lazy, - ) - out = _maybe_new_metatensor(img) - if lazy: - return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info - img_t = out.to(dtype) - _, _m, _, _ = resolves_modes(mode, torch_interpolate_spatial_nd=len(img_t.shape) - 1) - zoomed: NdarrayOrTensor = torch.nn.functional.interpolate( - recompute_scale_factor=True, - input=img_t.unsqueeze(0), - scale_factor=list(scale_factor), - mode=_m, - align_corners=align_corners, - ).squeeze(0) - out, *_ = convert_to_dst_type(zoomed, dst=out, dtype=torch.float32) - if isinstance(out, MetaTensor): - out = out.copy_meta_from(meta_info) - do_pad_crop = not np.allclose(output_size, zoomed.shape[1:]) - if do_pad_crop: - _pad_crop = ResizeWithPadOrCrop(spatial_size=img_t.shape[1:], mode=padding_mode) - out = _pad_crop(out) - if get_track_meta() and do_pad_crop: - padcrop_xform = out.applied_operations.pop() - out.applied_operations[-1]["extra_info"]["do_padcrop"] = True - out.applied_operations[-1]["extra_info"]["padcrop"] = padcrop_xform - return out - - -def rotate90(img, axes, k, lazy, transform_info): + + return lazily_apply_op(img_, metadata, lazy_evaluation) + + +def rotate90( + img: torch.Tensor, + k: int = 1, + spatial_axes: Tuple[int, int] = (0, 1), + shape_override: Sequence[int] | None = None, + dtype_override: DtypeLike | torch.dtype | None = None, + lazy_evaluation: bool = False +): """ Functional implementation of rotate90. - This function operates eagerly or lazily according to - ``lazy`` (default ``False``). + This function operates eagerly or lazily according to ``lazy_evaluation`` (default ``False``). Args: img: data to be changed, assuming `img` is channel-first. axes: 2 int numbers, defines the plane to rotate with 2 spatial axes. If axis is negative it counts from the last to the first axis. k: number of times to rotate by 90 degrees. - lazy: a flag that indicates whether the operation should be performed lazily or not + lazy_evaluation: a flag that indicates whether the operation should be performed lazily or not transform_info: a dictionary with the relevant information pertaining to an applied transform. """ - extra_info = {"axes": [d - 1 for d in axes], "k": k} - ori_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] - sp_shape = list(ori_shape) - if k in (1, 3): - a_0, a_1 = axes[0] - 1, axes[1] - 1 - sp_shape[a_0], sp_shape[a_1] = ori_shape[a_1], ori_shape[a_0] - rank = img.peek_pending_rank() if isinstance(img, MetaTensor) else torch.tensor(3.0, dtype=torch.double) - r, sp_r = int(rank), len(ori_shape) - xform = to_affine_nd(r, create_translate(sp_r, [-float(d - 1) / 2 for d in sp_shape])) - s = -1.0 if int(axes[0]) - int(axes[1]) in (-1, 2) else 1.0 - if sp_r == 2: - rot90 = to_affine_nd(r, create_rotate(sp_r, [s * np.pi / 2])) + if len(spatial_axes) != 2: + raise ValueError("'spatial_axes' must be a tuple of two integers indicating") + + img_ = convert_to_tensor(img, track_meta=get_track_meta()) + + # if shape_override is set, it always wins + input_shape = shape_override + + input_shape, input_dtype = get_input_shape_and_dtype(shape_override, dtype_override, img_) + + input_ndim = len(input_shape) - 1 + + transform = create_rotate_90(input_ndim, spatial_axes, k) + + # TODO: this could be calculated from the transform like the other functions do + if k % 2 == 1: + output_shape_order = [i for i in range(input_ndim)] + for i in range(input_ndim): + if i == spatial_axes[0]: + output_shape_order[i] = spatial_axes[1] + elif i == spatial_axes[1]: + output_shape_order[i] = spatial_axes[0] + output_shape = (input_shape[0],) + tuple(input_shape[output_shape_order[i] + 1] for i in range(input_ndim)) else: - idx = {1, 2, 3} - set(axes) - angle: list[float] = [0, 0, 0] - angle[idx.pop() - 1] = s * np.pi / 2 - rot90 = to_affine_nd(r, create_rotate(sp_r, angle)) - for _ in range(k): - xform = rot90 @ xform - xform = to_affine_nd(r, create_translate(sp_r, [float(d - 1) / 2 for d in ori_shape])) @ xform - meta_info = TraceableTransform.track_transform_meta( - img, - sp_size=sp_shape, - affine=xform, - extra_info=extra_info, - orig_size=ori_shape, - transform_info=transform_info, - lazy=lazy, - ) - out = _maybe_new_metatensor(img) - if lazy: - return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info - out = torch.rot90(out, k, axes) - return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out + output_shape = input_shape + + metadata = { + "transform": "transform", + "op": "rotate90", + "k": k, + "spatial_axes": spatial_axes, + LazyAttr.IN_SHAPE: input_shape, + LazyAttr.IN_DTYPE: input_dtype, + LazyAttr.OUT_SHAPE: output_shape, + LazyAttr.OUT_DTYPE: input_dtype, + } + return lazily_apply_op(img_, metadata, lazy_evaluation) def affine_func( - img, affine, grid, resampler, sp_size, mode, padding_mode, do_resampling, image_only, lazy, transform_info + img, affine, grid, resampler, sp_size, mode, padding_mode, do_resampling, image_only, lazy_evaluation, transform_info ): """ Functional implementation of affine. - This function operates eagerly or lazily according to - ``lazy`` (default ``False``). + This function operates eagerly or lazily according to ``lazy_evaluation`` (default ``False``). Args: img: data to be changed, assuming `img` is channel-first. @@ -574,7 +681,7 @@ def affine_func( do_resampling: whether to do the resampling, this is a flag for the use case of updating metadata but skipping the actual (potentially heavy) resampling operation. image_only: if True return only the image volume, otherwise return (image, affine). - lazy: a flag that indicates whether the operation should be performed lazily or not + lazy_evaluation: a flag that indicates whether the operation should be performed lazily or not transform_info: a dictionary with the relevant information pertaining to an applied transform. """ @@ -597,9 +704,9 @@ def affine_func( extra_info=extra_info, orig_size=img_size, transform_info=transform_info, - lazy=lazy, + lazy_evaluation=lazy_evaluation, ) - if lazy: + if lazy_evaluation: out = _maybe_new_metatensor(img) out = out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info return out if image_only else (out, affine) @@ -610,3 +717,73 @@ def affine_func( out = _maybe_new_metatensor(img, dtype=torch.float32, device=resampler.device) out = out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out return out if image_only else (out, affine) + + +def transform_like( + data, + data_applied_ops, + data_pending_ops, + reference, + reference_applied_ops, + reference_pending_ops, + lazy_evaluation=False +): + """ + Functional implementation for adapting a tensor. + This function operates eagerly or lazily accoring to ``lazy`` (default ``False``). + + `transform_like` takes a tensor ``data`` and a reference tensor ``reference``, and applies the + latest transform from ``reference`` to ``data``. + + ``transform_like`` is designed to work with MONAI's ``MetaTensor`` as well as with standard + pytorch ``tensor``. + - in the case that a pytorch ``tensor`` is provided for ``data``, one or both of + ``data_applied_ops`` and ``data_pending_ops`` must also be provided. + - ``reference`` should only be supplied if it is a ``MetaTensor``. Otherwise, one or both of + ``reference_applied_ops`` and ``reference_pending_ops`` should be provided instead + + Args: + data: the ``tensor`` / ``MetaTensor`` to be transformed + data_applied_ops: a list of ops that have been applied to ``data``, if data is not a + ``MetaTensor``. It should be None if ``data`` is a ``MetaTensor`` + data_pending_ops: a list of ops that are due to be applied to ``data``, if data is not a + ``MetaTensor``. It should be None if ``data`` is a ``MetaTensor`` + reference: the ``MetaTensor`` that ``data`` is being matched to. If the reference tensor + is not a ``MetaTensor``, ``reference`` should be `None` + reference_applied_ops: a list of ops that have been applied to a reference tensor + (which is not supplied if this parameter is set) + reference_pending_ops: a list of ops that are due to be applied to a reference tensor (not + (which is not supplied if this parameter is set) + """ + + if not isinstance(data, (MetaTensor, torch.Tensor)): + raise TypeError( + f"If 'data' is provided, it must be one of (MetaTensor, torch.Tensor) but is of type {type(data)}" + ) + if isinstance(data, MetaTensor) and any(p is not None for p in (data_applied_ops, data_pending_ops)): + raise ValueError(f"is 'data' is a MetaTensor, 'data_applied_ops' and 'data_pending_ops' must be None") + + if reference is not None and not isinstance(reference, MetaTensor): + raise TypeError( + f"If 'reference' is provided, it must be of type MetaTensor, but is of type {type(reference)}" + ) + if reference is not None and any(p is not None for p in (reference_applied_ops, reference_pending_ops)): + raise ValueError(f"If 'reference' is set, 'reference_applied_ops' and 'reference_pending_ops' must be None") + + op_to_apply = None + if reference is MetaTensor: + if reference.has_pending_operations(): + op_to_apply = copy.deepcopy(reference.pending_operations[-1]) + elif len(reference.applied_operations) > 0: + op_to_apply = copy.deepcopy(reference.applied_operations[-1]) + else: + if reference_pending_ops is not None and len(reference_pending_ops) > 0: + op_to_apply = copy.deepcopy(reference_pending_ops[-1]) + elif reference_applied_ops is not None and len(reference_applied_ops) > 0: + op_to_apply = copy.deepcopy(reference_applied_ops[-1]) + + if op_to_apply is not None: + if isinstance(data, MetaTensor): + data.push_pending_operation(op_to_apply) + + return lazily_apply_op(data, op_to_apply, lazy_evaluation) diff --git a/monai/transforms/traits.py b/monai/transforms/traits.py index 016effc59d..7bb6d70701 100644 --- a/monai/transforms/traits.py +++ b/monai/transforms/traits.py @@ -29,7 +29,7 @@ class LazyTrait: """ @property - def lazy(self): + def lazy_evaluation(self): """ Get whether lazy evaluation is enabled for this transform instance. Returns: @@ -37,8 +37,8 @@ def lazy(self): """ raise NotImplementedError() - @lazy.setter - def lazy(self, enabled: bool): + @lazy_evaluation.setter + def lazy_evaluation(self, enabled: bool): """ Set whether lazy evaluation is enabled for this transform instance. Args: diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 3d09cea545..c0391c135c 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -294,22 +294,22 @@ class LazyTransform(Transform, LazyTrait): dictionary transforms to simplify implementation of new lazy transforms. """ - def __init__(self, lazy: bool | None = False): - if lazy is not None: - if not isinstance(lazy, bool): - raise TypeError(f"lazy must be a bool but is of type {type(lazy)}") - self._lazy = lazy + def __init__(self, lazy_evaluation: bool | None = False): + if lazy_evaluation is not None: + if not isinstance(lazy_evaluation, bool): + raise TypeError(f"lazy must be a bool but is of type {type(lazy_evaluation)}") + self._lazy = lazy_evaluation @property - def lazy(self): + def lazy_evaluation(self): return self._lazy - @lazy.setter - def lazy(self, lazy: bool | None): - if lazy is not None: - if not isinstance(lazy, bool): - raise TypeError(f"lazy must be a bool but is of type {type(lazy)}") - self._lazy = lazy + @lazy_evaluation.setter + def lazy_evaluation(self, lazy_evaluation: bool | None): + if lazy_evaluation is not None: + if not isinstance(lazy_evaluation, bool): + raise TypeError(f"lazy_evaluation must be a bool but is of type {type(lazy_evaluation)}") + self._lazy = lazy_evaluation @property def requires_current_data(self): diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 44e5b25a34..aa95456e11 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -26,6 +26,7 @@ import monai from monai.config import DtypeLike, IndexSelection from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor +from monai.data.meta_tensor import MetaTensor from monai.networks.layers import GaussianFilter from monai.networks.utils import meshgrid_ij from monai.transforms.compose import Compose @@ -131,6 +132,10 @@ "resolves_modes", "has_status_keys", "distance_transform_edt", + "extents_from_shape", + "shape_from_extents", + "transform_shape", + "get_input_shape_and_dtype", ] @@ -2194,5 +2199,293 @@ def distance_transform_edt( return convert_data_type(r_vals[0] if len(r_vals) == 1 else r_vals, output_type=type(img), device=device)[0] # type: ignore +def extents_from_shape( + shape: Sequence[int], + dtype=torch.float32 +): + """ + This method calculates a set of extents given a shape. Each extent is a point in a coordinate + system that can be multiplied with a homogeneous matrix. As such, extents for 2D data have + three values, and extends for 3D data have four values. + + For shapes representing 2D data, this is an array of four extents, for shape s: + - (0, 0, 1), (0, s[1], 1), (s[0], 0, 1), (s[0], s[1], 1). + + For shapes representing 3D data, this is an array of eight extents, representing a cuboid: + - (0, 0, 0, 1), (0, 0, s[2], 1), (0, s[1], 0, 1), (0, s[1], s[2], 1), + - (s[0], 0, 0, 1), (s[0], 0, s[2], 1), (s[0], s[1], 0, 1), (s[0], s[1], s[2], 1) + + Args: + shape: A shape from a numpy array or tensor + dtype: The dtype to use for the resulting extents + + Returns: + An array of arrays representing the shape extents + """ + extents = [[0, shape[i]] for i in range(1, len(shape))] + + extents = itertools.product(*extents) + # return [torch.as_tensor(e + (1,), dtype=dtype) for e in extents] + return [np.asarray(e + (1,), dtype=dtype) for e in extents] + + +def shape_from_extents( + src_shape: Sequence, + extents: Sequence[np.ndarray] | Sequence[torch.Tensor] | np.ndarray | torch.Tensor +): + """ + This method, given a sequence of homogeneous vertices representing the corners of a rectangle + or cuboid, will calculate the resulting shape values from those extents. + + Args: + src_shape: The shape into which the resulting spatial shape values will be written. Note + that initial shape value is appended to the spatial shape components. + extents: The extents from which the spatial shape values should be calculated + + Returns: + A tuple composed of the first element of `src_shape` with the spatial shape values appended + to it. + """ + if isinstance(extents, (list, tuple)): + if isinstance(extents[0], np.ndarray): + extents_ = np.asarray(extents) + else: + extents_ = torch.stack(extents) + extents_ = extents_.numpy() + else: + if isinstance(extents, np.ndarray): + extents_ = extents + else: + extents_ = extents.numpy() + + mins = extents_.min(axis=0) + maxes = extents_.max(axis=0) + values = np.round(maxes - mins).astype(int)[:-1].tolist() + return (src_shape[0],) + tuple(values) + + +def transform_shape(input_shape: Sequence[int], matrix: torch.Tensor): + """ + TODO: this method should accept Matrix and Grid types also + TODO: this method should be moved to transforms.utils + Transform `input_shape` according to `transform`. This can be used for any transforms that + widen / narrow the resulting region of interest (typically transforms that have a 'keep_size' + parameter such as rotate. + + Args: + input_shape: the shape to be transformed + matrix: the matrix to apply to it + + Returns: + The resulting shape + """ + if not Affine.is_affine_shaped(matrix): + raise ValueError("'matrix' must have a valid 2d or 3d homogenous matrix shape but has shape " + f"{matrix.shape}") + im_extents = extents_from_shape(input_shape, matrix.dtype) + im_extents = [matrix @ e for e in im_extents] + output_shape = shape_from_extents(input_shape, im_extents) + return output_shape + + +def get_input_shape_and_dtype(shape_override, dtype_override, img): + # if shape_override is set, it always wins + input_shape = shape_override + input_dtype = dtype_override + + if input_shape is None: + if isinstance(img, MetaTensor) and len(img.pending_operations) > 0: + input_shape = img.peek_pending_shape() + else: + input_shape = img.shape + if input_dtype is None: + if isinstance(img, MetaTensor) and len(img.pending_operations) > 0: + input_dtype = img.peek_pending_dtype() + else: + input_dtype = img.dtype + return input_shape, input_dtype + + +def apply_align_corners(matrix, spatial_size, op): + """ + TODO: ensure that this functionality is correct and produces the same result as the existing ways of handling align corners + """ + inflated_spatial_size = tuple(s + 1 for s in spatial_size) + scale_factors = tuple(s / i for s, i in zip(spatial_size, inflated_spatial_size)) + scale_mat = op(scale_factors) + # scale_mat = scale_mat.double() + return matmul(scale_mat, matrix) + + +class Affine: + """A class to represent an affine transform matrix.""" + + __slots__ = ("data",) + + def __init__(self, data): + self.data = data + + @staticmethod + def is_affine_shaped(data): + """Check if the data is an affine matrix.""" + if isinstance(data, Affine): + return True + if isinstance(data, DisplacementField): + return False + if not hasattr(data, "shape") or len(data.shape) < 2: + return False + return data.shape[-1] in (3, 4) and data.shape[-1] == data.shape[-2] + + +class DisplacementField: + """A class to represent a dense displacement field.""" + + __slots__ = ("data",) + + def __init__(self, data): + self.data = data + + @staticmethod + def is_ddf_shaped(data): + """Check if the data is a DDF.""" + if isinstance(data, DisplacementField): + return True + if isinstance(data, Affine): + return False + if not hasattr(data, "shape") or len(data.shape) < 3: + return False + return not Affine.is_affine_shaped(data) + + +def combine_transforms(left: torch.Tensor, right: torch.Tensor) -> torch.Tensor: + """Given transforms A and B to be applied to x, return the combined transform (AB), so that A(B(x)) becomes AB(x)""" + if Affine.is_affine_shaped(left) and Affine.is_affine_shaped(right): # linear transforms + left = convert_to_tensor(left.data if isinstance(left, Affine) else left, wrap_sequence=True) + right = convert_to_tensor(right.data if isinstance(right, Affine) else right, wrap_sequence=True) + return torch.matmul(left, right) + if DisplacementField.is_ddf_shaped(left) and DisplacementField.is_ddf_shaped( + right + ): # adds DDFs, do we need metadata if metatensor input? + left = convert_to_tensor(left.data if isinstance(left, DisplacementField) else left, wrap_sequence=True) + right = convert_to_tensor(right.data if isinstance(right, DisplacementField) else right, wrap_sequence=True) + return left + right + raise NotImplementedError + + +def matmul( + left: Affine | DisplacementField | NdarrayOrTensor, + right: Affine | DisplacementField | NdarrayOrTensor +): + matrix_types = (Affine, DisplacementField, torch.Tensor, np.ndarray) + + if not isinstance(left, matrix_types): + raise TypeError(f"'left' must be one of {matrix_types} but is {type(left)}") + if not isinstance(right, matrix_types): + raise TypeError(f"'second' must be one of {matrix_types} but is {type(right)}") + + left_ = left + right_ = right + + put_in_grid = isinstance(left, DisplacementField) or isinstance(right, DisplacementField) + + put_in_matrix = isinstance(left, Affine) or isinstance(right, Affine) + put_in_matrix = False if put_in_grid is True else put_in_matrix + + promote_to_tensor = not (isinstance(left_, np.ndarray) and isinstance(right_, np.ndarray)) + + left_raw = left_.data if isinstance(left_, (Affine, DisplacementField)) else left_ + right_raw = right_.data if isinstance(right_, (Affine, DisplacementField)) else right_ + + if promote_to_tensor: + left_raw = torch.as_tensor(left_raw) + right_raw = torch.as_tensor(right_raw) + + if isinstance(left_, DisplacementField): + if isinstance(right_, DisplacementField): + raise RuntimeError("Unable to matrix multiply two Grids") + else: + result = matmul_grid_matrix(left_raw, right_raw) + else: + if isinstance(right_, DisplacementField): + result = matmul_matrix_grid(left_raw, right_raw) + else: + result = matmul_matrix_matrix(left_raw, right_raw) + + if put_in_grid: + result = DisplacementField(result) + elif put_in_matrix: + result = Affine(result) + + return result + + +def matmul_matrix_grid( + left: NdarrayOrTensor, + right: NdarrayOrTensor +): + if not Affine.is_affine_shaped(left): + raise ValueError(f"'left' should be a 2D or 3D homogenous matrix but has shape {left.shape}") + + if not DisplacementField.is_ddf_shaped(right): + raise ValueError( + "'right' should be a 3D array with shape[0] == 2 or a " + f"4D array with shape[0] == 3 but has shape {right.shape}" + ) + + # flatten the grid to take advantage of torch batch matrix multiply + right_flat = right.reshape(right.shape[0], -1) + result_flat = left @ right_flat + # restore the grid shape + result = result_flat.reshape((-1,) + result_flat.shape[1:]) + return result + + +def matmul_grid_matrix(left: NdarrayOrTensor, right: NdarrayOrTensor): + if not DisplacementField.is_ddf_shaped(left): + raise ValueError( + "'left' should be a 3D array with shape[0] == 2 or a " + f"4D array with shape[0] == 3 but has shape {left.shape}" + ) + + if not Affine.is_affine_shaped(right): + raise ValueError(f"'right' should be a 2D or 3D homogenous matrix but has shape {right.shape}") + + try: + inv_matrix = torch.inverse(right) + except RuntimeError: + # the matrix is not invertible, so we will have to perform a slow grid to matrix operation + return matmul_grid_matrix_slow(left, right) + + # invert the matrix and swap the arguments, taking advantage of + # matrix @ vector == vector_transposed @ matrix_inverse + return matmul_matrix_grid(inv_matrix, left) + + +def matmul_grid_matrix_slow(left: NdarrayOrTensor, right: NdarrayOrTensor): + if not DisplacementField.is_ddf_shaped(left): + raise ValueError( + "'left' should be a 3D array with shape[0] == 2 or a " + f"4D array with shape[0] == 3 but has shape {left.shape}" + ) + + if not Affine.is_affine_shaped(right): + raise ValueError(f"'right' should be a 2D or 3D homogenous matrix but has shape {right.shape}") + + flat_left = left.reshape(left.shape[0], -1) + result_flat = torch.zeros_like(flat_left) + for i in range(flat_left.shape[1]): + vector = flat_left[:, i][None, :] + result_vector = vector @ right + result_flat[:, i] = result_vector[0, :] + + # restore the grid shape + result = result_flat.reshape((-1,) + result_flat.shape[1:]) + return result + + +def matmul_matrix_matrix(left: NdarrayOrTensor, right: NdarrayOrTensor): + return left @ right + + if __name__ == "__main__": print_transform_backends() diff --git a/tests/test_resample.py b/tests/test_resample.py index c90dc5f13d..af1b1ac102 100644 --- a/tests/test_resample.py +++ b/tests/test_resample.py @@ -16,7 +16,7 @@ import torch from parameterized import parameterized -from monai.transforms.lazy.functional import resample +from monai.transforms.lazy.functional import resample_image from monai.utils import convert_to_tensor from tests.utils import assert_allclose, get_arange_img @@ -37,12 +37,12 @@ def rotate_90_2d(): class TestResampleFunction(unittest.TestCase): @parameterized.expand(RESAMPLE_FUNCTION_CASES) def test_resample_function_impl(self, img, matrix, expected): - out = resample(convert_to_tensor(img), matrix, {"lazy_shape": img.shape[1:], "lazy_padding_mode": "border"}) + out = resample_image(convert_to_tensor(img), matrix, {"lazy_shape": img.shape[1:], "lazy_padding_mode": "border"}) assert_allclose(out[0], expected, type_test=False) img = convert_to_tensor(img, dtype=torch.uint8) - out = resample(img, matrix, {"lazy_resample_mode": "auto", "lazy_dtype": torch.float}) - out_1 = resample(img, matrix, {"lazy_resample_mode": "other value", "lazy_dtype": torch.float}) + out = resample_image(img, matrix, {"lazy_resample_mode": "auto", "lazy_dtype": torch.float}) + out_1 = resample_image(img, matrix, {"lazy_resample_mode": "other value", "lazy_dtype": torch.float}) self.assertIs(out.dtype, out_1.dtype) # testing dtype in different lazy_resample_mode