diff --git a/monai/handlers/transform_inverter.py b/monai/handlers/transform_inverter.py index 42c5bdcf92..68201e44be 100644 --- a/monai/handlers/transform_inverter.py +++ b/monai/handlers/transform_inverter.py @@ -17,7 +17,7 @@ from monai.data import BatchInverseTransform from monai.data.utils import no_collation from monai.engines.utils import CommonKeys -from monai.transforms import InvertibleTransform, allow_missing_keys_mode +from monai.transforms import InvertibleTransform, allow_missing_keys_mode, convert_inverse_interp_mode from monai.utils import InverseKeys, exact_version, optional_import Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") @@ -32,12 +32,6 @@ class TransformInverter: Ignite handler to automatically invert all the pre-transforms that support `inverse`. It takes `engine.state.output` as the input data and uses the transforms infomation from `engine.state.batch`. - Note: - This handler is experimental API in v0.5, the interpolation mode in the transforms - and inverse transforms are the same, so maybe it's not correct as we may want to use `bilinear` - for input image but use `nearest` when inverting transforms for model outout. - For this case, a solution is to set `batch_key` to the label field if we have labels. - """ def __init__( @@ -48,6 +42,7 @@ def __init__( batch_key: str = CommonKeys.IMAGE, output_key: str = CommonKeys.PRED, postfix: str = "inverted", + nearest_interp: bool = True, ) -> None: """ Args: @@ -59,6 +54,8 @@ def __init__( for this input data, then invert them for the model output, default to "image". output_key: the key of model output in `ignite.engine.output`, invert transforms on it. postfix: will save the inverted result into `ignite.engine.output` with key `{ouput_key}_{postfix}`. + nearest_interp: whether to use `nearest` interpolation mode when inverting spatial transforms, + default to `True`. if `False`, use the same interpolation mode as the original transform. """ self.transform = transform @@ -66,6 +63,7 @@ def __init__( self.batch_key = batch_key self.output_key = output_key self.postfix = postfix + self.nearest_interp = nearest_interp def attach(self, engine: Engine) -> None: """ @@ -84,9 +82,13 @@ def __call__(self, engine: Engine) -> None: warnings.warn("all the pre-transforms are not InvertibleTransform or no need to invert.") return + transform_info = engine.state.batch[transform_key] + if self.nearest_interp: + convert_inverse_interp_mode(trans_info=transform_info, mode="nearest", align_corners=None) + segs_dict = { self.batch_key: engine.state.output[self.output_key].detach().cpu(), - transform_key: engine.state.batch[transform_key], + transform_key: transform_info, } with allow_missing_keys_mode(self.transform): # type: ignore diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index b66567e71a..f96194c262 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -371,6 +371,7 @@ ) from .utils import ( allow_missing_keys_mode, + convert_inverse_interp_mode, copypaste_arrays, create_control_grid, create_grid, diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index c8d5ceea40..c4ef659c69 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -16,6 +16,7 @@ """ from copy import deepcopy +from enum import Enum from itertools import chain from math import floor from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union @@ -125,7 +126,7 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key, m in self.key_iterator(d, self.mode): - self.push_transform(d, key) + self.push_transform(d, key, extra_info={"mode": m.value if isinstance(m, Enum) else m}) d[key] = self.padder(d[key], mode=m) return d @@ -193,7 +194,7 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key, m in self.key_iterator(d, self.mode): - self.push_transform(d, key) + self.push_transform(d, key, extra_info={"mode": m.value if isinstance(m, Enum) else m}) d[key] = self.padder(d[key], mode=m) return d @@ -259,7 +260,7 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key, m in self.key_iterator(d, self.mode): - self.push_transform(d, key) + self.push_transform(d, key, extra_info={"mode": m.value if isinstance(m, Enum) else m}) d[key] = self.padder(d[key], mode=m) return d @@ -826,6 +827,7 @@ class ResizeWithPadOrCropd(MapTransform, InvertibleTransform): ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} One of the listed string values or a user supplied function for padding. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + It also can be a sequence of string, each element corresponds to a key in ``keys``. allow_missing_keys: don't raise exception if key is missing. """ @@ -834,18 +836,26 @@ def __init__( self, keys: KeysCollection, spatial_size: Union[Sequence[int], int], - mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT, + mode: NumpyPadModeSequence = NumpyPadMode.CONSTANT, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) - self.padcropper = ResizeWithPadOrCrop(spatial_size=spatial_size, mode=mode) + self.mode = ensure_tuple_rep(mode, len(self.keys)) + self.padcropper = ResizeWithPadOrCrop(spatial_size=spatial_size) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.key_iterator(d): + for key, m in self.key_iterator(d, self.mode): orig_size = d[key].shape[1:] - d[key] = self.padcropper(d[key]) - self.push_transform(d, key, orig_size=orig_size) + d[key] = self.padcropper(d[key], mode=m) + self.push_transform( + d, + key, + orig_size=orig_size, + extra_info={ + "mode": m.value if isinstance(m, Enum) else m, + }, + ) return d def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index 3e5b68e8e4..3baef91717 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -76,7 +76,7 @@ def push_transform( info = { InverseKeys.CLASS_NAME: self.__class__.__name__, InverseKeys.ID: id(self), - InverseKeys.ORIG_SIZE: orig_size or data[key].shape[1:], + InverseKeys.ORIG_SIZE: orig_size or (data[key].shape[1:] if hasattr(data[key], "shape") else None), } if extra_info is not None: info[InverseKeys.EXTRA_INFO] = extra_info diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 86c94302a1..9f782bf8fc 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -16,6 +16,7 @@ """ from copy import deepcopy +from enum import Enum from typing import Any, Dict, Hashable, Mapping, Optional, Sequence, Tuple, Union import numpy as np @@ -208,16 +209,24 @@ def __call__( align_corners=align_corners, dtype=dtype, ) - self.push_transform(d, key, extra_info={"meta_data_key": meta_data_key, "old_affine": old_affine}) + self.push_transform( + d, + key, + extra_info={ + "meta_data_key": meta_data_key, + "old_affine": old_affine, + "mode": mode.value if isinstance(mode, Enum) else mode, + "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, + "align_corners": align_corners if align_corners is not None else "none", + }, + ) # set the 'affine' key meta_data["affine"] = new_affine return d def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key, mode, padding_mode, align_corners, dtype in self.key_iterator( - d, self.mode, self.padding_mode, self.align_corners, self.dtype - ): + for key, dtype in self.key_iterator(d, self.dtype): transform = self.get_most_recent_transform(d, key) if self.spacing_transform.diagonal: raise RuntimeError( @@ -227,6 +236,9 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar # Create inverse transform meta_data = d[transform[InverseKeys.EXTRA_INFO]["meta_data_key"]] old_affine = np.array(transform[InverseKeys.EXTRA_INFO]["old_affine"]) + mode = transform[InverseKeys.EXTRA_INFO]["mode"] + padding_mode = transform[InverseKeys.EXTRA_INFO]["padding_mode"] + align_corners = transform[InverseKeys.EXTRA_INFO]["align_corners"] orig_pixdim = np.sqrt(np.sum(np.square(old_affine), 0))[:-1] inverse_transform = Spacing(orig_pixdim, diagonal=self.spacing_transform.diagonal) # Apply inverse @@ -235,7 +247,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar affine=meta_data["affine"], mode=mode, padding_mode=padding_mode, - align_corners=align_corners, + align_corners=False if align_corners == "none" else align_corners, dtype=dtype, ) meta_data["affine"] = new_affine @@ -483,17 +495,26 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key, mode, align_corners in self.key_iterator(d, self.mode, self.align_corners): - self.push_transform(d, key) + self.push_transform( + d, + key, + extra_info={ + "mode": mode.value if isinstance(mode, Enum) else mode, + "align_corners": align_corners if align_corners is not None else "none", + }, + ) d[key] = self.resizer(d[key], mode=mode, align_corners=align_corners) return d def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key, mode, align_corners in self.key_iterator(d, self.mode, self.align_corners): + for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) orig_size = transform[InverseKeys.ORIG_SIZE] + mode = transform[InverseKeys.EXTRA_INFO]["mode"] + align_corners = transform[InverseKeys.EXTRA_INFO]["align_corners"] # Create inverse transform - inverse_transform = Resize(orig_size, mode, align_corners) + inverse_transform = Resize(orig_size, mode, None if align_corners == "none" else align_corners) # Apply inverse transform d[key] = inverse_transform(d[key]) # Remove the applied transform @@ -573,17 +594,28 @@ def __call__( for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): orig_size = d[key].shape[1:] d[key], affine = self.affine(d[key], mode=mode, padding_mode=padding_mode) - self.push_transform(d, key, orig_size=orig_size, extra_info={"affine": affine}) + self.push_transform( + d, + key, + orig_size=orig_size, + extra_info={ + "affine": affine, + "mode": mode.value if isinstance(mode, Enum) else mode, + "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, + }, + ) return d def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): + for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) orig_size = transform[InverseKeys.ORIG_SIZE] # Create inverse transform fwd_affine = transform[InverseKeys.EXTRA_INFO]["affine"] + mode = transform[InverseKeys.EXTRA_INFO]["mode"] + padding_mode = transform[InverseKeys.EXTRA_INFO]["padding_mode"] inv_affine = np.linalg.inv(fwd_affine) affine_grid = AffineGrid(affine=inv_affine) @@ -701,18 +733,28 @@ def __call__( affine = torch.as_tensor(np.eye(len(sp_size) + 1), device=self.rand_affine.rand_affine_grid.device) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): - self.push_transform(d, key, extra_info={"affine": affine}) + self.push_transform( + d, + key, + extra_info={ + "affine": affine, + "mode": mode.value if isinstance(mode, Enum) else mode, + "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, + }, + ) d[key] = self.rand_affine.resampler(d[key], grid, mode=mode, padding_mode=padding_mode) return d def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): + for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) orig_size = transform[InverseKeys.ORIG_SIZE] # Create inverse transform fwd_affine = transform[InverseKeys.EXTRA_INFO]["affine"] + mode = transform[InverseKeys.EXTRA_INFO]["mode"] + padding_mode = transform[InverseKeys.EXTRA_INFO]["padding_mode"] inv_affine = np.linalg.inv(fwd_affine) affine_grid = AffineGrid(affine=inv_affine) @@ -1171,24 +1213,35 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda dtype=dtype, ) rot_mat = self.rotator.get_rotation_matrix() - self.push_transform(d, key, orig_size=orig_size, extra_info={"rot_mat": rot_mat}) + self.push_transform( + d, + key, + orig_size=orig_size, + extra_info={ + "rot_mat": rot_mat, + "mode": mode.value if isinstance(mode, Enum) else mode, + "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, + "align_corners": align_corners if align_corners is not None else "none", + }, + ) return d def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key, mode, padding_mode, align_corners, dtype in self.key_iterator( - d, self.mode, self.padding_mode, self.align_corners, self.dtype - ): + for key, dtype in self.key_iterator(d, self.dtype): transform = self.get_most_recent_transform(d, key) # Create inverse transform fwd_rot_mat = transform[InverseKeys.EXTRA_INFO]["rot_mat"] + mode = transform[InverseKeys.EXTRA_INFO]["mode"] + padding_mode = transform[InverseKeys.EXTRA_INFO]["padding_mode"] + align_corners = transform[InverseKeys.EXTRA_INFO]["align_corners"] inv_rot_mat = np.linalg.inv(fwd_rot_mat) xform = AffineTransform( normalized=False, mode=mode, padding_mode=padding_mode, - align_corners=align_corners, + align_corners=False if align_corners == "none" else align_corners, reverse_indexing=True, ) output = xform( @@ -1283,10 +1336,6 @@ def randomize(self, data: Optional[Any] = None) -> None: def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: self.randomize() d = dict(data) - if not self._do_transform: - for key in self.keys: - self.push_transform(d, key, extra_info={"rot_mat": np.eye(d[key].ndim)}) - return d angle: Union[Sequence[float], float] = self.x if d[self.keys[0]].ndim == 3 else (self.x, self.y, self.z) rotator = Rotate( angle=angle, @@ -1296,34 +1345,48 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d, self.mode, self.padding_mode, self.align_corners, self.dtype ): orig_size = d[key].shape[1:] - d[key] = rotator( - d[key], - mode=mode, - padding_mode=padding_mode, - align_corners=align_corners, - dtype=dtype, + if self._do_transform: + d[key] = rotator( + d[key], + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, + dtype=dtype, + ) + rot_mat = rotator.get_rotation_matrix() + else: + rot_mat = np.eye(d[key].ndim) + self.push_transform( + d, + key, + orig_size=orig_size, + extra_info={ + "rot_mat": rot_mat, + "mode": mode.value if isinstance(mode, Enum) else mode, + "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, + "align_corners": align_corners if align_corners is not None else "none", + }, ) - rot_mat = rotator.get_rotation_matrix() - self.push_transform(d, key, orig_size=orig_size, extra_info={"rot_mat": rot_mat}) return d def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key, mode, padding_mode, align_corners, dtype in self.key_iterator( - d, self.mode, self.padding_mode, self.align_corners, self.dtype - ): + for key, dtype in self.key_iterator(d, self.dtype): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) if transform[InverseKeys.DO_TRANSFORM]: # Create inverse transform fwd_rot_mat = transform[InverseKeys.EXTRA_INFO]["rot_mat"] + mode = transform[InverseKeys.EXTRA_INFO]["mode"] + padding_mode = transform[InverseKeys.EXTRA_INFO]["padding_mode"] + align_corners = transform[InverseKeys.EXTRA_INFO]["align_corners"] inv_rot_mat = np.linalg.inv(fwd_rot_mat) xform = AffineTransform( normalized=False, mode=mode, padding_mode=padding_mode, - align_corners=align_corners, + align_corners=False if align_corners == "none" else align_corners, reverse_indexing=True, ) output = xform( @@ -1384,7 +1447,15 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda for key, mode, padding_mode, align_corners in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners ): - self.push_transform(d, key) + self.push_transform( + d, + key, + extra_info={ + "mode": mode.value if isinstance(mode, Enum) else mode, + "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, + "align_corners": align_corners if align_corners is not None else "none", + }, + ) d[key] = self.zoomer( d[key], mode=mode, @@ -1395,19 +1466,20 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key, mode, padding_mode, align_corners in self.key_iterator( - d, self.mode, self.padding_mode, self.align_corners - ): + for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform zoom = np.array(self.zoomer.zoom) inverse_transform = Zoom(zoom=1 / zoom, keep_size=self.zoomer.keep_size) + mode = transform[InverseKeys.EXTRA_INFO]["mode"] + padding_mode = transform[InverseKeys.EXTRA_INFO]["padding_mode"] + align_corners = transform[InverseKeys.EXTRA_INFO]["align_corners"] # Apply inverse d[key] = inverse_transform( d[key], mode=mode, padding_mode=padding_mode, - align_corners=align_corners, + align_corners=None if align_corners == "none" else align_corners, ) # Size might be out by 1 voxel so pad d[key] = SpatialPad(transform[InverseKeys.ORIG_SIZE])(d[key]) @@ -1496,7 +1568,16 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda for key, mode, padding_mode, align_corners in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners ): - self.push_transform(d, key, extra_info={"zoom": self._zoom}) + self.push_transform( + d, + key, + extra_info={ + "zoom": self._zoom, + "mode": mode.value if isinstance(mode, Enum) else mode, + "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, + "align_corners": align_corners if align_corners is not None else "none", + }, + ) if self._do_transform: d[key] = zoomer( d[key], @@ -1508,21 +1589,22 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key, mode, padding_mode, align_corners in self.key_iterator( - d, self.mode, self.padding_mode, self.align_corners - ): + for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) if transform[InverseKeys.DO_TRANSFORM]: # Create inverse transform zoom = np.array(transform[InverseKeys.EXTRA_INFO]["zoom"]) + mode = transform[InverseKeys.EXTRA_INFO]["mode"] + padding_mode = transform[InverseKeys.EXTRA_INFO]["padding_mode"] + align_corners = transform[InverseKeys.EXTRA_INFO]["align_corners"] inverse_transform = Zoom(zoom=1 / zoom, keep_size=self.keep_size) # Apply inverse d[key] = inverse_transform( d[key], mode=mode, padding_mode=padding_mode, - align_corners=align_corners, + align_corners=None if align_corners == "none" else align_corners, ) # Size might be out by 1 voxel so pad d[key] = SpatialPad(transform[InverseKeys.ORIG_SIZE])(d[key]) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 7c4ea398f6..67da9ceb35 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -406,7 +406,6 @@ def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: def inverse(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: d = deepcopy(dict(data)) for key in self.key_iterator(d): - transform = self.get_most_recent_transform(d, key) # Create inverse transform inverse_transform = ToNumpy() # Apply inverse diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index eb1b194c96..b73a899153 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -22,8 +22,18 @@ from monai.networks.layers import GaussianFilter from monai.transforms.compose import Compose from monai.transforms.transform import MapTransform -from monai.utils import ensure_tuple, ensure_tuple_rep, ensure_tuple_size, fall_back_tuple, min_version, optional_import -from monai.utils.misc import issequenceiterable +from monai.utils import ( + GridSampleMode, + InterpolateMode, + InverseKeys, + ensure_tuple, + ensure_tuple_rep, + ensure_tuple_size, + fall_back_tuple, + issequenceiterable, + min_version, + optional_import, +) measure, _ = optional_import("skimage.measure", "0.14.2", min_version) @@ -53,6 +63,7 @@ "extreme_points_to_image", "map_spatial_axes", "allow_missing_keys_mode", + "convert_inverse_interp_mode", ] @@ -756,3 +767,34 @@ def allow_missing_keys_mode(transform: Union[MapTransform, Compose, Tuple[MapTra # Revert for t, o_s in zip(transforms, orig_states): t.allow_missing_keys = o_s + + +def convert_inverse_interp_mode(trans_info: List, mode: str = "nearest", align_corners: Optional[bool] = None): + """ + Change the interpolation mode when inverting spatial transforms, default to "nearest". + It can support both single data or batch data. + + Args: + trans_info: transforms inverse information list, contains context of every invertible transform. + mode: target interpolation mode to convert, default to "nearest" as it's usually used to save the mode output. + align_corners: target align corner value in PyTorch interpolation API, need to align with the `mode`. + + """ + interp_modes = [i.value for i in InterpolateMode] + [i.value for i in GridSampleMode] + + # set to string for DataLoader collation + align_corners_ = "none" if align_corners is None else align_corners + + for item in ensure_tuple(trans_info): + if InverseKeys.EXTRA_INFO in item: + orig_mode = item[InverseKeys.EXTRA_INFO].get("mode", None) + if orig_mode is not None: + if orig_mode[0] in interp_modes: + item[InverseKeys.EXTRA_INFO]["mode"] = [mode for _ in range(len(mode))] + elif orig_mode in interp_modes: + item[InverseKeys.EXTRA_INFO]["mode"] = mode + if "align_corners" in item[InverseKeys.EXTRA_INFO]: + if issequenceiterable(item[InverseKeys.EXTRA_INFO]["align_corners"]): + item[InverseKeys.EXTRA_INFO]["align_corners"] = [align_corners_ for _ in range(len(mode))] + else: + item[InverseKeys.EXTRA_INFO]["align_corners"] = align_corners_ diff --git a/tests/test_handler_transform_inverter.py b/tests/test_handler_transform_inverter.py index 48efd5df53..87414319cf 100644 --- a/tests/test_handler_transform_inverter.py +++ b/tests/test_handler_transform_inverter.py @@ -20,6 +20,7 @@ from monai.handlers import TransformInverter from monai.transforms import ( AddChanneld, + CastToTyped, Compose, LoadImaged, RandAffined, @@ -29,8 +30,10 @@ RandRotated, RandZoomd, ResizeWithPadOrCropd, + ScaleIntensityd, ToTensord, ) +from monai.utils.misc import set_determinism from tests.utils import make_nifti_image KEYS = ["image", "label"] @@ -38,19 +41,22 @@ class TestTransformInverter(unittest.TestCase): def test_invert(self): - im_fname, seg_fname = [make_nifti_image(i) for i in create_test_image_3d(101, 100, 107)] + set_determinism(seed=0) + im_fname, seg_fname = [make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100)] transform = Compose( [ LoadImaged(KEYS), AddChanneld(KEYS), + ScaleIntensityd(KEYS, minv=1, maxv=10), RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(KEYS, prob=0.5), RandRotate90d(KEYS, spatial_axes=(1, 2)), RandZoomd(KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), - RandRotated(KEYS, prob=0.5, range_x=np.pi), + RandRotated(KEYS, prob=0.5, range_x=np.pi, mode="bilinear", align_corners=True), RandAffined(KEYS, prob=0.5, rotate_range=np.pi), ResizeWithPadOrCropd(KEYS, 100), ToTensord(KEYS), + CastToTyped(KEYS, dtype=torch.uint8), ] ) data = [{"image": im_fname, "label": seg_fname} for _ in range(12)] @@ -69,11 +75,13 @@ def _train_func(engine, batch): engine = Engine(_train_func) # set up testing handler - TransformInverter(transform=transform, loader=loader, output_key="image").attach(engine) + TransformInverter(transform=transform, loader=loader, output_key="image", nearest_interp=True).attach(engine) engine.run(loader, max_epochs=1) + set_determinism(seed=None) self.assertTupleEqual(engine.state.output["image"].shape, (2, 1, 100, 100, 100)) for i in engine.state.output["image_inverted"]: + np.testing.assert_allclose(i.astype(np.uint8).astype(np.float32), i, rtol=1e-4) self.assertTupleEqual(i.shape, (1, 100, 101, 107)) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index ccc4f366c2..358bf0176a 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -54,6 +54,7 @@ SpatialPadd, Zoomd, allow_missing_keys_mode, + convert_inverse_interp_mode, ) from monai.utils import first, get_seed, optional_import, set_determinism from monai.utils.enums import InverseKeys @@ -572,9 +573,11 @@ def test_inverse_inferred_seg(self): segs_dict = {"label": segs, label_transform_key: data[label_transform_key]} segs_dict_decollated = decollate_batch(segs_dict) - # inverse of individual segmentation seg_dict = first(segs_dict_decollated) + # test to convert interpolation mode for 1 data of model output batch + convert_inverse_interp_mode(seg_dict, mode="nearest", align_corners=None) + with allow_missing_keys_mode(transforms): inv_seg = transforms.inverse(seg_dict)["label"] self.assertEqual(len(data["label_transforms"]), num_invertible_transforms)