From 0f937b0b1149288ba760ec3a89d74a5eab5db135 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 8 Apr 2021 18:30:01 +0800 Subject: [PATCH 01/14] [DLMED] add TransformInverter handler Signed-off-by: Nic Ma --- docs/source/handlers.rst | 5 ++ monai/data/inverse_batch_transform.py | 5 +- monai/handlers/__init__.py | 1 + monai/handlers/transform_inverter.py | 82 +++++++++++++++++++++++++++ 4 files changed, 92 insertions(+), 1 deletion(-) create mode 100644 monai/handlers/transform_inverter.py diff --git a/docs/source/handlers.rst b/docs/source/handlers.rst index 9030fa3ced..7c8498e37a 100644 --- a/docs/source/handlers.rst +++ b/docs/source/handlers.rst @@ -125,3 +125,8 @@ GarbageCollector handler ------------------------ .. autoclass:: GarbageCollector :members: + +Transform inverter +------------------ +.. autoclass:: TransformInverter + :members: diff --git a/monai/data/inverse_batch_transform.py b/monai/data/inverse_batch_transform.py index fa88114c84..02be39ccab 100644 --- a/monai/data/inverse_batch_transform.py +++ b/monai/data/inverse_batch_transform.py @@ -10,7 +10,7 @@ # limitations under the License. from typing import Any, Callable, Dict, Hashable, Optional, Sequence - +import warnings import numpy as np from torch.utils.data.dataloader import DataLoader as TorchDataLoader @@ -42,6 +42,9 @@ def _transform(self, index: int) -> Dict[Hashable, np.ndarray]: if self.pad_collation_used: data = PadListDataCollate.inverse(data) + if not isinstance(self.invertible_transform, InvertibleTransform): + warnings.warn("transform is not invertible, can't invert transform for the input data.") + return data return self.invertible_transform.inverse(data) diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index f88531ea8e..b0dbb82127 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -28,6 +28,7 @@ from .stats_handler import StatsHandler from .surface_distance import SurfaceDistance from .tensorboard_handlers import TensorBoardHandler, TensorBoardImageHandler, TensorBoardStatsHandler +from .transform_inverter import TransformInverter from .utils import ( evenly_divisible_all_gather, stopping_fn_from_loss, diff --git a/monai/handlers/transform_inverter.py b/monai/handlers/transform_inverter.py new file mode 100644 index 0000000000..86c08872e7 --- /dev/null +++ b/monai/handlers/transform_inverter.py @@ -0,0 +1,82 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Callable, Optional +import warnings +from torch.utils.data import DataLoader as TorchDataLoader +from monai.transforms import InvertibleTransform, allow_missing_keys_mode +from monai.data import BatchInverseTransform +from monai.utils import InverseKeys, exact_version, optional_import +from monai.engines.utils import CommonKeys + +Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") +if TYPE_CHECKING: + from ignite.engine import Engine +else: + Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") + + +class TransformInverter: + """ + Ignite handler to automatically invert all the pre-transforms that support `inverse`. + It takes `engine.state.output` as the input data and uses the transforms infomation from `engine.state.batch`. + + """ + def __init__( + self, + transform: InvertibleTransform, + loader: TorchDataLoader, + collate_fn: Optional[Callable] = lambda x: x, + batch_key: str = CommonKeys.IMAGE, + output_key: str = CommonKeys.PRED, + postfix: str = "inverted", + ) -> None: + """ + Args: + transform: a callable data transform on input data. + loader: data loader used to generate the batch of data. + collate_fn: how to collate data after inverse transformations. + default won't do any collation, so the output will be a list of size batch size. + batch_key: the key of input image in `ignite.engine.batch`. will get the applied transforms + for this input image, then invert them for the model output, default to "image". + output_key: the key of model output in `ignite.engine.output`, invert transforms on it. + postfix: will save the inverted result into `ignite.engine.output` with key `{ouput_key}_{postfix}`. + + """ + self.transform = transform + self.inverter = BatchInverseTransform(transform=transform, loader=loader, collate_fn=collate_fn) + self.batch_key = batch_key + self.output_key = output_key + self.postfix = postfix + + def attach(self, engine: Engine) -> None: + """ + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + """ + engine.add_event_handler(Events.ITERATION_COMPLETED, self) + + def __call__(self, engine: Engine) -> None: + """ + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + """ + transform_key = self.batch_key + InverseKeys.KEY_SUFFIX + if transform_key not in engine.state.batch: + warnings.warn("all the pre-transforms doesn't support inverse or no need to inverse.") + return + + segs_dict = { + self.batch_key: engine.state.output[self.output_key].detach().cpu(), + transform_key: engine.state.batch[transform_key]} + + with allow_missing_keys_mode(self.transform): + engine.state.output[f"{self.output_key}_{self.postfix}"] = self.inverter(segs_dict) From 3ad8aad3061be6effae4ed91311e6e5bc10b576a Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 8 Apr 2021 18:33:56 +0800 Subject: [PATCH 02/14] [DLMED] fix typo Signed-off-by: Nic Ma --- monai/handlers/transform_inverter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/handlers/transform_inverter.py b/monai/handlers/transform_inverter.py index 86c08872e7..4bbb5669a2 100644 --- a/monai/handlers/transform_inverter.py +++ b/monai/handlers/transform_inverter.py @@ -71,7 +71,7 @@ def __call__(self, engine: Engine) -> None: """ transform_key = self.batch_key + InverseKeys.KEY_SUFFIX if transform_key not in engine.state.batch: - warnings.warn("all the pre-transforms doesn't support inverse or no need to inverse.") + warnings.warn("all the pre-transforms are not InvertibleTransform or no need to invert.") return segs_dict = { From 13e1617caea4eb38a2549e13068ff186bc6a212a Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 8 Apr 2021 18:56:59 +0800 Subject: [PATCH 03/14] [DLMED] add support in SegmentationSaver handler Signed-off-by: Nic Ma --- monai/handlers/segmentation_saver.py | 12 +++++++++--- monai/handlers/transform_inverter.py | 3 ++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/monai/handlers/segmentation_saver.py b/monai/handlers/segmentation_saver.py index 6a98abf3ca..4baeac4e81 100644 --- a/monai/handlers/segmentation_saver.py +++ b/monai/handlers/segmentation_saver.py @@ -13,7 +13,6 @@ from typing import TYPE_CHECKING, Callable, Optional, Union import numpy as np - from monai.config import DtypeLike from monai.transforms import SaveImage from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode, exact_version, optional_import @@ -119,7 +118,6 @@ def __init__( output_dtype=output_dtype, squeeze_end_dims=squeeze_end_dims, data_root_dir=data_root_dir, - save_batch=True, ) self.batch_transform = batch_transform self.output_transform = output_transform @@ -147,5 +145,13 @@ def __call__(self, engine: Engine) -> None: """ meta_data = self.batch_transform(engine.state.batch) engine_output = self.output_transform(engine.state.output) - self._saver(engine_output, meta_data) + if isinstance(engine_output, (tuple, list)): + # if a list of data in shape: [channel, H, W, [D]], save every item separately + self._saver.save_batch = False + for i, d in enumerate(engine_output): + self._saver(d, {k: meta_data[k][i] for k in meta_data} if meta_data is not None else None) + else: + # if the data is in shape: [batch, channel, H, W, [D]] + self._saver.save_batch = True + self._saver(engine_output, meta_data) self.logger.info("saved all the model outputs into files.") diff --git a/monai/handlers/transform_inverter.py b/monai/handlers/transform_inverter.py index 4bbb5669a2..16bcb386ba 100644 --- a/monai/handlers/transform_inverter.py +++ b/monai/handlers/transform_inverter.py @@ -79,4 +79,5 @@ def __call__(self, engine: Engine) -> None: transform_key: engine.state.batch[transform_key]} with allow_missing_keys_mode(self.transform): - engine.state.output[f"{self.output_key}_{self.postfix}"] = self.inverter(segs_dict) + inverted_key = f"{self.output_key}_{self.postfix}" + engine.state.output[inverted_key] = [i[self.batch_key] for i in self.inverter(segs_dict)] From 2d00fcf6bf0a2dbd2b74c3fbad44726905bb9210 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 8 Apr 2021 20:22:29 +0800 Subject: [PATCH 04/14] [DLMED] fix flake8 issue Signed-off-by: Nic Ma --- monai/data/inverse_batch_transform.py | 3 +- monai/handlers/segmentation_saver.py | 1 + monai/handlers/transform_inverter.py | 14 ++-- monai/transforms/utility/dictionary.py | 17 ++++- tests/test_handler_transform_inverter.py | 81 ++++++++++++++++++++++++ tests/test_inverse_collation.py | 3 +- 6 files changed, 111 insertions(+), 8 deletions(-) create mode 100644 tests/test_handler_transform_inverter.py diff --git a/monai/data/inverse_batch_transform.py b/monai/data/inverse_batch_transform.py index 02be39ccab..a9f09b896d 100644 --- a/monai/data/inverse_batch_transform.py +++ b/monai/data/inverse_batch_transform.py @@ -9,8 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Hashable, Optional, Sequence import warnings +from typing import Any, Callable, Dict, Hashable, Optional, Sequence + import numpy as np from torch.utils.data.dataloader import DataLoader as TorchDataLoader diff --git a/monai/handlers/segmentation_saver.py b/monai/handlers/segmentation_saver.py index 4baeac4e81..279b514bd7 100644 --- a/monai/handlers/segmentation_saver.py +++ b/monai/handlers/segmentation_saver.py @@ -13,6 +13,7 @@ from typing import TYPE_CHECKING, Callable, Optional, Union import numpy as np + from monai.config import DtypeLike from monai.transforms import SaveImage from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode, exact_version, optional_import diff --git a/monai/handlers/transform_inverter.py b/monai/handlers/transform_inverter.py index 16bcb386ba..cbb57609bd 100644 --- a/monai/handlers/transform_inverter.py +++ b/monai/handlers/transform_inverter.py @@ -9,13 +9,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Callable, Optional import warnings +from typing import TYPE_CHECKING, Callable, Optional + from torch.utils.data import DataLoader as TorchDataLoader -from monai.transforms import InvertibleTransform, allow_missing_keys_mode + from monai.data import BatchInverseTransform -from monai.utils import InverseKeys, exact_version, optional_import from monai.engines.utils import CommonKeys +from monai.transforms import InvertibleTransform, allow_missing_keys_mode +from monai.utils import InverseKeys, exact_version, optional_import Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") if TYPE_CHECKING: @@ -30,6 +32,7 @@ class TransformInverter: It takes `engine.state.output` as the input data and uses the transforms infomation from `engine.state.batch`. """ + def __init__( self, transform: InvertibleTransform, @@ -76,8 +79,9 @@ def __call__(self, engine: Engine) -> None: segs_dict = { self.batch_key: engine.state.output[self.output_key].detach().cpu(), - transform_key: engine.state.batch[transform_key]} + transform_key: engine.state.batch[transform_key], + } - with allow_missing_keys_mode(self.transform): + with allow_missing_keys_mode(self.transform): # type: ignore inverted_key = f"{self.output_key}_{self.postfix}" engine.state.output[inverted_key] = [i[self.batch_key] for i in self.inverter(segs_dict)] diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 9464faa503..f671070811 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -17,12 +17,14 @@ import copy import logging +from copy import deepcopy from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np import torch from monai.config import DtypeLike, KeysCollection, NdarrayTensor +from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import MapTransform, Randomizable from monai.transforms.utility.array import ( AddChannel, @@ -379,7 +381,7 @@ def __call__( return d -class ToTensord(MapTransform): +class ToTensord(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.ToTensor`. """ @@ -397,9 +399,22 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: d = dict(data) for key in self.key_iterator(d): + self.push_transform(d, key) d[key] = self.converter(d[key]) return d + def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, Any]: + d = deepcopy(dict(data)) + for key in self.key_iterator(d): + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + inverse_transform = ToNumpy() + # Apply inverse + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.pop_transform(d, key) + return d + class ToNumpyd(MapTransform): """ diff --git a/tests/test_handler_transform_inverter.py b/tests/test_handler_transform_inverter.py new file mode 100644 index 0000000000..48efd5df53 --- /dev/null +++ b/tests/test_handler_transform_inverter.py @@ -0,0 +1,81 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +import numpy as np +import torch +from ignite.engine import Engine + +from monai.data import CacheDataset, DataLoader, create_test_image_3d +from monai.handlers import TransformInverter +from monai.transforms import ( + AddChanneld, + Compose, + LoadImaged, + RandAffined, + RandAxisFlipd, + RandFlipd, + RandRotate90d, + RandRotated, + RandZoomd, + ResizeWithPadOrCropd, + ToTensord, +) +from tests.utils import make_nifti_image + +KEYS = ["image", "label"] + + +class TestTransformInverter(unittest.TestCase): + def test_invert(self): + im_fname, seg_fname = [make_nifti_image(i) for i in create_test_image_3d(101, 100, 107)] + transform = Compose( + [ + LoadImaged(KEYS), + AddChanneld(KEYS), + RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]), + RandAxisFlipd(KEYS, prob=0.5), + RandRotate90d(KEYS, spatial_axes=(1, 2)), + RandZoomd(KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), + RandRotated(KEYS, prob=0.5, range_x=np.pi), + RandAffined(KEYS, prob=0.5, rotate_range=np.pi), + ResizeWithPadOrCropd(KEYS, 100), + ToTensord(KEYS), + ] + ) + data = [{"image": im_fname, "label": seg_fname} for _ in range(12)] + + # num workers = 0 for mac or gpu transforms + num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available() else 2 + + dataset = CacheDataset(data, transform=transform, progress=False) + loader = DataLoader(dataset, num_workers=num_workers, batch_size=5) + + # set up engine + def _train_func(engine, batch): + self.assertTupleEqual(batch["image"].shape[1:], (1, 100, 100, 100)) + return batch + + engine = Engine(_train_func) + + # set up testing handler + TransformInverter(transform=transform, loader=loader, output_key="image").attach(engine) + + engine.run(loader, max_epochs=1) + self.assertTupleEqual(engine.state.output["image"].shape, (2, 1, 100, 100, 100)) + for i in engine.state.output["image_inverted"]: + self.assertTupleEqual(i.shape, (1, 100, 101, 107)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_inverse_collation.py b/tests/test_inverse_collation.py index 3e07a8f0e2..c302e04017 100644 --- a/tests/test_inverse_collation.py +++ b/tests/test_inverse_collation.py @@ -29,6 +29,7 @@ RandRotated, RandZoomd, ResizeWithPadOrCropd, + ToTensord, ) from monai.utils import optional_import, set_determinism from tests.utils import make_nifti_image @@ -113,7 +114,7 @@ def test_collation(self, _, transform, collate_fn, ndim): if collate_fn: modified_transform = transform else: - modified_transform = Compose([transform, ResizeWithPadOrCropd(KEYS, 100)]) + modified_transform = Compose([transform, ResizeWithPadOrCropd(KEYS, 100), ToTensord(KEYS)]) # num workers = 0 for mac or gpu transforms num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available() else 2 From 03032f01655f4423838fd278cf23886d570b8a93 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 8 Apr 2021 20:45:19 +0800 Subject: [PATCH 05/14] [DLMED] fix flake8 issue Signed-off-by: Nic Ma --- monai/transforms/utility/dictionary.py | 2 +- tests/min_tests.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index f671070811..7c4ea398f6 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -403,7 +403,7 @@ def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: d[key] = self.converter(d[key]) return d - def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, Any]: + def inverse(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) diff --git a/tests/min_tests.py b/tests/min_tests.py index 98f6d822a7..d346afc7e3 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -116,6 +116,7 @@ def run_testsuit(): "test_ensure_channel_first", "test_ensure_channel_firstd", "test_handler_early_stop", + "test_handler_transform_inverter", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" From f3413c5b5230995688c8679099cc02c9ef681ddf Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 8 Apr 2021 21:53:22 +0800 Subject: [PATCH 06/14] [DLMED] fix CI test Signed-off-by: Nic Ma --- monai/data/inverse_batch_transform.py | 5 +---- monai/data/utils.py | 8 ++++++++ monai/handlers/transform_inverter.py | 13 ++++++++++--- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/monai/data/inverse_batch_transform.py b/monai/data/inverse_batch_transform.py index a9f09b896d..15106c2e07 100644 --- a/monai/data/inverse_batch_transform.py +++ b/monai/data/inverse_batch_transform.py @@ -17,6 +17,7 @@ from monai.data.dataloader import DataLoader from monai.data.dataset import Dataset +from monai.data.utils import no_collation from monai.data.utils import decollate_batch, pad_list_data_collate from monai.transforms.croppad.batch import PadListDataCollate from monai.transforms.inverse import InvertibleTransform @@ -49,10 +50,6 @@ def _transform(self, index: int) -> Dict[Hashable, np.ndarray]: return self.invertible_transform.inverse(data) -def no_collation(x): - return x - - class BatchInverseTransform(Transform): """Perform inverse on a batch of data. This is useful if you have inferred a batch of images and want to invert them all.""" diff --git a/monai/data/utils.py b/monai/data/utils.py index 938365460b..d39f2702ff 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -65,6 +65,7 @@ "sorted_dict", "decollate_batch", "pad_list_data_collate", + "no_collation", ] @@ -379,6 +380,13 @@ def pad_list_data_collate( return PadListDataCollate(method, mode)(batch) +def no_collation(x): + """ + No any collation operation. + """ + return x + + def worker_init_fn(worker_id: int) -> None: """ Callback function for PyTorch DataLoader `worker_init_fn`. diff --git a/monai/handlers/transform_inverter.py b/monai/handlers/transform_inverter.py index cbb57609bd..42c5bdcf92 100644 --- a/monai/handlers/transform_inverter.py +++ b/monai/handlers/transform_inverter.py @@ -15,6 +15,7 @@ from torch.utils.data import DataLoader as TorchDataLoader from monai.data import BatchInverseTransform +from monai.data.utils import no_collation from monai.engines.utils import CommonKeys from monai.transforms import InvertibleTransform, allow_missing_keys_mode from monai.utils import InverseKeys, exact_version, optional_import @@ -31,13 +32,19 @@ class TransformInverter: Ignite handler to automatically invert all the pre-transforms that support `inverse`. It takes `engine.state.output` as the input data and uses the transforms infomation from `engine.state.batch`. + Note: + This handler is experimental API in v0.5, the interpolation mode in the transforms + and inverse transforms are the same, so maybe it's not correct as we may want to use `bilinear` + for input image but use `nearest` when inverting transforms for model outout. + For this case, a solution is to set `batch_key` to the label field if we have labels. + """ def __init__( self, transform: InvertibleTransform, loader: TorchDataLoader, - collate_fn: Optional[Callable] = lambda x: x, + collate_fn: Optional[Callable] = no_collation, batch_key: str = CommonKeys.IMAGE, output_key: str = CommonKeys.PRED, postfix: str = "inverted", @@ -48,8 +55,8 @@ def __init__( loader: data loader used to generate the batch of data. collate_fn: how to collate data after inverse transformations. default won't do any collation, so the output will be a list of size batch size. - batch_key: the key of input image in `ignite.engine.batch`. will get the applied transforms - for this input image, then invert them for the model output, default to "image". + batch_key: the key of input data in `ignite.engine.batch`. will get the applied transforms + for this input data, then invert them for the model output, default to "image". output_key: the key of model output in `ignite.engine.output`, invert transforms on it. postfix: will save the inverted result into `ignite.engine.output` with key `{ouput_key}_{postfix}`. From d108dd20325b434727a227f353d2b3b54fed5130 Mon Sep 17 00:00:00 2001 From: monai-bot Date: Thu, 8 Apr 2021 13:58:10 +0000 Subject: [PATCH 07/14] [MONAI] python code formatting Signed-off-by: monai-bot --- monai/data/inverse_batch_transform.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/monai/data/inverse_batch_transform.py b/monai/data/inverse_batch_transform.py index 15106c2e07..edfaee3758 100644 --- a/monai/data/inverse_batch_transform.py +++ b/monai/data/inverse_batch_transform.py @@ -17,8 +17,7 @@ from monai.data.dataloader import DataLoader from monai.data.dataset import Dataset -from monai.data.utils import no_collation -from monai.data.utils import decollate_batch, pad_list_data_collate +from monai.data.utils import decollate_batch, no_collation, pad_list_data_collate from monai.transforms.croppad.batch import PadListDataCollate from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import Transform From dbf3f6ab0c6c57eddeac1c08a36794af2b8bceff Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 9 Apr 2021 00:02:55 +0800 Subject: [PATCH 08/14] [DLMED] save mode into inverse dict Signed-off-by: Nic Ma --- monai/transforms/croppad/dictionary.py | 21 +++-- monai/transforms/spatial/dictionary.py | 103 ++++++++++++++++++------- 2 files changed, 88 insertions(+), 36 deletions(-) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index c8d5ceea40..5969dfb7db 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -16,6 +16,7 @@ """ from copy import deepcopy +from enum import Enum from itertools import chain from math import floor from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union @@ -125,7 +126,7 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key, m in self.key_iterator(d, self.mode): - self.push_transform(d, key) + self.push_transform(d, key, extra_info={"mode": m.value if isinstance(m, Enum) else m}) d[key] = self.padder(d[key], mode=m) return d @@ -193,7 +194,7 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key, m in self.key_iterator(d, self.mode): - self.push_transform(d, key) + self.push_transform(d, key, extra_info={"mode": m.value if isinstance(m, Enum) else m}) d[key] = self.padder(d[key], mode=m) return d @@ -259,7 +260,7 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key, m in self.key_iterator(d, self.mode): - self.push_transform(d, key) + self.push_transform(d, key, extra_info={"mode": m.value if isinstance(m, Enum) else m}) d[key] = self.padder(d[key], mode=m) return d @@ -826,6 +827,7 @@ class ResizeWithPadOrCropd(MapTransform, InvertibleTransform): ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} One of the listed string values or a user supplied function for padding. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + It also can be a sequence of string, each element corresponds to a key in ``keys``. allow_missing_keys: don't raise exception if key is missing. """ @@ -834,18 +836,21 @@ def __init__( self, keys: KeysCollection, spatial_size: Union[Sequence[int], int], - mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT, + mode: NumpyPadModeSequence = NumpyPadMode.CONSTANT, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) - self.padcropper = ResizeWithPadOrCrop(spatial_size=spatial_size, mode=mode) + self.mode = ensure_tuple_rep(mode, len(self.keys)) + self.padcropper = ResizeWithPadOrCrop(spatial_size=spatial_size) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.key_iterator(d): + for key, m in self.key_iterator(d, self.mode): orig_size = d[key].shape[1:] - d[key] = self.padcropper(d[key]) - self.push_transform(d, key, orig_size=orig_size) + d[key] = self.padcropper(d[key], mode=m) + self.push_transform(d, key, orig_size=orig_size, extra_info={ + "mode": m.value if isinstance(m, Enum) else m, + }) return d def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 86c94302a1..53bbeae209 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -17,7 +17,7 @@ from copy import deepcopy from typing import Any, Dict, Hashable, Mapping, Optional, Sequence, Tuple, Union - +from enum import Enum import numpy as np import torch @@ -203,21 +203,25 @@ def __call__( d[key], old_affine, new_affine = self.spacing_transform( data_array=np.asarray(d[key]), affine=meta_data["affine"], - mode=mode, + mode=mode.value if isinstance(mode, Enum) else mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype, ) - self.push_transform(d, key, extra_info={"meta_data_key": meta_data_key, "old_affine": old_affine}) + self.push_transform(d, key, extra_info={ + "meta_data_key": meta_data_key, + "old_affine": old_affine, + "mode": mode.value if isinstance(mode, Enum) else mode, + "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, + "align_corners": align_corners, + }) # set the 'affine' key meta_data["affine"] = new_affine return d def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key, mode, padding_mode, align_corners, dtype in self.key_iterator( - d, self.mode, self.padding_mode, self.align_corners, self.dtype - ): + for key, dtype in self.key_iterator(d, self.dtype): transform = self.get_most_recent_transform(d, key) if self.spacing_transform.diagonal: raise RuntimeError( @@ -227,6 +231,9 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar # Create inverse transform meta_data = d[transform[InverseKeys.EXTRA_INFO]["meta_data_key"]] old_affine = np.array(transform[InverseKeys.EXTRA_INFO]["old_affine"]) + mode = transform[InverseKeys.EXTRA_INFO]["mode"] + padding_mode = transform[InverseKeys.EXTRA_INFO]["padding_mode"] + align_corners = transform[InverseKeys.EXTRA_INFO]["align_corners"] orig_pixdim = np.sqrt(np.sum(np.square(old_affine), 0))[:-1] inverse_transform = Spacing(orig_pixdim, diagonal=self.spacing_transform.diagonal) # Apply inverse @@ -483,15 +490,20 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key, mode, align_corners in self.key_iterator(d, self.mode, self.align_corners): - self.push_transform(d, key) + self.push_transform(d, key, extra_info={ + "mode": mode.value if isinstance(mode, Enum) else mode, + "align_corners": align_corners, + }) d[key] = self.resizer(d[key], mode=mode, align_corners=align_corners) return d def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key, mode, align_corners in self.key_iterator(d, self.mode, self.align_corners): + for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) orig_size = transform[InverseKeys.ORIG_SIZE] + mode = transform[InverseKeys.EXTRA_INFO]["mode"] + align_corners = transform[InverseKeys.EXTRA_INFO]["align_corners"] # Create inverse transform inverse_transform = Resize(orig_size, mode, align_corners) # Apply inverse transform @@ -573,17 +585,23 @@ def __call__( for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): orig_size = d[key].shape[1:] d[key], affine = self.affine(d[key], mode=mode, padding_mode=padding_mode) - self.push_transform(d, key, orig_size=orig_size, extra_info={"affine": affine}) + self.push_transform(d, key, orig_size=orig_size, extra_info={ + "affine": affine, + "mode": mode.value if isinstance(mode, Enum) else mode, + "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, + }) return d def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): + for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) orig_size = transform[InverseKeys.ORIG_SIZE] # Create inverse transform fwd_affine = transform[InverseKeys.EXTRA_INFO]["affine"] + mode = transform[InverseKeys.EXTRA_INFO]["mode"] + padding_mode = transform[InverseKeys.EXTRA_INFO]["padding_mode"] inv_affine = np.linalg.inv(fwd_affine) affine_grid = AffineGrid(affine=inv_affine) @@ -701,18 +719,24 @@ def __call__( affine = torch.as_tensor(np.eye(len(sp_size) + 1), device=self.rand_affine.rand_affine_grid.device) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): - self.push_transform(d, key, extra_info={"affine": affine}) + self.push_transform(d, key, extra_info={ + "affine": affine, + "mode": mode.value if isinstance(mode, Enum) else mode, + "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, + }) d[key] = self.rand_affine.resampler(d[key], grid, mode=mode, padding_mode=padding_mode) return d def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): + for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) orig_size = transform[InverseKeys.ORIG_SIZE] # Create inverse transform fwd_affine = transform[InverseKeys.EXTRA_INFO]["affine"] + mode = transform[InverseKeys.EXTRA_INFO]["mode"] + padding_mode = transform[InverseKeys.EXTRA_INFO]["padding_mode"] inv_affine = np.linalg.inv(fwd_affine) affine_grid = AffineGrid(affine=inv_affine) @@ -1171,17 +1195,23 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda dtype=dtype, ) rot_mat = self.rotator.get_rotation_matrix() - self.push_transform(d, key, orig_size=orig_size, extra_info={"rot_mat": rot_mat}) + self.push_transform(d, key, orig_size=orig_size, extra_info={ + "rot_mat": rot_mat, + "mode": mode.value if isinstance(mode, Enum) else mode, + "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, + "align_corners": align_corners, + }) return d def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key, mode, padding_mode, align_corners, dtype in self.key_iterator( - d, self.mode, self.padding_mode, self.align_corners, self.dtype - ): + for key, dtype in self.key_iterator(d, self.dtype): transform = self.get_most_recent_transform(d, key) # Create inverse transform fwd_rot_mat = transform[InverseKeys.EXTRA_INFO]["rot_mat"] + mode = transform[InverseKeys.EXTRA_INFO]["mode"] + padding_mode = transform[InverseKeys.EXTRA_INFO]["padding_mode"] + align_corners = transform[InverseKeys.EXTRA_INFO]["align_corners"] inv_rot_mat = np.linalg.inv(fwd_rot_mat) xform = AffineTransform( @@ -1304,19 +1334,25 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda dtype=dtype, ) rot_mat = rotator.get_rotation_matrix() - self.push_transform(d, key, orig_size=orig_size, extra_info={"rot_mat": rot_mat}) + self.push_transform(d, key, orig_size=orig_size, extra_info={ + "rot_mat": rot_mat, + "mode": mode.value if isinstance(mode, Enum) else mode, + "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, + "align_corners": align_corners, + }) return d def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key, mode, padding_mode, align_corners, dtype in self.key_iterator( - d, self.mode, self.padding_mode, self.align_corners, self.dtype - ): + for key, dtype in self.key_iterator(d, self.dtype): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) if transform[InverseKeys.DO_TRANSFORM]: # Create inverse transform fwd_rot_mat = transform[InverseKeys.EXTRA_INFO]["rot_mat"] + mode = transform[InverseKeys.EXTRA_INFO]["mode"] + padding_mode = transform[InverseKeys.EXTRA_INFO]["padding_mode"] + align_corners = transform[InverseKeys.EXTRA_INFO]["align_corners"] inv_rot_mat = np.linalg.inv(fwd_rot_mat) xform = AffineTransform( @@ -1384,7 +1420,11 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda for key, mode, padding_mode, align_corners in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners ): - self.push_transform(d, key) + self.push_transform(d, key, extra_info={ + "mode": mode.value if isinstance(mode, Enum) else mode, + "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, + "align_corners": align_corners, + }) d[key] = self.zoomer( d[key], mode=mode, @@ -1395,13 +1435,14 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key, mode, padding_mode, align_corners in self.key_iterator( - d, self.mode, self.padding_mode, self.align_corners - ): + for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform zoom = np.array(self.zoomer.zoom) inverse_transform = Zoom(zoom=1 / zoom, keep_size=self.zoomer.keep_size) + mode = transform[InverseKeys.EXTRA_INFO]["mode"] + padding_mode = transform[InverseKeys.EXTRA_INFO]["padding_mode"] + align_corners = transform[InverseKeys.EXTRA_INFO]["align_corners"] # Apply inverse d[key] = inverse_transform( d[key], @@ -1496,7 +1537,12 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda for key, mode, padding_mode, align_corners in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners ): - self.push_transform(d, key, extra_info={"zoom": self._zoom}) + self.push_transform(d, key, extra_info={ + "zoom": self._zoom, + "mode": mode.value if isinstance(mode, Enum) else mode, + "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, + "align_corners": align_corners, + }) if self._do_transform: d[key] = zoomer( d[key], @@ -1508,14 +1554,15 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key, mode, padding_mode, align_corners in self.key_iterator( - d, self.mode, self.padding_mode, self.align_corners - ): + for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) if transform[InverseKeys.DO_TRANSFORM]: # Create inverse transform zoom = np.array(transform[InverseKeys.EXTRA_INFO]["zoom"]) + mode = transform[InverseKeys.EXTRA_INFO]["mode"] + padding_mode = transform[InverseKeys.EXTRA_INFO]["padding_mode"] + align_corners = transform[InverseKeys.EXTRA_INFO]["align_corners"] inverse_transform = Zoom(zoom=1 / zoom, keep_size=self.keep_size) # Apply inverse d[key] = inverse_transform( From fbd5775e5b62d1cce2b7990d4e3953fdb3be1f65 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 9 Apr 2021 02:36:42 +0800 Subject: [PATCH 09/14] [DLMED] add unit tests Signed-off-by: Nic Ma --- monai/handlers/transform_inverter.py | 25 ++-- monai/transforms/croppad/dictionary.py | 11 +- monai/transforms/spatial/dictionary.py | 161 ++++++++++++++--------- tests/test_handler_transform_inverter.py | 14 +- 4 files changed, 134 insertions(+), 77 deletions(-) diff --git a/monai/handlers/transform_inverter.py b/monai/handlers/transform_inverter.py index 42c5bdcf92..d5d8fc6805 100644 --- a/monai/handlers/transform_inverter.py +++ b/monai/handlers/transform_inverter.py @@ -18,7 +18,7 @@ from monai.data.utils import no_collation from monai.engines.utils import CommonKeys from monai.transforms import InvertibleTransform, allow_missing_keys_mode -from monai.utils import InverseKeys, exact_version, optional_import +from monai.utils import GridSampleMode, InterpolateMode, InverseKeys, exact_version, optional_import Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") if TYPE_CHECKING: @@ -32,12 +32,6 @@ class TransformInverter: Ignite handler to automatically invert all the pre-transforms that support `inverse`. It takes `engine.state.output` as the input data and uses the transforms infomation from `engine.state.batch`. - Note: - This handler is experimental API in v0.5, the interpolation mode in the transforms - and inverse transforms are the same, so maybe it's not correct as we may want to use `bilinear` - for input image but use `nearest` when inverting transforms for model outout. - For this case, a solution is to set `batch_key` to the label field if we have labels. - """ def __init__( @@ -48,6 +42,7 @@ def __init__( batch_key: str = CommonKeys.IMAGE, output_key: str = CommonKeys.PRED, postfix: str = "inverted", + nearest_interp: bool = True, ) -> None: """ Args: @@ -59,6 +54,8 @@ def __init__( for this input data, then invert them for the model output, default to "image". output_key: the key of model output in `ignite.engine.output`, invert transforms on it. postfix: will save the inverted result into `ignite.engine.output` with key `{ouput_key}_{postfix}`. + nearest_interp: whether to use `nearest` interpolation mode when inverting spatial transforms, + default to `True`. if `False`, use the same interpolation mode as the original transform. """ self.transform = transform @@ -66,6 +63,7 @@ def __init__( self.batch_key = batch_key self.output_key = output_key self.postfix = postfix + self.nearest_interp = nearest_interp def attach(self, engine: Engine) -> None: """ @@ -84,9 +82,20 @@ def __call__(self, engine: Engine) -> None: warnings.warn("all the pre-transforms are not InvertibleTransform or no need to invert.") return + transform_info = engine.state.batch[transform_key] + if self.nearest_interp: + interp_modes = [i.value for i in InterpolateMode] + [i.value for i in GridSampleMode] + for item in transform_info: + if InverseKeys.EXTRA_INFO in item: + mode = item[InverseKeys.EXTRA_INFO].get("mode", None) + if mode is not None and mode[0] in interp_modes: + item[InverseKeys.EXTRA_INFO]["mode"] = ["nearest" for _ in range(len(mode))] + if "align_corners" in item[InverseKeys.EXTRA_INFO]: + item[InverseKeys.EXTRA_INFO]["align_corners"] = ["none" for _ in range(len(mode))] + segs_dict = { self.batch_key: engine.state.output[self.output_key].detach().cpu(), - transform_key: engine.state.batch[transform_key], + transform_key: transform_info, } with allow_missing_keys_mode(self.transform): # type: ignore diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 5969dfb7db..c4ef659c69 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -848,9 +848,14 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda for key, m in self.key_iterator(d, self.mode): orig_size = d[key].shape[1:] d[key] = self.padcropper(d[key], mode=m) - self.push_transform(d, key, orig_size=orig_size, extra_info={ - "mode": m.value if isinstance(m, Enum) else m, - }) + self.push_transform( + d, + key, + orig_size=orig_size, + extra_info={ + "mode": m.value if isinstance(m, Enum) else m, + }, + ) return d def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 53bbeae209..9f782bf8fc 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -16,8 +16,9 @@ """ from copy import deepcopy -from typing import Any, Dict, Hashable, Mapping, Optional, Sequence, Tuple, Union from enum import Enum +from typing import Any, Dict, Hashable, Mapping, Optional, Sequence, Tuple, Union + import numpy as np import torch @@ -203,18 +204,22 @@ def __call__( d[key], old_affine, new_affine = self.spacing_transform( data_array=np.asarray(d[key]), affine=meta_data["affine"], - mode=mode.value if isinstance(mode, Enum) else mode, + mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype, ) - self.push_transform(d, key, extra_info={ - "meta_data_key": meta_data_key, - "old_affine": old_affine, - "mode": mode.value if isinstance(mode, Enum) else mode, - "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, - "align_corners": align_corners, - }) + self.push_transform( + d, + key, + extra_info={ + "meta_data_key": meta_data_key, + "old_affine": old_affine, + "mode": mode.value if isinstance(mode, Enum) else mode, + "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, + "align_corners": align_corners if align_corners is not None else "none", + }, + ) # set the 'affine' key meta_data["affine"] = new_affine return d @@ -242,7 +247,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar affine=meta_data["affine"], mode=mode, padding_mode=padding_mode, - align_corners=align_corners, + align_corners=False if align_corners == "none" else align_corners, dtype=dtype, ) meta_data["affine"] = new_affine @@ -490,10 +495,14 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key, mode, align_corners in self.key_iterator(d, self.mode, self.align_corners): - self.push_transform(d, key, extra_info={ - "mode": mode.value if isinstance(mode, Enum) else mode, - "align_corners": align_corners, - }) + self.push_transform( + d, + key, + extra_info={ + "mode": mode.value if isinstance(mode, Enum) else mode, + "align_corners": align_corners if align_corners is not None else "none", + }, + ) d[key] = self.resizer(d[key], mode=mode, align_corners=align_corners) return d @@ -505,7 +514,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar mode = transform[InverseKeys.EXTRA_INFO]["mode"] align_corners = transform[InverseKeys.EXTRA_INFO]["align_corners"] # Create inverse transform - inverse_transform = Resize(orig_size, mode, align_corners) + inverse_transform = Resize(orig_size, mode, None if align_corners == "none" else align_corners) # Apply inverse transform d[key] = inverse_transform(d[key]) # Remove the applied transform @@ -585,11 +594,16 @@ def __call__( for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): orig_size = d[key].shape[1:] d[key], affine = self.affine(d[key], mode=mode, padding_mode=padding_mode) - self.push_transform(d, key, orig_size=orig_size, extra_info={ - "affine": affine, - "mode": mode.value if isinstance(mode, Enum) else mode, - "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, - }) + self.push_transform( + d, + key, + orig_size=orig_size, + extra_info={ + "affine": affine, + "mode": mode.value if isinstance(mode, Enum) else mode, + "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, + }, + ) return d def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: @@ -719,11 +733,15 @@ def __call__( affine = torch.as_tensor(np.eye(len(sp_size) + 1), device=self.rand_affine.rand_affine_grid.device) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): - self.push_transform(d, key, extra_info={ - "affine": affine, - "mode": mode.value if isinstance(mode, Enum) else mode, - "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, - }) + self.push_transform( + d, + key, + extra_info={ + "affine": affine, + "mode": mode.value if isinstance(mode, Enum) else mode, + "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, + }, + ) d[key] = self.rand_affine.resampler(d[key], grid, mode=mode, padding_mode=padding_mode) return d @@ -1195,12 +1213,17 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda dtype=dtype, ) rot_mat = self.rotator.get_rotation_matrix() - self.push_transform(d, key, orig_size=orig_size, extra_info={ - "rot_mat": rot_mat, - "mode": mode.value if isinstance(mode, Enum) else mode, - "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, - "align_corners": align_corners, - }) + self.push_transform( + d, + key, + orig_size=orig_size, + extra_info={ + "rot_mat": rot_mat, + "mode": mode.value if isinstance(mode, Enum) else mode, + "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, + "align_corners": align_corners if align_corners is not None else "none", + }, + ) return d def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: @@ -1218,7 +1241,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar normalized=False, mode=mode, padding_mode=padding_mode, - align_corners=align_corners, + align_corners=False if align_corners == "none" else align_corners, reverse_indexing=True, ) output = xform( @@ -1313,10 +1336,6 @@ def randomize(self, data: Optional[Any] = None) -> None: def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: self.randomize() d = dict(data) - if not self._do_transform: - for key in self.keys: - self.push_transform(d, key, extra_info={"rot_mat": np.eye(d[key].ndim)}) - return d angle: Union[Sequence[float], float] = self.x if d[self.keys[0]].ndim == 3 else (self.x, self.y, self.z) rotator = Rotate( angle=angle, @@ -1326,20 +1345,28 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d, self.mode, self.padding_mode, self.align_corners, self.dtype ): orig_size = d[key].shape[1:] - d[key] = rotator( - d[key], - mode=mode, - padding_mode=padding_mode, - align_corners=align_corners, - dtype=dtype, + if self._do_transform: + d[key] = rotator( + d[key], + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, + dtype=dtype, + ) + rot_mat = rotator.get_rotation_matrix() + else: + rot_mat = np.eye(d[key].ndim) + self.push_transform( + d, + key, + orig_size=orig_size, + extra_info={ + "rot_mat": rot_mat, + "mode": mode.value if isinstance(mode, Enum) else mode, + "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, + "align_corners": align_corners if align_corners is not None else "none", + }, ) - rot_mat = rotator.get_rotation_matrix() - self.push_transform(d, key, orig_size=orig_size, extra_info={ - "rot_mat": rot_mat, - "mode": mode.value if isinstance(mode, Enum) else mode, - "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, - "align_corners": align_corners, - }) return d def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: @@ -1359,7 +1386,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar normalized=False, mode=mode, padding_mode=padding_mode, - align_corners=align_corners, + align_corners=False if align_corners == "none" else align_corners, reverse_indexing=True, ) output = xform( @@ -1420,11 +1447,15 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda for key, mode, padding_mode, align_corners in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners ): - self.push_transform(d, key, extra_info={ - "mode": mode.value if isinstance(mode, Enum) else mode, - "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, - "align_corners": align_corners, - }) + self.push_transform( + d, + key, + extra_info={ + "mode": mode.value if isinstance(mode, Enum) else mode, + "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, + "align_corners": align_corners if align_corners is not None else "none", + }, + ) d[key] = self.zoomer( d[key], mode=mode, @@ -1448,7 +1479,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar d[key], mode=mode, padding_mode=padding_mode, - align_corners=align_corners, + align_corners=None if align_corners == "none" else align_corners, ) # Size might be out by 1 voxel so pad d[key] = SpatialPad(transform[InverseKeys.ORIG_SIZE])(d[key]) @@ -1537,12 +1568,16 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda for key, mode, padding_mode, align_corners in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners ): - self.push_transform(d, key, extra_info={ - "zoom": self._zoom, - "mode": mode.value if isinstance(mode, Enum) else mode, - "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, - "align_corners": align_corners, - }) + self.push_transform( + d, + key, + extra_info={ + "zoom": self._zoom, + "mode": mode.value if isinstance(mode, Enum) else mode, + "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, + "align_corners": align_corners if align_corners is not None else "none", + }, + ) if self._do_transform: d[key] = zoomer( d[key], @@ -1569,7 +1604,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar d[key], mode=mode, padding_mode=padding_mode, - align_corners=align_corners, + align_corners=None if align_corners == "none" else align_corners, ) # Size might be out by 1 voxel so pad d[key] = SpatialPad(transform[InverseKeys.ORIG_SIZE])(d[key]) diff --git a/tests/test_handler_transform_inverter.py b/tests/test_handler_transform_inverter.py index 48efd5df53..87414319cf 100644 --- a/tests/test_handler_transform_inverter.py +++ b/tests/test_handler_transform_inverter.py @@ -20,6 +20,7 @@ from monai.handlers import TransformInverter from monai.transforms import ( AddChanneld, + CastToTyped, Compose, LoadImaged, RandAffined, @@ -29,8 +30,10 @@ RandRotated, RandZoomd, ResizeWithPadOrCropd, + ScaleIntensityd, ToTensord, ) +from monai.utils.misc import set_determinism from tests.utils import make_nifti_image KEYS = ["image", "label"] @@ -38,19 +41,22 @@ class TestTransformInverter(unittest.TestCase): def test_invert(self): - im_fname, seg_fname = [make_nifti_image(i) for i in create_test_image_3d(101, 100, 107)] + set_determinism(seed=0) + im_fname, seg_fname = [make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100)] transform = Compose( [ LoadImaged(KEYS), AddChanneld(KEYS), + ScaleIntensityd(KEYS, minv=1, maxv=10), RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(KEYS, prob=0.5), RandRotate90d(KEYS, spatial_axes=(1, 2)), RandZoomd(KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), - RandRotated(KEYS, prob=0.5, range_x=np.pi), + RandRotated(KEYS, prob=0.5, range_x=np.pi, mode="bilinear", align_corners=True), RandAffined(KEYS, prob=0.5, rotate_range=np.pi), ResizeWithPadOrCropd(KEYS, 100), ToTensord(KEYS), + CastToTyped(KEYS, dtype=torch.uint8), ] ) data = [{"image": im_fname, "label": seg_fname} for _ in range(12)] @@ -69,11 +75,13 @@ def _train_func(engine, batch): engine = Engine(_train_func) # set up testing handler - TransformInverter(transform=transform, loader=loader, output_key="image").attach(engine) + TransformInverter(transform=transform, loader=loader, output_key="image", nearest_interp=True).attach(engine) engine.run(loader, max_epochs=1) + set_determinism(seed=None) self.assertTupleEqual(engine.state.output["image"].shape, (2, 1, 100, 100, 100)) for i in engine.state.output["image_inverted"]: + np.testing.assert_allclose(i.astype(np.uint8).astype(np.float32), i, rtol=1e-4) self.assertTupleEqual(i.shape, (1, 100, 101, 107)) From 5d3b7048c01f6527d08cdcb8a467b0ae46a0afc7 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 9 Apr 2021 19:54:03 +0800 Subject: [PATCH 10/14] [DLMED] fix ToTensor inverse issue Signed-off-by: Nic Ma --- monai/transforms/inverse.py | 2 +- monai/transforms/utility/dictionary.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index 3e5b68e8e4..3baef91717 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -76,7 +76,7 @@ def push_transform( info = { InverseKeys.CLASS_NAME: self.__class__.__name__, InverseKeys.ID: id(self), - InverseKeys.ORIG_SIZE: orig_size or data[key].shape[1:], + InverseKeys.ORIG_SIZE: orig_size or (data[key].shape[1:] if hasattr(data[key], "shape") else None), } if extra_info is not None: info[InverseKeys.EXTRA_INFO] = extra_info diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 7c4ea398f6..67da9ceb35 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -406,7 +406,6 @@ def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: def inverse(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: d = deepcopy(dict(data)) for key in self.key_iterator(d): - transform = self.get_most_recent_transform(d, key) # Create inverse transform inverse_transform = ToNumpy() # Apply inverse From f4bc282c815ea89ec2c64494db2976f7ad7c3352 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 9 Apr 2021 20:30:06 +0800 Subject: [PATCH 11/14] [DLMED] change the replacement logic into util function Signed-off-by: Nic Ma --- monai/handlers/transform_inverter.py | 13 ++------ monai/transforms/__init__.py | 1 + monai/transforms/utils.py | 46 ++++++++++++++++++++++++++-- 3 files changed, 48 insertions(+), 12 deletions(-) diff --git a/monai/handlers/transform_inverter.py b/monai/handlers/transform_inverter.py index d5d8fc6805..68201e44be 100644 --- a/monai/handlers/transform_inverter.py +++ b/monai/handlers/transform_inverter.py @@ -17,8 +17,8 @@ from monai.data import BatchInverseTransform from monai.data.utils import no_collation from monai.engines.utils import CommonKeys -from monai.transforms import InvertibleTransform, allow_missing_keys_mode -from monai.utils import GridSampleMode, InterpolateMode, InverseKeys, exact_version, optional_import +from monai.transforms import InvertibleTransform, allow_missing_keys_mode, convert_inverse_interp_mode +from monai.utils import InverseKeys, exact_version, optional_import Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") if TYPE_CHECKING: @@ -84,14 +84,7 @@ def __call__(self, engine: Engine) -> None: transform_info = engine.state.batch[transform_key] if self.nearest_interp: - interp_modes = [i.value for i in InterpolateMode] + [i.value for i in GridSampleMode] - for item in transform_info: - if InverseKeys.EXTRA_INFO in item: - mode = item[InverseKeys.EXTRA_INFO].get("mode", None) - if mode is not None and mode[0] in interp_modes: - item[InverseKeys.EXTRA_INFO]["mode"] = ["nearest" for _ in range(len(mode))] - if "align_corners" in item[InverseKeys.EXTRA_INFO]: - item[InverseKeys.EXTRA_INFO]["align_corners"] = ["none" for _ in range(len(mode))] + convert_inverse_interp_mode(trans_info=transform_info, mode="nearest", align_corners=None) segs_dict = { self.batch_key: engine.state.output[self.output_key].detach().cpu(), diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index b66567e71a..195cf235db 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -395,4 +395,5 @@ resize_center, weighted_patch_samples, zero_margins, + convert_inverse_interp_mode, ) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index eb1b194c96..80c1e6e86a 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -22,8 +22,18 @@ from monai.networks.layers import GaussianFilter from monai.transforms.compose import Compose from monai.transforms.transform import MapTransform -from monai.utils import ensure_tuple, ensure_tuple_rep, ensure_tuple_size, fall_back_tuple, min_version, optional_import -from monai.utils.misc import issequenceiterable +from monai.utils import ( + ensure_tuple, + ensure_tuple_rep, + ensure_tuple_size, + fall_back_tuple, + min_version, + optional_import, + issequenceiterable, + GridSampleMode, + InterpolateMode, + InverseKeys, +) measure, _ = optional_import("skimage.measure", "0.14.2", min_version) @@ -53,6 +63,7 @@ "extreme_points_to_image", "map_spatial_axes", "allow_missing_keys_mode", + "convert_inverse_interp_mode", ] @@ -756,3 +767,34 @@ def allow_missing_keys_mode(transform: Union[MapTransform, Compose, Tuple[MapTra # Revert for t, o_s in zip(transforms, orig_states): t.allow_missing_keys = o_s + + +def convert_inverse_interp_mode(trans_info: List, mode: str = "nearest", align_corners: Optional[bool] = None): + """ + Change the interpolation mode when inverting spatial transforms, default to "nearest". + It can support both single data or batch data. + + Args: + trans_info: transforms inverse information list, contains context of every invertible transform. + mode: target interpolation mode to convert, default to "nearest" as it's usually used to save the mode output. + align_corners: target align corner value in PyTorch interpolation API, need to align with the `mode`. + + """ + interp_modes = [i.value for i in InterpolateMode] + [i.value for i in GridSampleMode] + if align_corners is None: + # set to string for DataLoader collation + align_corners = "none" + + for item in trans_info: + if InverseKeys.EXTRA_INFO in item: + orig_mode = item[InverseKeys.EXTRA_INFO].get("mode", None) + if orig_mode is not None: + if orig_mode[0] in interp_modes: + item[InverseKeys.EXTRA_INFO]["mode"] = [mode for _ in range(len(mode))] + elif orig_mode in interp_modes: + item[InverseKeys.EXTRA_INFO]["mode"] = mode + if "align_corners" in item[InverseKeys.EXTRA_INFO]: + if issequenceiterable(item[InverseKeys.EXTRA_INFO]["align_corners"]): + item[InverseKeys.EXTRA_INFO]["align_corners"] = [align_corners for _ in range(len(mode))] + else: + item[InverseKeys.EXTRA_INFO]["align_corners"] = align_corners From 959329c242a90f939e2d9519e02f35e5ce79d631 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 9 Apr 2021 20:43:14 +0800 Subject: [PATCH 12/14] [DLMED] add more tests Signed-off-by: Nic Ma --- monai/transforms/__init__.py | 2 +- monai/transforms/utils.py | 10 +++++----- tests/test_inverse.py | 5 ++++- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 195cf235db..f96194c262 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -371,6 +371,7 @@ ) from .utils import ( allow_missing_keys_mode, + convert_inverse_interp_mode, copypaste_arrays, create_control_grid, create_grid, @@ -395,5 +396,4 @@ resize_center, weighted_patch_samples, zero_margins, - convert_inverse_interp_mode, ) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 80c1e6e86a..1bbebee972 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -23,16 +23,16 @@ from monai.transforms.compose import Compose from monai.transforms.transform import MapTransform from monai.utils import ( + GridSampleMode, + InterpolateMode, + InverseKeys, ensure_tuple, ensure_tuple_rep, ensure_tuple_size, fall_back_tuple, + issequenceiterable, min_version, optional_import, - issequenceiterable, - GridSampleMode, - InterpolateMode, - InverseKeys, ) measure, _ = optional_import("skimage.measure", "0.14.2", min_version) @@ -785,7 +785,7 @@ def convert_inverse_interp_mode(trans_info: List, mode: str = "nearest", align_c # set to string for DataLoader collation align_corners = "none" - for item in trans_info: + for item in ensure_tuple(trans_info): if InverseKeys.EXTRA_INFO in item: orig_mode = item[InverseKeys.EXTRA_INFO].get("mode", None) if orig_mode is not None: diff --git a/tests/test_inverse.py b/tests/test_inverse.py index ccc4f366c2..358bf0176a 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -54,6 +54,7 @@ SpatialPadd, Zoomd, allow_missing_keys_mode, + convert_inverse_interp_mode, ) from monai.utils import first, get_seed, optional_import, set_determinism from monai.utils.enums import InverseKeys @@ -572,9 +573,11 @@ def test_inverse_inferred_seg(self): segs_dict = {"label": segs, label_transform_key: data[label_transform_key]} segs_dict_decollated = decollate_batch(segs_dict) - # inverse of individual segmentation seg_dict = first(segs_dict_decollated) + # test to convert interpolation mode for 1 data of model output batch + convert_inverse_interp_mode(seg_dict, mode="nearest", align_corners=None) + with allow_missing_keys_mode(transforms): inv_seg = transforms.inverse(seg_dict)["label"] self.assertEqual(len(data["label_transforms"]), num_invertible_transforms) From 76034b43682cff52c12e756e604fe6fcf8a33eda Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 9 Apr 2021 21:33:47 +0800 Subject: [PATCH 13/14] [DLMED] fix flake8 Signed-off-by: Nic Ma --- monai/transforms/utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 1bbebee972..3c8daa69d1 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -781,9 +781,9 @@ def convert_inverse_interp_mode(trans_info: List, mode: str = "nearest", align_c """ interp_modes = [i.value for i in InterpolateMode] + [i.value for i in GridSampleMode] - if align_corners is None: - # set to string for DataLoader collation - align_corners = "none" + + # set to string for DataLoader collation + align_corners_ = "none" if align_corners is None else align_corners for item in ensure_tuple(trans_info): if InverseKeys.EXTRA_INFO in item: @@ -795,6 +795,6 @@ def convert_inverse_interp_mode(trans_info: List, mode: str = "nearest", align_c item[InverseKeys.EXTRA_INFO]["mode"] = mode if "align_corners" in item[InverseKeys.EXTRA_INFO]: if issequenceiterable(item[InverseKeys.EXTRA_INFO]["align_corners"]): - item[InverseKeys.EXTRA_INFO]["align_corners"] = [align_corners for _ in range(len(mode))] + item[InverseKeys.EXTRA_INFO]["align_corners"] = [align_corners_ for _ in range(len(mode))] else: - item[InverseKeys.EXTRA_INFO]["align_corners"] = align_corners + item[InverseKeys.EXTRA_INFO]["align_corners"] = align_corners_ From 8d78a21db74b6d639c7db55d0e52b1ba84ad030a Mon Sep 17 00:00:00 2001 From: monai-bot Date: Fri, 9 Apr 2021 13:41:02 +0000 Subject: [PATCH 14/14] [MONAI] python code formatting Signed-off-by: monai-bot --- monai/transforms/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 3c8daa69d1..b73a899153 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -781,7 +781,7 @@ def convert_inverse_interp_mode(trans_info: List, mode: str = "nearest", align_c """ interp_modes = [i.value for i in InterpolateMode] + [i.value for i in GridSampleMode] - + # set to string for DataLoader collation align_corners_ = "none" if align_corners is None else align_corners