-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Test time augmentations #1794
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Test time augmentations #1794
Changes from all commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
0d24a6d
tta
rijobro 5bdd7f5
Merge remote-tracking branch 'MONAI/master' into tta
rijobro d19c3a6
update
rijobro ef07560
Merge remote-tracking branch 'MONAI/master' into tta
rijobro 33aa836
Merge remote-tracking branch 'MONAI/master' into tta
rijobro 0eaea7f
tta code review
rijobro 9a1277a
remove segmentation requirement
rijobro eb8f57d
add to data.rst
rijobro 799c449
Merge remote-tracking branch 'MONAI/master' into tta
rijobro 8d79231
update w master
rijobro 300b68d
Merge remote-tracking branch 'MONAI/master' into tta
rijobro fec04fb
explicit cast
rijobro e4e587a
Merge branch 'master' into tta
wyli 6b488c1
update for CommonKeys
rijobro e02a67c
torch cuda is available
rijobro 57f54d2
Merge branch 'master' into tta
rijobro 7cdf682
correct transform check
rijobro 61051c3
add single transform test, remove KeepLargestConnectedComponent
rijobro 5c430ab
code format
rijobro bb2e141
convert error to warning
rijobro c6dc1ef
Merge remote-tracking branch 'MONAI/master' into tta
rijobro f1b92a7
Revert "convert error to warning"
rijobro File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,178 @@ | ||
| # Copyright 2020 - 2021 MONAI Consortium | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
||
| from monai.data.dataloader import DataLoader | ||
| from monai.data.dataset import Dataset | ||
| from monai.data.inverse_batch_transform import BatchInverseTransform | ||
| from monai.data.utils import list_data_collate, pad_list_data_collate | ||
| from monai.transforms.compose import Compose | ||
| from monai.transforms.inverse import InvertibleTransform | ||
| from monai.transforms.transform import RandomizableTransform | ||
| from monai.transforms.utils import allow_missing_keys_mode | ||
| from monai.utils.enums import CommonKeys, InverseKeys | ||
|
|
||
| __all__ = ["TestTimeAugmentation"] | ||
|
|
||
|
|
||
| class TestTimeAugmentation: | ||
| """ | ||
| Class for performing test time augmentations. This will pass the same image through the network multiple times. | ||
|
|
||
| The user passes transform(s) to be applied to each realisation, and provided that at least one of those transforms | ||
| is random, the network's output will vary. Provided that inverse transformations exist for all supplied spatial | ||
| transforms, the inverse can be applied to each realisation of the network's output. Once in the same spatial | ||
| reference, the results can then be combined and metrics computed. | ||
|
|
||
| Test time augmentations are a useful feature for computing network uncertainty, as well as observing the network's | ||
| dependency on the applied random transforms. | ||
|
|
||
| Reference: | ||
| Wang et al., | ||
| Aleatoric uncertainty estimation with test-time augmentation for medical image segmentation with convolutional | ||
| neural networks, | ||
| https://doi.org/10.1016/j.neucom.2019.01.103 | ||
|
|
||
| Args: | ||
| transform: transform (or composed) to be applied to each realisation. At least one transform must be of type | ||
| `RandomizableTransform`. All random transforms must be of type `InvertibleTransform`. | ||
| batch_size: number of realisations to infer at once. | ||
| num_workers: how many subprocesses to use for data. | ||
| inferrer_fn: function to use to perform inference. | ||
| device: device on which to perform inference. | ||
| image_key: key used to extract image from input dictionary. | ||
| label_key: key used to extract label from input dictionary. | ||
| return_full_data: normally, metrics are returned (mode, mean, std, vvc). Setting this flag to `True` will return the | ||
| full data. Dimensions will be same size as when passing a single image through `inferrer_fn`, with a dimension appended | ||
| equal in size to `num_examples` (N), i.e., `[N,C,H,W,[D]]`. | ||
|
|
||
| Example: | ||
| .. code-block:: python | ||
|
|
||
| transform = RandAffined(keys, ...) | ||
| post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) | ||
|
|
||
| tt_aug = TestTimeAugmentation( | ||
| transform, batch_size=5, num_workers=0, inferrer_fn=lambda x: post_trans(model(x)), device=device | ||
| ) | ||
| mode, mean, std, vvc = tt_aug(test_data) | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| transform: InvertibleTransform, | ||
| batch_size: int, | ||
| num_workers: int, | ||
| inferrer_fn: Callable, | ||
| device: Optional[Union[str, torch.device]] = "cuda" if torch.cuda.is_available() else "cpu", | ||
| image_key=CommonKeys.IMAGE, | ||
| label_key=CommonKeys.LABEL, | ||
| return_full_data: bool = False, | ||
| ) -> None: | ||
| self.transform = transform | ||
| self.batch_size = batch_size | ||
| self.num_workers = num_workers | ||
| self.inferrer_fn = inferrer_fn | ||
| self.device = device | ||
| self.image_key = image_key | ||
| self.label_key = label_key | ||
| self.return_full_data = return_full_data | ||
|
|
||
| # check that the transform has at least one random component, and that all random transforms are invertible | ||
| self._check_transforms() | ||
|
|
||
| def _check_transforms(self): | ||
| """Should be at least 1 random transform, and all random transforms should be invertible.""" | ||
| ts = [self.transform] if not isinstance(self.transform, Compose) else self.transform.transforms | ||
| randoms = np.array([isinstance(t, RandomizableTransform) for t in ts]) | ||
| invertibles = np.array([isinstance(t, InvertibleTransform) for t in ts]) | ||
| # check at least 1 random | ||
| if sum(randoms) == 0: | ||
| raise RuntimeError( | ||
| "Requires a `Randomizable` transform or a `Compose` containing at least one `Randomizable` transform." | ||
| ) | ||
| # check that whenever randoms is True, invertibles is also true | ||
| for r, i in zip(randoms, invertibles): | ||
| if r and not i: | ||
| raise RuntimeError( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this could be a warning because the random factors may be in the model, we may want to do test aug with a randomised model... |
||
| f"All applied random transform(s) must be invertible. Problematic transform: {type(r).__name__}" | ||
| ) | ||
|
|
||
| def __call__( | ||
| self, data: Dict[str, Any], num_examples: int = 10 | ||
| ) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, float], np.ndarray]: | ||
| """ | ||
| Args: | ||
| data: dictionary data to be processed. | ||
| num_examples: number of realisations to be processed and results combined. | ||
|
|
||
| Returns: | ||
| - if `return_full_data==False`: mode, mean, std, vvc. The mode, mean and standard deviation are calculated across | ||
| `num_examples` outputs at each voxel. The volume variation coefficient (VVC) is `std/mean` across the whole output, | ||
| including `num_examples`. See original paper for clarification. | ||
| - if `return_full_data==False`: data is returned as-is after applying the `inferrer_fn` and then concatenating across | ||
| the first dimension containing `num_examples`. This allows the user to perform their own analysis if desired. | ||
| """ | ||
| d = dict(data) | ||
|
|
||
| # check num examples is multiple of batch size | ||
| if num_examples % self.batch_size != 0: | ||
| raise ValueError("num_examples should be multiple of batch size.") | ||
|
|
||
| # generate batch of data of size == batch_size, dataset and dataloader | ||
| data_in = [d] * num_examples | ||
| ds = Dataset(data_in, self.transform) | ||
| dl = DataLoader(ds, self.num_workers, batch_size=self.batch_size, collate_fn=pad_list_data_collate) | ||
|
|
||
| label_transform_key = self.label_key + InverseKeys.KEY_SUFFIX | ||
|
|
||
| # create inverter | ||
| inverter = BatchInverseTransform(self.transform, dl, collate_fn=list_data_collate) | ||
|
|
||
| outputs: List[np.ndarray] = [] | ||
|
|
||
| for batch_data in dl: | ||
|
|
||
| batch_images = batch_data[self.image_key].to(self.device) | ||
|
|
||
| # do model forward pass | ||
| batch_output = self.inferrer_fn(batch_images) | ||
| if isinstance(batch_output, torch.Tensor): | ||
| batch_output = batch_output.detach().cpu() | ||
| if isinstance(batch_output, np.ndarray): | ||
| batch_output = torch.Tensor(batch_output) | ||
|
|
||
| # create a dictionary containing the inferred batch and their transforms | ||
| inferred_dict = {self.label_key: batch_output, label_transform_key: batch_data[label_transform_key]} | ||
|
|
||
| # do inverse transformation (allow missing keys as only inverting label) | ||
| with allow_missing_keys_mode(self.transform): # type: ignore | ||
| inv_batch = inverter(inferred_dict) | ||
|
|
||
| # append | ||
| outputs.append(inv_batch[self.label_key]) | ||
|
|
||
| # output | ||
| output: np.ndarray = np.concatenate(outputs) | ||
|
|
||
| if self.return_full_data: | ||
| return output | ||
|
|
||
| # calculate metrics | ||
| mode: np.ndarray = np.apply_along_axis(lambda x: np.bincount(x).argmax(), axis=0, arr=output.astype(np.int64)) | ||
| mean: np.ndarray = np.mean(output, axis=0) # type: ignore | ||
| std: np.ndarray = np.std(output, axis=0) # type: ignore | ||
| vvc: float = (np.std(output) / np.mean(output)).item() | ||
| return mode, mean, std, vvc | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,143 @@ | ||
| # Copyright 2020 - 2021 MONAI Consortium | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import unittest | ||
| from functools import partial | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
||
| from monai.data import CacheDataset, DataLoader, create_test_image_2d | ||
| from monai.data.test_time_augmentation import TestTimeAugmentation | ||
| from monai.data.utils import pad_list_data_collate | ||
| from monai.losses import DiceLoss | ||
| from monai.networks.nets import UNet | ||
| from monai.transforms import Activations, AddChanneld, AsDiscrete, Compose, CropForegroundd, DivisiblePadd, RandAffined | ||
| from monai.transforms.croppad.dictionary import SpatialPadd | ||
| from monai.transforms.spatial.dictionary import Rand2DElasticd, RandFlipd | ||
| from monai.utils import optional_import, set_determinism | ||
|
|
||
| if TYPE_CHECKING: | ||
| import tqdm | ||
|
|
||
| has_tqdm = True | ||
| else: | ||
| tqdm, has_tqdm = optional_import("tqdm") | ||
|
|
||
| trange = partial(tqdm.trange, desc="training") if has_tqdm else range | ||
|
|
||
|
|
||
| class TestTestTimeAugmentation(unittest.TestCase): | ||
| @staticmethod | ||
| def get_data(num_examples, input_size): | ||
| custom_create_test_image_2d = partial( | ||
| create_test_image_2d, *input_size, rad_max=7, num_seg_classes=1, num_objs=1 | ||
| ) | ||
| data = [] | ||
| for _ in range(num_examples): | ||
| im, label = custom_create_test_image_2d() | ||
| data.append({"image": im, "label": label}) | ||
| return data[0] if num_examples == 1 else data | ||
|
|
||
| def setUp(self) -> None: | ||
| set_determinism(seed=0) | ||
|
|
||
| def tearDown(self) -> None: | ||
| set_determinism(None) | ||
|
|
||
| def test_test_time_augmentation(self): | ||
| input_size = (20, 20) | ||
| device = "cuda" if torch.cuda.is_available() else "cpu" | ||
| keys = ["image", "label"] | ||
| num_training_ims = 10 | ||
| train_data = self.get_data(num_training_ims, input_size) | ||
| test_data = self.get_data(1, input_size) | ||
|
|
||
| transforms = Compose( | ||
| [ | ||
| AddChanneld(keys), | ||
| RandAffined( | ||
| keys, | ||
| prob=1.0, | ||
| spatial_size=(30, 30), | ||
| rotate_range=(np.pi / 3, np.pi / 3), | ||
| translate_range=(3, 3), | ||
| scale_range=((0.8, 1), (0.8, 1)), | ||
| padding_mode="zeros", | ||
| mode=("bilinear", "nearest"), | ||
| as_tensor_output=False, | ||
| ), | ||
| CropForegroundd(keys, source_key="image"), | ||
| DivisiblePadd(keys, 4), | ||
| ] | ||
| ) | ||
|
|
||
| train_ds = CacheDataset(train_data, transforms) | ||
| # output might be different size, so pad so that they match | ||
| train_loader = DataLoader(train_ds, batch_size=2, collate_fn=pad_list_data_collate) | ||
|
|
||
| model = UNet(2, 1, 1, channels=(6, 6), strides=(2, 2)).to(device) | ||
| loss_function = DiceLoss(sigmoid=True) | ||
| optimizer = torch.optim.Adam(model.parameters(), 1e-3) | ||
|
|
||
| num_epochs = 10 | ||
| for _ in trange(num_epochs): | ||
| epoch_loss = 0 | ||
|
|
||
| for batch_data in train_loader: | ||
| inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device) | ||
| optimizer.zero_grad() | ||
| outputs = model(inputs) | ||
| loss = loss_function(outputs, labels) | ||
| loss.backward() | ||
| optimizer.step() | ||
| epoch_loss += loss.item() | ||
|
|
||
| epoch_loss /= len(train_loader) | ||
|
|
||
| post_trans = Compose( | ||
| [ | ||
| Activations(sigmoid=True), | ||
| AsDiscrete(threshold_values=True), | ||
| ] | ||
| ) | ||
|
|
||
| def inferrer_fn(x): | ||
| return post_trans(model(x)) | ||
|
|
||
| tt_aug = TestTimeAugmentation(transforms, batch_size=5, num_workers=0, inferrer_fn=inferrer_fn, device=device) | ||
| mode, mean, std, vvc = tt_aug(test_data) | ||
| self.assertEqual(mode.shape, (1,) + input_size) | ||
| self.assertEqual(mean.shape, (1,) + input_size) | ||
| self.assertTrue(all(np.unique(mode) == (0, 1))) | ||
| self.assertEqual((mean.min(), mean.max()), (0.0, 1.0)) | ||
| self.assertEqual(std.shape, (1,) + input_size) | ||
| self.assertIsInstance(vvc, float) | ||
|
|
||
| def test_fail_non_random(self): | ||
| transforms = Compose([AddChanneld("im"), SpatialPadd("im", 1)]) | ||
| with self.assertRaises(RuntimeError): | ||
| TestTimeAugmentation(transforms, None, None, None) | ||
|
|
||
| def test_fail_random_but_not_invertible(self): | ||
| transforms = Compose([AddChanneld("im"), Rand2DElasticd("im", None, None)]) | ||
| with self.assertRaises(RuntimeError): | ||
| TestTimeAugmentation(transforms, None, None, None) | ||
|
|
||
| def test_single_transform(self): | ||
| transforms = RandFlipd(["image", "label"]) | ||
| tta = TestTimeAugmentation(transforms, batch_size=5, num_workers=0, inferrer_fn=lambda x: x) | ||
| tta(self.get_data(1, (20, 20))) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.