Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,7 @@ ThreadBuffer
BatchInverseTransform
~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: monai.data.BatchInverseTransform

TestTimeAugmentation
~~~~~~~~~~~~~~~~~~~~
.. autoclass:: monai.data.TestTimeAugmentation
1 change: 1 addition & 0 deletions monai/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from .png_writer import write_png
from .samplers import DistributedSampler, DistributedWeightedRandomSampler
from .synthetic import create_test_image_2d, create_test_image_3d
from .test_time_augmentation import TestTimeAugmentation
from .thread_buffer import ThreadBuffer, ThreadDataLoader
from .utils import (
compute_importance_map,
Expand Down
178 changes: 178 additions & 0 deletions monai/data/test_time_augmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import torch

from monai.data.dataloader import DataLoader
from monai.data.dataset import Dataset
from monai.data.inverse_batch_transform import BatchInverseTransform
from monai.data.utils import list_data_collate, pad_list_data_collate
from monai.transforms.compose import Compose
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.transform import RandomizableTransform
from monai.transforms.utils import allow_missing_keys_mode
from monai.utils.enums import CommonKeys, InverseKeys

__all__ = ["TestTimeAugmentation"]


class TestTimeAugmentation:
"""
Class for performing test time augmentations. This will pass the same image through the network multiple times.

The user passes transform(s) to be applied to each realisation, and provided that at least one of those transforms
is random, the network's output will vary. Provided that inverse transformations exist for all supplied spatial
transforms, the inverse can be applied to each realisation of the network's output. Once in the same spatial
reference, the results can then be combined and metrics computed.

Test time augmentations are a useful feature for computing network uncertainty, as well as observing the network's
dependency on the applied random transforms.

Reference:
Wang et al.,
Aleatoric uncertainty estimation with test-time augmentation for medical image segmentation with convolutional
neural networks,
https://doi.org/10.1016/j.neucom.2019.01.103

Args:
transform: transform (or composed) to be applied to each realisation. At least one transform must be of type
`RandomizableTransform`. All random transforms must be of type `InvertibleTransform`.
batch_size: number of realisations to infer at once.
num_workers: how many subprocesses to use for data.
inferrer_fn: function to use to perform inference.
device: device on which to perform inference.
image_key: key used to extract image from input dictionary.
label_key: key used to extract label from input dictionary.
return_full_data: normally, metrics are returned (mode, mean, std, vvc). Setting this flag to `True` will return the
full data. Dimensions will be same size as when passing a single image through `inferrer_fn`, with a dimension appended
equal in size to `num_examples` (N), i.e., `[N,C,H,W,[D]]`.

Example:
.. code-block:: python

transform = RandAffined(keys, ...)
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)])

tt_aug = TestTimeAugmentation(
transform, batch_size=5, num_workers=0, inferrer_fn=lambda x: post_trans(model(x)), device=device
)
mode, mean, std, vvc = tt_aug(test_data)
"""

def __init__(
self,
transform: InvertibleTransform,
batch_size: int,
num_workers: int,
inferrer_fn: Callable,
device: Optional[Union[str, torch.device]] = "cuda" if torch.cuda.is_available() else "cpu",
image_key=CommonKeys.IMAGE,
label_key=CommonKeys.LABEL,
return_full_data: bool = False,
) -> None:
self.transform = transform
self.batch_size = batch_size
self.num_workers = num_workers
self.inferrer_fn = inferrer_fn
self.device = device
self.image_key = image_key
self.label_key = label_key
self.return_full_data = return_full_data

# check that the transform has at least one random component, and that all random transforms are invertible
self._check_transforms()

def _check_transforms(self):
"""Should be at least 1 random transform, and all random transforms should be invertible."""
ts = [self.transform] if not isinstance(self.transform, Compose) else self.transform.transforms
randoms = np.array([isinstance(t, RandomizableTransform) for t in ts])
invertibles = np.array([isinstance(t, InvertibleTransform) for t in ts])
# check at least 1 random
if sum(randoms) == 0:
raise RuntimeError(
"Requires a `Randomizable` transform or a `Compose` containing at least one `Randomizable` transform."
)
# check that whenever randoms is True, invertibles is also true
for r, i in zip(randoms, invertibles):
if r and not i:
raise RuntimeError(
Copy link
Contributor

@wyli wyli Mar 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this could be a warning because the random factors may be in the model, we may want to do test aug with a randomised model...

f"All applied random transform(s) must be invertible. Problematic transform: {type(r).__name__}"
)

def __call__(
self, data: Dict[str, Any], num_examples: int = 10
) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, float], np.ndarray]:
"""
Args:
data: dictionary data to be processed.
num_examples: number of realisations to be processed and results combined.

Returns:
- if `return_full_data==False`: mode, mean, std, vvc. The mode, mean and standard deviation are calculated across
`num_examples` outputs at each voxel. The volume variation coefficient (VVC) is `std/mean` across the whole output,
including `num_examples`. See original paper for clarification.
- if `return_full_data==False`: data is returned as-is after applying the `inferrer_fn` and then concatenating across
the first dimension containing `num_examples`. This allows the user to perform their own analysis if desired.
"""
d = dict(data)

# check num examples is multiple of batch size
if num_examples % self.batch_size != 0:
raise ValueError("num_examples should be multiple of batch size.")

# generate batch of data of size == batch_size, dataset and dataloader
data_in = [d] * num_examples
ds = Dataset(data_in, self.transform)
dl = DataLoader(ds, self.num_workers, batch_size=self.batch_size, collate_fn=pad_list_data_collate)

label_transform_key = self.label_key + InverseKeys.KEY_SUFFIX

# create inverter
inverter = BatchInverseTransform(self.transform, dl, collate_fn=list_data_collate)

outputs: List[np.ndarray] = []

for batch_data in dl:

batch_images = batch_data[self.image_key].to(self.device)

# do model forward pass
batch_output = self.inferrer_fn(batch_images)
if isinstance(batch_output, torch.Tensor):
batch_output = batch_output.detach().cpu()
if isinstance(batch_output, np.ndarray):
batch_output = torch.Tensor(batch_output)

# create a dictionary containing the inferred batch and their transforms
inferred_dict = {self.label_key: batch_output, label_transform_key: batch_data[label_transform_key]}

# do inverse transformation (allow missing keys as only inverting label)
with allow_missing_keys_mode(self.transform): # type: ignore
inv_batch = inverter(inferred_dict)

# append
outputs.append(inv_batch[self.label_key])

# output
output: np.ndarray = np.concatenate(outputs)

if self.return_full_data:
return output

# calculate metrics
mode: np.ndarray = np.apply_along_axis(lambda x: np.bincount(x).argmax(), axis=0, arr=output.astype(np.int64))
mean: np.ndarray = np.mean(output, axis=0) # type: ignore
std: np.ndarray = np.std(output, axis=0) # type: ignore
vvc: float = (np.std(output) / np.mean(output)).item()
return mode, mean, std, vvc
143 changes: 143 additions & 0 deletions tests/test_testtimeaugmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
from functools import partial
from typing import TYPE_CHECKING

import numpy as np
import torch

from monai.data import CacheDataset, DataLoader, create_test_image_2d
from monai.data.test_time_augmentation import TestTimeAugmentation
from monai.data.utils import pad_list_data_collate
from monai.losses import DiceLoss
from monai.networks.nets import UNet
from monai.transforms import Activations, AddChanneld, AsDiscrete, Compose, CropForegroundd, DivisiblePadd, RandAffined
from monai.transforms.croppad.dictionary import SpatialPadd
from monai.transforms.spatial.dictionary import Rand2DElasticd, RandFlipd
from monai.utils import optional_import, set_determinism

if TYPE_CHECKING:
import tqdm

has_tqdm = True
else:
tqdm, has_tqdm = optional_import("tqdm")

trange = partial(tqdm.trange, desc="training") if has_tqdm else range


class TestTestTimeAugmentation(unittest.TestCase):
@staticmethod
def get_data(num_examples, input_size):
custom_create_test_image_2d = partial(
create_test_image_2d, *input_size, rad_max=7, num_seg_classes=1, num_objs=1
)
data = []
for _ in range(num_examples):
im, label = custom_create_test_image_2d()
data.append({"image": im, "label": label})
return data[0] if num_examples == 1 else data

def setUp(self) -> None:
set_determinism(seed=0)

def tearDown(self) -> None:
set_determinism(None)

def test_test_time_augmentation(self):
input_size = (20, 20)
device = "cuda" if torch.cuda.is_available() else "cpu"
keys = ["image", "label"]
num_training_ims = 10
train_data = self.get_data(num_training_ims, input_size)
test_data = self.get_data(1, input_size)

transforms = Compose(
[
AddChanneld(keys),
RandAffined(
keys,
prob=1.0,
spatial_size=(30, 30),
rotate_range=(np.pi / 3, np.pi / 3),
translate_range=(3, 3),
scale_range=((0.8, 1), (0.8, 1)),
padding_mode="zeros",
mode=("bilinear", "nearest"),
as_tensor_output=False,
),
CropForegroundd(keys, source_key="image"),
DivisiblePadd(keys, 4),
]
)

train_ds = CacheDataset(train_data, transforms)
# output might be different size, so pad so that they match
train_loader = DataLoader(train_ds, batch_size=2, collate_fn=pad_list_data_collate)

model = UNet(2, 1, 1, channels=(6, 6), strides=(2, 2)).to(device)
loss_function = DiceLoss(sigmoid=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-3)

num_epochs = 10
for _ in trange(num_epochs):
epoch_loss = 0

for batch_data in train_loader:
inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_function(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()

epoch_loss /= len(train_loader)

post_trans = Compose(
[
Activations(sigmoid=True),
AsDiscrete(threshold_values=True),
]
)

def inferrer_fn(x):
return post_trans(model(x))

tt_aug = TestTimeAugmentation(transforms, batch_size=5, num_workers=0, inferrer_fn=inferrer_fn, device=device)
mode, mean, std, vvc = tt_aug(test_data)
self.assertEqual(mode.shape, (1,) + input_size)
self.assertEqual(mean.shape, (1,) + input_size)
self.assertTrue(all(np.unique(mode) == (0, 1)))
self.assertEqual((mean.min(), mean.max()), (0.0, 1.0))
self.assertEqual(std.shape, (1,) + input_size)
self.assertIsInstance(vvc, float)

def test_fail_non_random(self):
transforms = Compose([AddChanneld("im"), SpatialPadd("im", 1)])
with self.assertRaises(RuntimeError):
TestTimeAugmentation(transforms, None, None, None)

def test_fail_random_but_not_invertible(self):
transforms = Compose([AddChanneld("im"), Rand2DElasticd("im", None, None)])
with self.assertRaises(RuntimeError):
TestTimeAugmentation(transforms, None, None, None)

def test_single_transform(self):
transforms = RandFlipd(["image", "label"])
tta = TestTimeAugmentation(transforms, batch_size=5, num_workers=0, inferrer_fn=lambda x: x)
tta(self.get_data(1, (20, 20)))


if __name__ == "__main__":
unittest.main()