diff --git a/docs/source/data.rst b/docs/source/data.rst index 8071bb1585..66fadd549b 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -188,3 +188,7 @@ ThreadBuffer BatchInverseTransform ~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: monai.data.BatchInverseTransform + +TestTimeAugmentation +~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: monai.data.TestTimeAugmentation diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 2001ccfc8f..adb27a608e 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -34,6 +34,7 @@ from .png_writer import write_png from .samplers import DistributedSampler, DistributedWeightedRandomSampler from .synthetic import create_test_image_2d, create_test_image_3d +from .test_time_augmentation import TestTimeAugmentation from .thread_buffer import ThreadBuffer, ThreadDataLoader from .utils import ( compute_importance_map, diff --git a/monai/data/test_time_augmentation.py b/monai/data/test_time_augmentation.py new file mode 100644 index 0000000000..51b95adc58 --- /dev/null +++ b/monai/data/test_time_augmentation.py @@ -0,0 +1,178 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch + +from monai.data.dataloader import DataLoader +from monai.data.dataset import Dataset +from monai.data.inverse_batch_transform import BatchInverseTransform +from monai.data.utils import list_data_collate, pad_list_data_collate +from monai.transforms.compose import Compose +from monai.transforms.inverse import InvertibleTransform +from monai.transforms.transform import RandomizableTransform +from monai.transforms.utils import allow_missing_keys_mode +from monai.utils.enums import CommonKeys, InverseKeys + +__all__ = ["TestTimeAugmentation"] + + +class TestTimeAugmentation: + """ + Class for performing test time augmentations. This will pass the same image through the network multiple times. + + The user passes transform(s) to be applied to each realisation, and provided that at least one of those transforms + is random, the network's output will vary. Provided that inverse transformations exist for all supplied spatial + transforms, the inverse can be applied to each realisation of the network's output. Once in the same spatial + reference, the results can then be combined and metrics computed. + + Test time augmentations are a useful feature for computing network uncertainty, as well as observing the network's + dependency on the applied random transforms. + + Reference: + Wang et al., + Aleatoric uncertainty estimation with test-time augmentation for medical image segmentation with convolutional + neural networks, + https://doi.org/10.1016/j.neucom.2019.01.103 + + Args: + transform: transform (or composed) to be applied to each realisation. At least one transform must be of type + `RandomizableTransform`. All random transforms must be of type `InvertibleTransform`. + batch_size: number of realisations to infer at once. + num_workers: how many subprocesses to use for data. + inferrer_fn: function to use to perform inference. + device: device on which to perform inference. + image_key: key used to extract image from input dictionary. + label_key: key used to extract label from input dictionary. + return_full_data: normally, metrics are returned (mode, mean, std, vvc). Setting this flag to `True` will return the + full data. Dimensions will be same size as when passing a single image through `inferrer_fn`, with a dimension appended + equal in size to `num_examples` (N), i.e., `[N,C,H,W,[D]]`. + + Example: + .. code-block:: python + + transform = RandAffined(keys, ...) + post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) + + tt_aug = TestTimeAugmentation( + transform, batch_size=5, num_workers=0, inferrer_fn=lambda x: post_trans(model(x)), device=device + ) + mode, mean, std, vvc = tt_aug(test_data) + """ + + def __init__( + self, + transform: InvertibleTransform, + batch_size: int, + num_workers: int, + inferrer_fn: Callable, + device: Optional[Union[str, torch.device]] = "cuda" if torch.cuda.is_available() else "cpu", + image_key=CommonKeys.IMAGE, + label_key=CommonKeys.LABEL, + return_full_data: bool = False, + ) -> None: + self.transform = transform + self.batch_size = batch_size + self.num_workers = num_workers + self.inferrer_fn = inferrer_fn + self.device = device + self.image_key = image_key + self.label_key = label_key + self.return_full_data = return_full_data + + # check that the transform has at least one random component, and that all random transforms are invertible + self._check_transforms() + + def _check_transforms(self): + """Should be at least 1 random transform, and all random transforms should be invertible.""" + ts = [self.transform] if not isinstance(self.transform, Compose) else self.transform.transforms + randoms = np.array([isinstance(t, RandomizableTransform) for t in ts]) + invertibles = np.array([isinstance(t, InvertibleTransform) for t in ts]) + # check at least 1 random + if sum(randoms) == 0: + raise RuntimeError( + "Requires a `Randomizable` transform or a `Compose` containing at least one `Randomizable` transform." + ) + # check that whenever randoms is True, invertibles is also true + for r, i in zip(randoms, invertibles): + if r and not i: + raise RuntimeError( + f"All applied random transform(s) must be invertible. Problematic transform: {type(r).__name__}" + ) + + def __call__( + self, data: Dict[str, Any], num_examples: int = 10 + ) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, float], np.ndarray]: + """ + Args: + data: dictionary data to be processed. + num_examples: number of realisations to be processed and results combined. + + Returns: + - if `return_full_data==False`: mode, mean, std, vvc. The mode, mean and standard deviation are calculated across + `num_examples` outputs at each voxel. The volume variation coefficient (VVC) is `std/mean` across the whole output, + including `num_examples`. See original paper for clarification. + - if `return_full_data==False`: data is returned as-is after applying the `inferrer_fn` and then concatenating across + the first dimension containing `num_examples`. This allows the user to perform their own analysis if desired. + """ + d = dict(data) + + # check num examples is multiple of batch size + if num_examples % self.batch_size != 0: + raise ValueError("num_examples should be multiple of batch size.") + + # generate batch of data of size == batch_size, dataset and dataloader + data_in = [d] * num_examples + ds = Dataset(data_in, self.transform) + dl = DataLoader(ds, self.num_workers, batch_size=self.batch_size, collate_fn=pad_list_data_collate) + + label_transform_key = self.label_key + InverseKeys.KEY_SUFFIX + + # create inverter + inverter = BatchInverseTransform(self.transform, dl, collate_fn=list_data_collate) + + outputs: List[np.ndarray] = [] + + for batch_data in dl: + + batch_images = batch_data[self.image_key].to(self.device) + + # do model forward pass + batch_output = self.inferrer_fn(batch_images) + if isinstance(batch_output, torch.Tensor): + batch_output = batch_output.detach().cpu() + if isinstance(batch_output, np.ndarray): + batch_output = torch.Tensor(batch_output) + + # create a dictionary containing the inferred batch and their transforms + inferred_dict = {self.label_key: batch_output, label_transform_key: batch_data[label_transform_key]} + + # do inverse transformation (allow missing keys as only inverting label) + with allow_missing_keys_mode(self.transform): # type: ignore + inv_batch = inverter(inferred_dict) + + # append + outputs.append(inv_batch[self.label_key]) + + # output + output: np.ndarray = np.concatenate(outputs) + + if self.return_full_data: + return output + + # calculate metrics + mode: np.ndarray = np.apply_along_axis(lambda x: np.bincount(x).argmax(), axis=0, arr=output.astype(np.int64)) + mean: np.ndarray = np.mean(output, axis=0) # type: ignore + std: np.ndarray = np.std(output, axis=0) # type: ignore + vvc: float = (np.std(output) / np.mean(output)).item() + return mode, mean, std, vvc diff --git a/tests/test_testtimeaugmentation.py b/tests/test_testtimeaugmentation.py new file mode 100644 index 0000000000..bee1aa4b0d --- /dev/null +++ b/tests/test_testtimeaugmentation.py @@ -0,0 +1,143 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial +from typing import TYPE_CHECKING + +import numpy as np +import torch + +from monai.data import CacheDataset, DataLoader, create_test_image_2d +from monai.data.test_time_augmentation import TestTimeAugmentation +from monai.data.utils import pad_list_data_collate +from monai.losses import DiceLoss +from monai.networks.nets import UNet +from monai.transforms import Activations, AddChanneld, AsDiscrete, Compose, CropForegroundd, DivisiblePadd, RandAffined +from monai.transforms.croppad.dictionary import SpatialPadd +from monai.transforms.spatial.dictionary import Rand2DElasticd, RandFlipd +from monai.utils import optional_import, set_determinism + +if TYPE_CHECKING: + import tqdm + + has_tqdm = True +else: + tqdm, has_tqdm = optional_import("tqdm") + +trange = partial(tqdm.trange, desc="training") if has_tqdm else range + + +class TestTestTimeAugmentation(unittest.TestCase): + @staticmethod + def get_data(num_examples, input_size): + custom_create_test_image_2d = partial( + create_test_image_2d, *input_size, rad_max=7, num_seg_classes=1, num_objs=1 + ) + data = [] + for _ in range(num_examples): + im, label = custom_create_test_image_2d() + data.append({"image": im, "label": label}) + return data[0] if num_examples == 1 else data + + def setUp(self) -> None: + set_determinism(seed=0) + + def tearDown(self) -> None: + set_determinism(None) + + def test_test_time_augmentation(self): + input_size = (20, 20) + device = "cuda" if torch.cuda.is_available() else "cpu" + keys = ["image", "label"] + num_training_ims = 10 + train_data = self.get_data(num_training_ims, input_size) + test_data = self.get_data(1, input_size) + + transforms = Compose( + [ + AddChanneld(keys), + RandAffined( + keys, + prob=1.0, + spatial_size=(30, 30), + rotate_range=(np.pi / 3, np.pi / 3), + translate_range=(3, 3), + scale_range=((0.8, 1), (0.8, 1)), + padding_mode="zeros", + mode=("bilinear", "nearest"), + as_tensor_output=False, + ), + CropForegroundd(keys, source_key="image"), + DivisiblePadd(keys, 4), + ] + ) + + train_ds = CacheDataset(train_data, transforms) + # output might be different size, so pad so that they match + train_loader = DataLoader(train_ds, batch_size=2, collate_fn=pad_list_data_collate) + + model = UNet(2, 1, 1, channels=(6, 6), strides=(2, 2)).to(device) + loss_function = DiceLoss(sigmoid=True) + optimizer = torch.optim.Adam(model.parameters(), 1e-3) + + num_epochs = 10 + for _ in trange(num_epochs): + epoch_loss = 0 + + for batch_data in train_loader: + inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device) + optimizer.zero_grad() + outputs = model(inputs) + loss = loss_function(outputs, labels) + loss.backward() + optimizer.step() + epoch_loss += loss.item() + + epoch_loss /= len(train_loader) + + post_trans = Compose( + [ + Activations(sigmoid=True), + AsDiscrete(threshold_values=True), + ] + ) + + def inferrer_fn(x): + return post_trans(model(x)) + + tt_aug = TestTimeAugmentation(transforms, batch_size=5, num_workers=0, inferrer_fn=inferrer_fn, device=device) + mode, mean, std, vvc = tt_aug(test_data) + self.assertEqual(mode.shape, (1,) + input_size) + self.assertEqual(mean.shape, (1,) + input_size) + self.assertTrue(all(np.unique(mode) == (0, 1))) + self.assertEqual((mean.min(), mean.max()), (0.0, 1.0)) + self.assertEqual(std.shape, (1,) + input_size) + self.assertIsInstance(vvc, float) + + def test_fail_non_random(self): + transforms = Compose([AddChanneld("im"), SpatialPadd("im", 1)]) + with self.assertRaises(RuntimeError): + TestTimeAugmentation(transforms, None, None, None) + + def test_fail_random_but_not_invertible(self): + transforms = Compose([AddChanneld("im"), Rand2DElasticd("im", None, None)]) + with self.assertRaises(RuntimeError): + TestTimeAugmentation(transforms, None, None, None) + + def test_single_transform(self): + transforms = RandFlipd(["image", "label"]) + tta = TestTimeAugmentation(transforms, batch_size=5, num_workers=0, inferrer_fn=lambda x: x) + tta(self.get_data(1, (20, 20))) + + +if __name__ == "__main__": + unittest.main()