From 0d24a6d9a3586ae34554536af44df7741249cf55 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 4 Mar 2021 15:50:01 +0000 Subject: [PATCH 01/14] tta Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/__init__.py | 1 + monai/data/test_time_augmentation.py | 116 +++++++++++++++++++++++ tests/test_testtimeaugmentation.py | 134 +++++++++++++++++++++++++++ 3 files changed, 251 insertions(+) create mode 100644 monai/data/test_time_augmentation.py create mode 100644 tests/test_testtimeaugmentation.py diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 3dd0a980ef..bc723340a0 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -31,6 +31,7 @@ from .png_saver import PNGSaver from .png_writer import write_png from .synthetic import create_test_image_2d, create_test_image_3d +from .test_time_augmentation import TestTimeAugmentation from .thread_buffer import ThreadBuffer from .utils import ( DistributedSampler, diff --git a/monai/data/test_time_augmentation.py b/monai/data/test_time_augmentation.py new file mode 100644 index 0000000000..48cbc54843 --- /dev/null +++ b/monai/data/test_time_augmentation.py @@ -0,0 +1,116 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict + +import numpy as np +import torch + +from monai.data.dataloader import DataLoader +from monai.data.dataset import Dataset +from monai.data.inverse_batch_transform import BatchInverseTransform +from monai.data.utils import pad_list_data_collate +from monai.transforms.compose import Compose +from monai.transforms.inverse_transform import InvertibleTransform +from monai.transforms.transform import Randomizable + +__all__ = ["TestTimeAugmentation"] + + +def is_transform_rand(transform): + if not isinstance(transform, Compose): + return isinstance(transform, Randomizable) + # call recursively for each sub-transform + return any(is_transform_rand(t) for t in transform.transforms) + + +class TestTimeAugmentation: + def __init__( + self, + transform: InvertibleTransform, + batch_size, + num_workers, + inferrer_fn, + device, + ) -> None: + self.transform = transform + self.batch_size = batch_size + self.num_workers = num_workers + self.inferrer_fn = inferrer_fn + self.device = device + + # check that the transform has at least one random component + if not is_transform_rand(self.transform): + raise RuntimeError( + type(self).__name__ + + " requires a `Randomizable` transform or a" + + " `Compose` containing at least one `Randomizable` transform." + ) + + def __call__( + self, data: Dict[str, Any], num_examples=10, image_key="image", label_key="label", return_full_data=False + ): + d = dict(data) + + # check num examples is multiple of batch size + if num_examples % self.batch_size != 0: + raise ValueError("num_examples should be multiple of batch size.") + + # generate batch of data of size == batch_size, dataset and dataloader + data_in = [d for _ in range(num_examples)] + ds = Dataset(data_in, self.transform) + dl = DataLoader(ds, self.num_workers, batch_size=self.batch_size, collate_fn=pad_list_data_collate) + + label_transform_key = label_key + "_transforms" + + # create inverter + inverter = BatchInverseTransform(self.transform, dl) + + outputs = [] + + for batch_data in dl: + + batch_images = batch_data[image_key].to(self.device) + + # do model forward pass + batch_output = self.inferrer_fn(batch_images) + if isinstance(batch_output, torch.Tensor): + batch_output = batch_output.detach().cpu() + if isinstance(batch_output, np.ndarray): + batch_output = torch.Tensor(batch_output) + + # check binary labels are extracted + if not all(torch.unique(batch_output.int()) == torch.Tensor([0, 1])): + raise RuntimeError( + "Test-time augmentation requires binary channels. If this is " + "not binary segmentation, then you should one-hot your output." + ) + + # create a dictionary containing the inferred batch and their transforms + inferred_dict = {label_key: batch_output, label_transform_key: batch_data[label_transform_key]} + + # do inverse transformation (only for the label key) + inv_batch = inverter(inferred_dict, label_key) + + # append + outputs.append(inv_batch) + + # calculate mean and standard deviation + output = np.concatenate(outputs) + + if return_full_data: + return output + + mode = np.apply_along_axis(lambda x: np.bincount(x).argmax(), axis=0, arr=output.astype(np.int64)) + mean = np.mean(output, axis=0) + std = np.std(output, axis=0) + vvc = np.std(output) / np.mean(output) + return mode, mean, std, vvc diff --git a/tests/test_testtimeaugmentation.py b/tests/test_testtimeaugmentation.py new file mode 100644 index 0000000000..93027e1ac8 --- /dev/null +++ b/tests/test_testtimeaugmentation.py @@ -0,0 +1,134 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial +from typing import TYPE_CHECKING + +import numpy as np +import torch +from torch._C import has_cuda + +from monai.data import CacheDataset, DataLoader, create_test_image_2d +from monai.data.test_time_augmentation import TestTimeAugmentation +from monai.data.utils import pad_list_data_collate +from monai.losses import DiceLoss +from monai.networks.nets import UNet +from monai.transforms import ( + Activations, + AddChanneld, + AsDiscrete, + Compose, + CropForegroundd, + DivisiblePadd, + KeepLargestConnectedComponent, + RandAffined, +) +from monai.transforms.croppad.dictionary import SpatialPadd +from monai.utils import optional_import, set_determinism + +if TYPE_CHECKING: + import tqdm + + has_tqdm = True +else: + tqdm, has_tqdm = optional_import("tqdm") + +trange = partial(tqdm.trange, desc="training") if has_tqdm else range + +set_determinism(seed=0) + + +class TestTestTimeAugmentation(unittest.TestCase): + def test_test_time_augmentation(self): + input_size = (20, 20) + device = "cuda" if has_cuda else "cpu" + num_training_ims = 10 + data = [] + custom_create_test_image_2d = partial( + create_test_image_2d, *input_size, rad_max=7, num_seg_classes=1, num_objs=1 + ) + keys = ["image", "label"] + + for _ in range(num_training_ims): + im, label = custom_create_test_image_2d() + data.append({"image": im, "label": label}) + + transforms = Compose( + [ + AddChanneld(keys), + RandAffined( + keys, + prob=1.0, + spatial_size=(30, 30), + rotate_range=(np.pi / 3, np.pi / 3), + translate_range=(3, 3), + scale_range=((0.8, 1), (0.8, 1)), + padding_mode="zeros", + mode=("bilinear", "nearest"), + as_tensor_output=False, + ), + CropForegroundd(keys, source_key="image"), + DivisiblePadd(keys, 4), + ] + ) + + train_ds = CacheDataset(data, transforms) + # output might be different size, so pad so that they match + train_loader = DataLoader(train_ds, batch_size=2, collate_fn=pad_list_data_collate) + + model = UNet(2, 1, 1, channels=(6, 6), strides=(2, 2)).to(device) + loss_function = DiceLoss(sigmoid=True) + optimizer = torch.optim.Adam(model.parameters(), 1e-3) + + num_epochs = 10 + for _ in trange(num_epochs): + epoch_loss = 0 + + for batch_data in train_loader: + inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device) + optimizer.zero_grad() + outputs = model(inputs) + loss = loss_function(outputs, labels) + loss.backward() + optimizer.step() + epoch_loss += loss.item() + + epoch_loss /= len(train_loader) + + image, label = custom_create_test_image_2d() + test_data = {"image": image, "label": label} + + post_trans = Compose( + [ + Activations(sigmoid=True), + AsDiscrete(threshold_values=True), + KeepLargestConnectedComponent(applied_labels=1), + ] + ) + + def inferrer_fn(x): + return post_trans(model(x)) + + tt_aug = TestTimeAugmentation(transforms, batch_size=5, num_workers=0, inferrer_fn=inferrer_fn, device=device) + mean, std = tt_aug(test_data) + self.assertEqual(mean.shape, (1,) + input_size) + self.assertEqual((mean.min(), mean.max()), (0.0, 1.0)) + self.assertEqual(std.shape, (1,) + input_size) + + def test_fail_non_random(self): + transforms = Compose([AddChanneld("im"), SpatialPadd("im", 1)]) + with self.assertRaises(RuntimeError): + TestTimeAugmentation(transforms, None, None, None, None) + + +if __name__ == "__main__": + unittest.main() From d19c3a66d7d000fb0563e0f67067ece1ba287383 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 17 Mar 2021 16:49:15 +0000 Subject: [PATCH 02/14] update Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/test_time_augmentation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/test_time_augmentation.py b/monai/data/test_time_augmentation.py index 48cbc54843..2db017f250 100644 --- a/monai/data/test_time_augmentation.py +++ b/monai/data/test_time_augmentation.py @@ -19,7 +19,7 @@ from monai.data.inverse_batch_transform import BatchInverseTransform from monai.data.utils import pad_list_data_collate from monai.transforms.compose import Compose -from monai.transforms.inverse_transform import InvertibleTransform +from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import Randomizable __all__ = ["TestTimeAugmentation"] From 0eaea7f4653bdb01e5ec978bae7cfd8822656c75 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 22 Mar 2021 10:24:46 +0000 Subject: [PATCH 03/14] tta code review Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/test_time_augmentation.py | 132 ++++++++++++++++++++------- tests/test_testtimeaugmentation.py | 21 ++++- 2 files changed, 118 insertions(+), 35 deletions(-) diff --git a/monai/data/test_time_augmentation.py b/monai/data/test_time_augmentation.py index 2db017f250..2ff34df20f 100644 --- a/monai/data/test_time_augmentation.py +++ b/monai/data/test_time_augmentation.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -17,38 +17,95 @@ from monai.data.dataloader import DataLoader from monai.data.dataset import Dataset from monai.data.inverse_batch_transform import BatchInverseTransform -from monai.data.utils import pad_list_data_collate +from monai.data.utils import list_data_collate, pad_list_data_collate +from monai.engines.utils import CommonKeys from monai.transforms.compose import Compose from monai.transforms.inverse import InvertibleTransform -from monai.transforms.transform import Randomizable +from monai.transforms.transform import RandomizableTransform +from monai.transforms.utils import allow_missing_keys_mode +from monai.utils.enums import InverseKeys __all__ = ["TestTimeAugmentation"] -def is_transform_rand(transform): - if not isinstance(transform, Compose): - return isinstance(transform, Randomizable) - # call recursively for each sub-transform - return any(is_transform_rand(t) for t in transform.transforms) +def is_transform_rand_invertible(transform): + if isinstance(transform, Compose): + # call recursively for each sub-transform + return any(is_transform_rand_invertible(t) for t in transform.transforms) + is_random = isinstance(transform, RandomizableTransform) + is_invertible = isinstance(transform, InvertibleTransform) + if is_random and not is_invertible: + raise RuntimeError( + f"All applied random transform(s) must be invertible. Problematic transform: {type(transform).__name__}" + ) + return is_random class TestTimeAugmentation: + """ + Class for performing test time augmentations. This will pass the same image through the network multiple times. + + The user passes transform(s) to be applied to each realisation, and provided that at least one of those transforms + is random, the network's output will vary. Provided that inverse transformations exist for all supplied spatial + transforms, the inverse can be applied to each realisation of the network's output. Once in the same spatial + reference, the results can then be combined and metrics computed. + + Test time augmentations are a useful feature for computing network uncertainty, as well as observing the network's + dependency on the applied random transforms. + + Reference: + Wang et al., + Aleatoric uncertainty estimation with test-time augmentation for medical image segmentation with convolutional + neural networks, + https://doi.org/10.1016/j.neucom.2019.01.103 + + Args: + transform: transform (or composed) to be applied to each realisation. At least one transform must be of type + `RandomizableTransform`. All random transforms must be of type `InvertibleTransform`. + batch_size: number of realisations to infer at once. + num_workers: how many subprocesses to use for data. + inferrer_fn: function to use to perform inference. + device: device on which to perform inference. + image_key: key used to extract image from input dictionary. + label_key: key used to extract label from input dictionary. + return_full_data: normally, metrics are returned (mode, mean, std, vvc). Setting this flag to `True` will return the + full data. Dimensions will be same size as when passing a single image through `inferrer_fn`, with a dimension appended + equal in size to `num_examples` (N), i.e., `[N,C,H,W,[D]]`. + + Example: + .. code-block:: python + + transform = RandAffined(keys, ...) + post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) + + tt_aug = TestTimeAugmentation( + transform, batch_size=5, num_workers=0, inferrer_fn=lambda x: post_trans(model(x)), device=device + ) + mode, mean, std, vvc = tt_aug(test_data) + """ + def __init__( self, transform: InvertibleTransform, - batch_size, - num_workers, - inferrer_fn, - device, + batch_size: int, + num_workers: int, + inferrer_fn: Callable, + device: Optional[Union[str, torch.device]] = "cuda" if torch.cuda.is_available() else "cpu", + image_key=CommonKeys.IMAGE, + label_key=CommonKeys.LABEL, + return_full_data: bool = False, ) -> None: self.transform = transform self.batch_size = batch_size self.num_workers = num_workers self.inferrer_fn = inferrer_fn self.device = device + self.image_key = image_key + self.label_key = label_key + self.return_full_data = return_full_data - # check that the transform has at least one random component - if not is_transform_rand(self.transform): + # check that the transform has at least one random component, and that all random transforms are invertible + if not is_transform_rand_invertible(self.transform): raise RuntimeError( type(self).__name__ + " requires a `Randomizable` transform or a" @@ -56,8 +113,20 @@ def __init__( ) def __call__( - self, data: Dict[str, Any], num_examples=10, image_key="image", label_key="label", return_full_data=False - ): + self, data: Dict[str, Any], num_examples: int = 10 + ) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, float], np.ndarray]: + """ + Args: + data: dictionary data to be processed. + num_examples: number of realisations to be processed and results combined. + + Returns: + - if `return_full_data==False`: mode, mean, std, vvc. The mode, mean and standard deviation are calculated across + `num_examples` outputs at each voxel. The volume variation coefficient (VVC) is `std/mean` across the whole output, + including `num_examples`. See original paper for clarification. + - if `return_full_data==False`: data is returned as-is after applying the `inferrer_fn` and then concatenating across + the first dimension containing `num_examples`. This allows the user to perform their own analysis if desired. + """ d = dict(data) # check num examples is multiple of batch size @@ -65,20 +134,20 @@ def __call__( raise ValueError("num_examples should be multiple of batch size.") # generate batch of data of size == batch_size, dataset and dataloader - data_in = [d for _ in range(num_examples)] + data_in = [d] * num_examples ds = Dataset(data_in, self.transform) dl = DataLoader(ds, self.num_workers, batch_size=self.batch_size, collate_fn=pad_list_data_collate) - label_transform_key = label_key + "_transforms" + label_transform_key = self.label_key + InverseKeys.KEY_SUFFIX.value # create inverter - inverter = BatchInverseTransform(self.transform, dl) + inverter = BatchInverseTransform(self.transform, dl, collate_fn=list_data_collate) - outputs = [] + outputs: List[np.ndarray] = [] for batch_data in dl: - batch_images = batch_data[image_key].to(self.device) + batch_images = batch_data[self.image_key].to(self.device) # do model forward pass batch_output = self.inferrer_fn(batch_images) @@ -95,22 +164,23 @@ def __call__( ) # create a dictionary containing the inferred batch and their transforms - inferred_dict = {label_key: batch_output, label_transform_key: batch_data[label_transform_key]} + inferred_dict = {self.label_key: batch_output, label_transform_key: batch_data[label_transform_key]} - # do inverse transformation (only for the label key) - inv_batch = inverter(inferred_dict, label_key) + # do inverse transformation (allow missing keys as only inverting label) + with allow_missing_keys_mode(self.transform): # type: ignore + inv_batch = inverter(inferred_dict) # append - outputs.append(inv_batch) + outputs.append(inv_batch[self.label_key]) # calculate mean and standard deviation - output = np.concatenate(outputs) + output: np.ndarray = np.concatenate(outputs) - if return_full_data: + if self.return_full_data: return output - mode = np.apply_along_axis(lambda x: np.bincount(x).argmax(), axis=0, arr=output.astype(np.int64)) - mean = np.mean(output, axis=0) - std = np.std(output, axis=0) - vvc = np.std(output) / np.mean(output) + mode: np.ndarray = np.apply_along_axis(lambda x: np.bincount(x).argmax(), axis=0, arr=output.astype(np.int64)) + mean: np.ndarray = np.mean(output, axis=0) # type: ignore + std: np.ndarray = np.std(output, axis=0) # type: ignore + vvc: float = (np.std(output) / np.mean(output)).item() return mode, mean, std, vvc diff --git a/tests/test_testtimeaugmentation.py b/tests/test_testtimeaugmentation.py index 93027e1ac8..8d4b7b8530 100644 --- a/tests/test_testtimeaugmentation.py +++ b/tests/test_testtimeaugmentation.py @@ -33,6 +33,7 @@ RandAffined, ) from monai.transforms.croppad.dictionary import SpatialPadd +from monai.transforms.spatial.dictionary import Rand2DElasticd from monai.utils import optional_import, set_determinism if TYPE_CHECKING: @@ -44,10 +45,14 @@ trange = partial(tqdm.trange, desc="training") if has_tqdm else range -set_determinism(seed=0) - class TestTestTimeAugmentation(unittest.TestCase): + def setUp(self) -> None: + set_determinism(seed=0) + + def tearDown(self) -> None: + set_determinism(None) + def test_test_time_augmentation(self): input_size = (20, 20) device = "cuda" if has_cuda else "cpu" @@ -119,15 +124,23 @@ def inferrer_fn(x): return post_trans(model(x)) tt_aug = TestTimeAugmentation(transforms, batch_size=5, num_workers=0, inferrer_fn=inferrer_fn, device=device) - mean, std = tt_aug(test_data) + mode, mean, std, vvc = tt_aug(test_data) + self.assertEqual(mode.shape, (1,) + input_size) self.assertEqual(mean.shape, (1,) + input_size) + self.assertTrue(all(np.unique(mode) == (0, 1))) self.assertEqual((mean.min(), mean.max()), (0.0, 1.0)) self.assertEqual(std.shape, (1,) + input_size) + self.assertIsInstance(vvc, float) def test_fail_non_random(self): transforms = Compose([AddChanneld("im"), SpatialPadd("im", 1)]) with self.assertRaises(RuntimeError): - TestTimeAugmentation(transforms, None, None, None, None) + TestTimeAugmentation(transforms, None, None, None) + + def test_fail_random_but_not_invertible(self): + transforms = Compose([AddChanneld("im"), Rand2DElasticd("im", None, None)]) + with self.assertRaises(RuntimeError): + TestTimeAugmentation(transforms, None, None, None) if __name__ == "__main__": From 9a1277a92994ef6abbb442b9fe3ce28af1569b25 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 22 Mar 2021 10:34:04 +0000 Subject: [PATCH 04/14] remove segmentation requirement Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/test_time_augmentation.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/monai/data/test_time_augmentation.py b/monai/data/test_time_augmentation.py index 2ff34df20f..8f2c6c908c 100644 --- a/monai/data/test_time_augmentation.py +++ b/monai/data/test_time_augmentation.py @@ -156,13 +156,6 @@ def __call__( if isinstance(batch_output, np.ndarray): batch_output = torch.Tensor(batch_output) - # check binary labels are extracted - if not all(torch.unique(batch_output.int()) == torch.Tensor([0, 1])): - raise RuntimeError( - "Test-time augmentation requires binary channels. If this is " - "not binary segmentation, then you should one-hot your output." - ) - # create a dictionary containing the inferred batch and their transforms inferred_dict = {self.label_key: batch_output, label_transform_key: batch_data[label_transform_key]} @@ -173,12 +166,13 @@ def __call__( # append outputs.append(inv_batch[self.label_key]) - # calculate mean and standard deviation + # output output: np.ndarray = np.concatenate(outputs) if self.return_full_data: return output + # calculate metrics mode: np.ndarray = np.apply_along_axis(lambda x: np.bincount(x).argmax(), axis=0, arr=output.astype(np.int64)) mean: np.ndarray = np.mean(output, axis=0) # type: ignore std: np.ndarray = np.std(output, axis=0) # type: ignore From eb8f57dedf92a4fb6aa913402504b28a2b71ef3b Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 22 Mar 2021 10:40:26 +0000 Subject: [PATCH 05/14] add to data.rst Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- docs/source/data.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/source/data.rst b/docs/source/data.rst index 8071bb1585..66fadd549b 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -188,3 +188,7 @@ ThreadBuffer BatchInverseTransform ~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: monai.data.BatchInverseTransform + +TestTimeAugmentation +~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: monai.data.TestTimeAugmentation From 8d79231af071a82ca1f2aad3acd278b1fb0098d5 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 22 Mar 2021 11:19:17 +0000 Subject: [PATCH 06/14] update w master Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/test_time_augmentation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/test_time_augmentation.py b/monai/data/test_time_augmentation.py index 8f2c6c908c..faa0f409ae 100644 --- a/monai/data/test_time_augmentation.py +++ b/monai/data/test_time_augmentation.py @@ -138,7 +138,7 @@ def __call__( ds = Dataset(data_in, self.transform) dl = DataLoader(ds, self.num_workers, batch_size=self.batch_size, collate_fn=pad_list_data_collate) - label_transform_key = self.label_key + InverseKeys.KEY_SUFFIX.value + label_transform_key = self.label_key + InverseKeys.KEY_SUFFIX # create inverter inverter = BatchInverseTransform(self.transform, dl, collate_fn=list_data_collate) From fec04fbf1a8cad06a82358bf1875b29f1542b977 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 22 Mar 2021 17:59:39 +0000 Subject: [PATCH 07/14] explicit cast Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/test_time_augmentation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/data/test_time_augmentation.py b/monai/data/test_time_augmentation.py index faa0f409ae..5980205e0e 100644 --- a/monai/data/test_time_augmentation.py +++ b/monai/data/test_time_augmentation.py @@ -30,8 +30,8 @@ def is_transform_rand_invertible(transform): if isinstance(transform, Compose): - # call recursively for each sub-transform - return any(is_transform_rand_invertible(t) for t in transform.transforms) + # call recursively for each sub-transform. cast to Compose shouldn't be necessary + return any(is_transform_rand_invertible(t) for t in Compose(transform).transforms) is_random = isinstance(transform, RandomizableTransform) is_invertible = isinstance(transform, InvertibleTransform) if is_random and not is_invertible: From 6b488c136fa2e9fab578252ad44abe0da807b079 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 22 Mar 2021 20:14:15 +0000 Subject: [PATCH 08/14] update for CommonKeys Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/test_time_augmentation.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/monai/data/test_time_augmentation.py b/monai/data/test_time_augmentation.py index 5980205e0e..f7c14e4a32 100644 --- a/monai/data/test_time_augmentation.py +++ b/monai/data/test_time_augmentation.py @@ -18,12 +18,11 @@ from monai.data.dataset import Dataset from monai.data.inverse_batch_transform import BatchInverseTransform from monai.data.utils import list_data_collate, pad_list_data_collate -from monai.engines.utils import CommonKeys from monai.transforms.compose import Compose from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import RandomizableTransform from monai.transforms.utils import allow_missing_keys_mode -from monai.utils.enums import InverseKeys +from monai.utils.enums import CommonKeys, InverseKeys __all__ = ["TestTimeAugmentation"] From e02a67ce4f94a81027d7006976aea2a5601e92d4 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 22 Mar 2021 21:42:31 +0000 Subject: [PATCH 09/14] torch cuda is available Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_testtimeaugmentation.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_testtimeaugmentation.py b/tests/test_testtimeaugmentation.py index 8d4b7b8530..18a7348902 100644 --- a/tests/test_testtimeaugmentation.py +++ b/tests/test_testtimeaugmentation.py @@ -15,7 +15,6 @@ import numpy as np import torch -from torch._C import has_cuda from monai.data import CacheDataset, DataLoader, create_test_image_2d from monai.data.test_time_augmentation import TestTimeAugmentation @@ -55,7 +54,7 @@ def tearDown(self) -> None: def test_test_time_augmentation(self): input_size = (20, 20) - device = "cuda" if has_cuda else "cpu" + device = "cuda" if torch.cuda.is_available() else "cpu" num_training_ims = 10 data = [] custom_create_test_image_2d = partial( From 7cdf68269b0c9844f4a03e4ade16e9b5a8acafd2 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 23 Mar 2021 10:15:30 +0000 Subject: [PATCH 10/14] correct transform check Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/test_time_augmentation.py | 33 ++++++++++++++-------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/monai/data/test_time_augmentation.py b/monai/data/test_time_augmentation.py index f7c14e4a32..4df2046bfa 100644 --- a/monai/data/test_time_augmentation.py +++ b/monai/data/test_time_augmentation.py @@ -27,19 +27,6 @@ __all__ = ["TestTimeAugmentation"] -def is_transform_rand_invertible(transform): - if isinstance(transform, Compose): - # call recursively for each sub-transform. cast to Compose shouldn't be necessary - return any(is_transform_rand_invertible(t) for t in Compose(transform).transforms) - is_random = isinstance(transform, RandomizableTransform) - is_invertible = isinstance(transform, InvertibleTransform) - if is_random and not is_invertible: - raise RuntimeError( - f"All applied random transform(s) must be invertible. Problematic transform: {type(transform).__name__}" - ) - return is_random - - class TestTimeAugmentation: """ Class for performing test time augmentations. This will pass the same image through the network multiple times. @@ -104,12 +91,24 @@ def __init__( self.return_full_data = return_full_data # check that the transform has at least one random component, and that all random transforms are invertible - if not is_transform_rand_invertible(self.transform): + self.check_transforms() + + def check_transforms(self): + """Should be at least 1 random transform, and all random transforms should be invertible.""" + ts = self.transform if not isinstance(self.transform, Compose) else self.transform.transforms + randoms = np.array([isinstance(t, RandomizableTransform) for t in ts]) + invertibles = np.array([isinstance(t, InvertibleTransform) for t in ts]) + # check at least 1 random + if sum(randoms) == 0: raise RuntimeError( - type(self).__name__ - + " requires a `Randomizable` transform or a" - + " `Compose` containing at least one `Randomizable` transform." + "Requires a `Randomizable` transform or a `Compose` containing at least one `Randomizable` transform." ) + # check that whenever randoms is True, invertibles is also true + for r, i in zip(randoms, invertibles): + if r and not i: + raise RuntimeError( + f"All applied random transform(s) must be invertible. Problematic transform: {type(r).__name__}" + ) def __call__( self, data: Dict[str, Any], num_examples: int = 10 From 61051c374cfbd95d1c55b73d1a0495e244788804 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 23 Mar 2021 11:34:09 +0000 Subject: [PATCH 11/14] add single transform test, remove KeepLargestConnectedComponent Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/test_time_augmentation.py | 6 ++--- tests/test_testtimeaugmentation.py | 38 ++++++++++++++++------------ 2 files changed, 25 insertions(+), 19 deletions(-) diff --git a/monai/data/test_time_augmentation.py b/monai/data/test_time_augmentation.py index 4df2046bfa..51b95adc58 100644 --- a/monai/data/test_time_augmentation.py +++ b/monai/data/test_time_augmentation.py @@ -91,11 +91,11 @@ def __init__( self.return_full_data = return_full_data # check that the transform has at least one random component, and that all random transforms are invertible - self.check_transforms() + self._check_transforms() - def check_transforms(self): + def _check_transforms(self): """Should be at least 1 random transform, and all random transforms should be invertible.""" - ts = self.transform if not isinstance(self.transform, Compose) else self.transform.transforms + ts = [self.transform] if not isinstance(self.transform, Compose) else self.transform.transforms randoms = np.array([isinstance(t, RandomizableTransform) for t in ts]) invertibles = np.array([isinstance(t, InvertibleTransform) for t in ts]) # check at least 1 random diff --git a/tests/test_testtimeaugmentation.py b/tests/test_testtimeaugmentation.py index 18a7348902..87df28460d 100644 --- a/tests/test_testtimeaugmentation.py +++ b/tests/test_testtimeaugmentation.py @@ -28,11 +28,10 @@ Compose, CropForegroundd, DivisiblePadd, - KeepLargestConnectedComponent, RandAffined, ) from monai.transforms.croppad.dictionary import SpatialPadd -from monai.transforms.spatial.dictionary import Rand2DElasticd +from monai.transforms.spatial.dictionary import Rand2DElasticd, RandFlipd from monai.utils import optional_import, set_determinism if TYPE_CHECKING: @@ -46,6 +45,17 @@ class TestTestTimeAugmentation(unittest.TestCase): + @staticmethod + def get_data(num_examples, input_size): + custom_create_test_image_2d = partial( + create_test_image_2d, *input_size, rad_max=7, num_seg_classes=1, num_objs=1 + ) + data = [] + for _ in range(num_examples): + im, label = custom_create_test_image_2d() + data.append({"image": im, "label": label}) + return data[0] if num_examples == 1 else data + def setUp(self) -> None: set_determinism(seed=0) @@ -55,16 +65,10 @@ def tearDown(self) -> None: def test_test_time_augmentation(self): input_size = (20, 20) device = "cuda" if torch.cuda.is_available() else "cpu" - num_training_ims = 10 - data = [] - custom_create_test_image_2d = partial( - create_test_image_2d, *input_size, rad_max=7, num_seg_classes=1, num_objs=1 - ) keys = ["image", "label"] - - for _ in range(num_training_ims): - im, label = custom_create_test_image_2d() - data.append({"image": im, "label": label}) + num_training_ims = 10 + train_data = self.get_data(num_training_ims, input_size) + test_data = self.get_data(1, input_size) transforms = Compose( [ @@ -85,7 +89,7 @@ def test_test_time_augmentation(self): ] ) - train_ds = CacheDataset(data, transforms) + train_ds = CacheDataset(train_data, transforms) # output might be different size, so pad so that they match train_loader = DataLoader(train_ds, batch_size=2, collate_fn=pad_list_data_collate) @@ -108,14 +112,10 @@ def test_test_time_augmentation(self): epoch_loss /= len(train_loader) - image, label = custom_create_test_image_2d() - test_data = {"image": image, "label": label} - post_trans = Compose( [ Activations(sigmoid=True), AsDiscrete(threshold_values=True), - KeepLargestConnectedComponent(applied_labels=1), ] ) @@ -141,6 +141,12 @@ def test_fail_random_but_not_invertible(self): with self.assertRaises(RuntimeError): TestTimeAugmentation(transforms, None, None, None) + def test_single_transform(self): + transforms = RandFlipd(["image", "label"]) + tta = TestTimeAugmentation(transforms, batch_size=5, num_workers=0, inferrer_fn=lambda x: x) + tta(self.get_data(1, (20, 20))) + + if __name__ == "__main__": unittest.main() From 5c430ab5aed4ad83bf4b6149068ef4d6a96d5cbf Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 23 Mar 2021 11:35:58 +0000 Subject: [PATCH 12/14] code format Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_testtimeaugmentation.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/tests/test_testtimeaugmentation.py b/tests/test_testtimeaugmentation.py index 87df28460d..bee1aa4b0d 100644 --- a/tests/test_testtimeaugmentation.py +++ b/tests/test_testtimeaugmentation.py @@ -21,15 +21,7 @@ from monai.data.utils import pad_list_data_collate from monai.losses import DiceLoss from monai.networks.nets import UNet -from monai.transforms import ( - Activations, - AddChanneld, - AsDiscrete, - Compose, - CropForegroundd, - DivisiblePadd, - RandAffined, -) +from monai.transforms import Activations, AddChanneld, AsDiscrete, Compose, CropForegroundd, DivisiblePadd, RandAffined from monai.transforms.croppad.dictionary import SpatialPadd from monai.transforms.spatial.dictionary import Rand2DElasticd, RandFlipd from monai.utils import optional_import, set_determinism @@ -147,6 +139,5 @@ def test_single_transform(self): tta(self.get_data(1, (20, 20))) - if __name__ == "__main__": unittest.main() From bb2e1419eee2b2270954d2eac8cced8208dfbdb1 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 23 Mar 2021 14:08:39 +0000 Subject: [PATCH 13/14] convert error to warning Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/test_time_augmentation.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/monai/data/test_time_augmentation.py b/monai/data/test_time_augmentation.py index 51b95adc58..94285f5e7a 100644 --- a/monai/data/test_time_augmentation.py +++ b/monai/data/test_time_augmentation.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np @@ -100,13 +101,13 @@ def _check_transforms(self): invertibles = np.array([isinstance(t, InvertibleTransform) for t in ts]) # check at least 1 random if sum(randoms) == 0: - raise RuntimeError( + warnings.warn( "Requires a `Randomizable` transform or a `Compose` containing at least one `Randomizable` transform." ) # check that whenever randoms is True, invertibles is also true for r, i in zip(randoms, invertibles): if r and not i: - raise RuntimeError( + warnings.warn( f"All applied random transform(s) must be invertible. Problematic transform: {type(r).__name__}" ) From f1b92a7ba98dd0b227ceff571a0e9a15500774af Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 23 Mar 2021 14:08:39 +0000 Subject: [PATCH 14/14] Revert "convert error to warning" This reverts commit bb2e1419eee2b2270954d2eac8cced8208dfbdb1. Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/test_time_augmentation.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/monai/data/test_time_augmentation.py b/monai/data/test_time_augmentation.py index 94285f5e7a..51b95adc58 100644 --- a/monai/data/test_time_augmentation.py +++ b/monai/data/test_time_augmentation.py @@ -9,7 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import warnings from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np @@ -101,13 +100,13 @@ def _check_transforms(self): invertibles = np.array([isinstance(t, InvertibleTransform) for t in ts]) # check at least 1 random if sum(randoms) == 0: - warnings.warn( + raise RuntimeError( "Requires a `Randomizable` transform or a `Compose` containing at least one `Randomizable` transform." ) # check that whenever randoms is True, invertibles is also true for r, i in zip(randoms, invertibles): if r and not i: - warnings.warn( + raise RuntimeError( f"All applied random transform(s) must be invertible. Problematic transform: {type(r).__name__}" )