Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
103 commits
Select commit Hold shift + click to select a range
18cc88c
move Transfrom out of compose file
rijobro Jan 22, 2021
18b53f5
add transform file
rijobro Jan 22, 2021
827dd84
inverse compose and spatialpadd
rijobro Jan 28, 2021
7851e4f
autofixes
rijobro Jan 28, 2021
b173914
Merge remote-tracking branch 'MONAI/master' into inverse_transforms
rijobro Jan 28, 2021
9d08f26
extra test
rijobro Jan 28, 2021
1810be3
Merge branch 'inverse_transforms' of https://github.com/rijobro/MONAI…
rijobro Jan 28, 2021
a54869b
serialisation of transform
rijobro Jan 28, 2021
b310c19
add rotate
rijobro Jan 28, 2021
f2edc9c
rotate w/ keep_size=True
rijobro Jan 28, 2021
41bec26
rotate w/ keep_size=False
rijobro Jan 29, 2021
75ad1fc
Merge remote-tracking branch 'MONAI/master' into inverse_transforms
rijobro Jan 29, 2021
8768d57
autofix
rijobro Jan 29, 2021
3b91c44
randrotated
rijobro Jan 29, 2021
e569b82
randrotated
rijobro Jan 29, 2021
6cff048
3d rotation not working
rijobro Jan 29, 2021
c10b7e3
works for dataloader
rijobro Jan 29, 2021
1800c6a
throw error for 3d rotation inverse
rijobro Jan 29, 2021
be53888
dataloader
rijobro Jan 29, 2021
58fdcf4
autofixes
rijobro Jan 29, 2021
5bd4d1c
add constructors for Randomizable class
rijobro Jan 29, 2021
50e7238
testing
rijobro Jan 29, 2021
9241cf7
start adding spatialcropd
rijobro Jan 29, 2021
46471a6
update tests
rijobro Feb 1, 2021
901f56b
crop
rijobro Feb 1, 2021
e898188
crop finished
rijobro Feb 1, 2021
31ebc76
RandSpatialCropd
rijobro Feb 1, 2021
3ad7fa8
Merge branch 'master' into inverse_transforms
rijobro Feb 1, 2021
40bc788
BorderPadd
rijobro Feb 1, 2021
695ea7b
Merge remote-tracking branch 'MONAI/master' into inverse_transforms
rijobro Feb 2, 2021
7405c36
update after git merge
rijobro Feb 2, 2021
1ff5729
DivisiblePadd
rijobro Feb 2, 2021
b73df2d
Flipd
rijobro Feb 2, 2021
667273a
start adding orientationd
rijobro Feb 2, 2021
b1e077c
tidy rotate3d
rijobro Feb 2, 2021
898f2d0
Orientationd
rijobro Feb 2, 2021
b8c1b8a
Rotate90d
rijobro Feb 2, 2021
3020fd4
Zoomd
rijobro Feb 2, 2021
9034d7b
CenterSpatialCropd
rijobro Feb 2, 2021
69927ea
CropForegroundd
rijobro Feb 2, 2021
2f48ad7
Spacingd
rijobro Feb 3, 2021
43155c4
Resized
rijobro Feb 3, 2021
1bf0015
ResizeWithPadOrCropd
rijobro Feb 3, 2021
a2e13d4
Rotated and RandRotated 3d
rijobro Feb 3, 2021
749d925
RandZoomd
rijobro Feb 3, 2021
944e4f1
RandFlipd
rijobro Feb 3, 2021
2b5450a
RandRotate90d
rijobro Feb 3, 2021
c798e2c
RandAffined to call correct constructor
rijobro Feb 3, 2021
9b51e6b
RandAffined (not finished)
rijobro Feb 3, 2021
521184b
RandAffined
rijobro Feb 4, 2021
5a0712f
correct returning of affine for RandAffined
rijobro Feb 4, 2021
936eaf5
RandElastic
rijobro Feb 5, 2021
c0decd5
cpg -> disp
rijobro Feb 8, 2021
31ddf65
SimpleITK for inverse nonrigid
rijobro Feb 8, 2021
eee8004
Merge remote-tracking branch 'MONAI/master' into inverse_transforms
rijobro Feb 8, 2021
7ec4596
code format
rijobro Feb 8, 2021
0a1e73b
Merge remote-tracking branch 'MONAI/master' into inverse_transforms
rijobro Feb 8, 2021
d03afd7
decollate batch
rijobro Feb 9, 2021
e200789
decollate2
rijobro Feb 9, 2021
c8aaf8c
need to remove all init_args
rijobro Feb 10, 2021
cbc5b7e
Merge remote-tracking branch 'MONAI/master' into inverse_transforms
rijobro Feb 10, 2021
f1ebd7a
remove init_args
rijobro Feb 10, 2021
48fbff2
working with data loader
rijobro Feb 10, 2021
1cbc8f7
code format
rijobro Feb 10, 2021
57c3e49
Merge remote-tracking branch 'MONAI/master' into inverse_transforms
rijobro Feb 10, 2021
b60d206
RandCropByPosNegLabeld to call correct parent constructor
rijobro Feb 11, 2021
61e56be
vtk attempt
rijobro Feb 11, 2021
83502e5
more vtk progress
rijobro Feb 11, 2021
3284a07
no error
rijobro Feb 11, 2021
884a407
batchdataset inherits from monai dataset
rijobro Feb 11, 2021
0834b06
inverse with vtk or sitk
rijobro Feb 16, 2021
cf238d2
option for no collation after batch inverse
rijobro Feb 16, 2021
0047275
create same keys whether random transform used or not
rijobro Feb 17, 2021
a3055bf
pad collate
rijobro Feb 19, 2021
ebee3c5
batch inverse improvement
rijobro Feb 19, 2021
0f955d3
bug fix in inverse RandAffined
rijobro Feb 19, 2021
917bb17
test for pad_collation
rijobro Feb 19, 2021
8de158f
Revert "test for pad_collation"
rijobro Feb 19, 2021
a3194bd
Revert "pad collate"
rijobro Feb 19, 2021
e4d6f00
pad_collation
rijobro Feb 19, 2021
639bb32
Merge remote-tracking branch 'MONAI/master' into inverse_transforms
rijobro Feb 19, 2021
15bbf9a
codeformate
rijobro Feb 19, 2021
595119f
Compose len
rijobro Feb 19, 2021
e2d63ae
inverse batch and fixes
rijobro Feb 19, 2021
e515a78
codeformat
rijobro Feb 19, 2021
3d6fdba
Merge remote-tracking branch 'MONAI/master' into inverse_transforms
rijobro Feb 19, 2021
b083067
inverse pad collation
rijobro Feb 22, 2021
c4b81e5
update test threshold
rijobro Feb 22, 2021
0964562
Merge remote-tracking branch 'MONAI/master' into inverse_transforms
rijobro Feb 22, 2021
e9ab4d3
Merge remote-tracking branch 'MONAI/master' into inverse_transforms
rijobro Feb 22, 2021
70e18ca
update transform location after deepgrow merge
rijobro Feb 23, 2021
1b7beba
Merge branch 'master' into inverse_transforms
rijobro Feb 23, 2021
ad7e9ff
TTA
rijobro Feb 23, 2021
d7a9c34
Merge remote-tracking branch 'MONAI/master' into inverse_transforms
rijobro Feb 23, 2021
19cf1d0
code format changes
rijobro Feb 23, 2021
ed789d1
Merge remote-tracking branch 'MONAI/master' into inverse_transforms
rijobro Feb 23, 2021
4668ca2
Merge remote-tracking branch 'MONAI/master' into inverse_transforms
rijobro Feb 23, 2021
b5747d9
Merge remote-tracking branch 'MONAI/master' into inverse_transforms
rijobro Feb 23, 2021
e0aa4a4
Merge remote-tracking branch 'MONAI/master' into inverse_transforms
rijobro Feb 23, 2021
700a520
Merge remote-tracking branch 'MONAI/master' into inverse_transforms
rijobro Feb 24, 2021
e13c902
Merge remote-tracking branch 'MONAI/master' into inverse_transforms
rijobro Feb 25, 2021
cfa402a
Merge remote-tracking branch 'MONAI/master' into inverse_transforms
rijobro Feb 25, 2021
83724d2
Merge remote-tracking branch 'MONAI/master' into inverse_transforms
rijobro Feb 25, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions monai/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@
from .grid_dataset import GridPatchDataset, PatchDataset
from .image_dataset import ImageDataset
from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader
from .inverse_batch_transform import BatchInverseTransform
from .iterable_dataset import IterableDataset
from .nifti_saver import NiftiSaver
from .nifti_writer import write_nifti
from .png_saver import PNGSaver
from .png_writer import write_png
from .synthetic import create_test_image_2d, create_test_image_3d
from .test_time_augmentation import TestTimeAugmentation
from .thread_buffer import ThreadBuffer
from .utils import (
DistributedSampler,
Expand Down
12 changes: 5 additions & 7 deletions monai/data/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@
import numpy as np
from torch.utils.data._utils.collate import np_str_obj_array_pattern

import monai.data.utils
from monai.config import DtypeLike, KeysCollection
from monai.data.utils import correct_nifti_header_if_necessary
from monai.utils import ensure_tuple, optional_import

from .utils import is_supported_format

if TYPE_CHECKING:
import itk # type: ignore
import nibabel as nib
Expand Down Expand Up @@ -322,7 +320,7 @@ def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool:

"""
suffixes: Sequence[str] = ["nii", "nii.gz"]
return has_nib and is_supported_format(filename, suffixes)
return has_nib and monai.data.is_supported_format(filename, suffixes)

def read(self, data: Union[Sequence[str], str], **kwargs):
"""
Expand All @@ -343,7 +341,7 @@ def read(self, data: Union[Sequence[str], str], **kwargs):
kwargs_.update(kwargs)
for name in filenames:
img = nib.load(name, **kwargs_)
img = correct_nifti_header_if_necessary(img)
img = monai.data.utils.correct_nifti_header_if_necessary(img)
img_.append(img)
return img_ if len(filenames) > 1 else img_[0]

Expand Down Expand Up @@ -453,7 +451,7 @@ def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool:
if a list of files, verify all the suffixes.
"""
suffixes: Sequence[str] = ["npz", "npy"]
return is_supported_format(filename, suffixes)
return monai.data.is_supported_format(filename, suffixes)

def read(self, data: Union[Sequence[str], str], **kwargs):
"""
Expand Down Expand Up @@ -537,7 +535,7 @@ def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool:
if a list of files, verify all the suffixes.
"""
suffixes: Sequence[str] = ["png", "jpg", "jpeg", "bmp"]
return has_pil and is_supported_format(filename, suffixes)
return has_pil and monai.data.is_supported_format(filename, suffixes)

def read(self, data: Union[Sequence[str], str, np.ndarray], **kwargs):
"""
Expand Down
90 changes: 90 additions & 0 deletions monai/data/inverse_batch_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, Dict, Hashable, Optional, Tuple

import numpy as np
from torch.utils.data.dataloader import DataLoader as TorchDataLoader

from monai.data.dataloader import DataLoader
from monai.data.dataset import Dataset
from monai.data.utils import decollate_batch, pad_list_data_collate
from monai.transforms.croppad.array import CenterSpatialCrop
from monai.transforms.inverse_transform import InvertibleTransform
from monai.utils import first
from monai.utils.misc import ensure_tuple

__all__ = ["BatchInverseTransform"]


class _BatchInverseDataset(Dataset):
def __init__(
self,
data: Dict[str, Any],
transform: InvertibleTransform,
keys: Optional[Tuple[Hashable, ...]],
pad_collation_used: bool,
) -> None:
self.data = decollate_batch(data)
self.invertible_transform = transform
self.keys = ensure_tuple(keys) if keys else None
self.pad_collation_used = pad_collation_used

def __getitem__(self, index: int) -> Dict[Hashable, np.ndarray]:
data = dict(self.data[index])
# If pad collation was used, then we need to undo this first
if self.pad_collation_used:
keys = self.keys or [key for key in data.keys() if str(key) + "_transforms" in data.keys()]
for key in keys:
transform_key = str(key) + "_transforms"
transform = data[transform_key][-1]
if transform["class"] == "SpatialPadd":
data[key] = CenterSpatialCrop(transform["orig_size"])(data[key])
# remove transform
data[transform_key].pop()

return self.invertible_transform.inverse(data, self.keys)


class BatchInverseTransform:
"""something"""

def __init__(
self, transform: InvertibleTransform, loader: TorchDataLoader, collate_fn: Optional[Callable] = None
) -> None:
"""
Args:
transform: a callable data transform on input data.
loader: data loader used to generate the batch of data.
collate_fn: how to collate data after inverse transformations. Default will use the DataLoader's default
collation method. If returning images of different sizes, this will likely create an error (since the
collation will concatenate arrays, requiring them to be the same size). In this case, using
`collate_fn=lambda x: x` might solve the problem.
"""
self.transform = transform
self.batch_size = loader.batch_size
self.num_workers = loader.num_workers
self.collate_fn = collate_fn
self.pad_collation_used = loader.collate_fn == pad_list_data_collate

def __call__(self, data: Dict[str, Any], keys: Optional[Tuple[Hashable, ...]] = None) -> Any:

inv_ds = _BatchInverseDataset(data, self.transform, keys, self.pad_collation_used)
inv_loader = DataLoader(
inv_ds, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=self.collate_fn
)
try:
return first(inv_loader)
except RuntimeError as re:
re_str = str(re)
if "stack expects each tensor to be equal size" in re_str:
re_str += "\nMONAI hint: try creating `BatchInverseTransform` with `collate_fn=lambda x: x`."
raise RuntimeError(re_str)
4 changes: 2 additions & 2 deletions monai/data/nifti_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
import numpy as np
import torch

import monai.data.utils
from monai.config import DtypeLike
from monai.data.nifti_writer import write_nifti
from monai.data.utils import create_file_basename
from monai.utils import GridSampleMode, GridSamplePadMode
from monai.utils import ImageMetaKey as Key

Expand Down Expand Up @@ -104,7 +104,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict]
if isinstance(data, torch.Tensor):
data = data.detach().cpu().numpy()

filename = create_file_basename(self.output_postfix, filename, self.output_dir)
filename = monai.data.utils.create_file_basename(self.output_postfix, filename, self.output_dir)
filename = f"{filename}{self.output_ext}"
# change data shape to be (channel, h, w, d)
while len(data.shape) < 4:
Expand Down
14 changes: 7 additions & 7 deletions monai/data/nifti_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
import numpy as np
import torch

import monai.data.utils
from monai.config import DtypeLike
from monai.data.utils import compute_shape_offset, to_affine_nd
from monai.networks.layers import AffineTransform
from monai.utils import GridSampleMode, GridSamplePadMode, optional_import

Expand Down Expand Up @@ -95,15 +95,15 @@ def write_nifti(
sr = min(data.ndim, 3)
if affine is None:
affine = np.eye(4, dtype=np.float64)
affine = to_affine_nd(sr, affine)
affine = monai.data.utils.to_affine_nd(sr, affine)

if target_affine is None:
target_affine = affine
target_affine = to_affine_nd(sr, target_affine)
target_affine = monai.data.utils.to_affine_nd(sr, target_affine)

if np.allclose(affine, target_affine, atol=1e-3):
# no affine changes, save (data, affine)
results_img = nib.Nifti1Image(data.astype(output_dtype), to_affine_nd(3, target_affine))
results_img = nib.Nifti1Image(data.astype(output_dtype), monai.data.utils.to_affine_nd(3, target_affine))
nib.save(results_img, file_name)
return

Expand All @@ -115,7 +115,7 @@ def write_nifti(
data = nib.orientations.apply_orientation(data, ornt_transform)
_affine = affine @ nib.orientations.inv_ornt_aff(ornt_transform, data_shape)
if np.allclose(_affine, target_affine, atol=1e-3) or not resample:
results_img = nib.Nifti1Image(data.astype(output_dtype), to_affine_nd(3, target_affine))
results_img = nib.Nifti1Image(data.astype(output_dtype), monai.data.utils.to_affine_nd(3, target_affine))
nib.save(results_img, file_name)
return

Expand All @@ -125,7 +125,7 @@ def write_nifti(
)
transform = np.linalg.inv(_affine) @ target_affine
if output_spatial_shape is None:
output_spatial_shape, _ = compute_shape_offset(data.shape, _affine, target_affine)
output_spatial_shape, _ = monai.data.utils.compute_shape_offset(data.shape, _affine, target_affine)
output_spatial_shape_ = list(output_spatial_shape) if output_spatial_shape is not None else []
if data.ndim > 3: # multi channel, resampling each channel
while len(output_spatial_shape_) < 3:
Expand All @@ -151,6 +151,6 @@ def write_nifti(
)
data_np = data_torch.squeeze(0).squeeze(0).detach().cpu().numpy()

results_img = nib.Nifti1Image(data_np.astype(output_dtype), to_affine_nd(3, target_affine))
results_img = nib.Nifti1Image(data_np.astype(output_dtype), monai.data.utils.to_affine_nd(3, target_affine))
nib.save(results_img, file_name)
return
4 changes: 2 additions & 2 deletions monai/data/png_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
import numpy as np
import torch

import monai.data.utils
from monai.data.png_writer import write_png
from monai.data.utils import create_file_basename
from monai.utils import ImageMetaKey as Key
from monai.utils import InterpolateMode

Expand Down Expand Up @@ -90,7 +90,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict]
if isinstance(data, torch.Tensor):
data = data.detach().cpu().numpy()

filename = create_file_basename(self.output_postfix, filename, self.output_dir)
filename = monai.data.utils.create_file_basename(self.output_postfix, filename, self.output_dir)
filename = f"{filename}{self.output_ext}"

if data.shape[0] == 1:
Expand Down
116 changes: 116 additions & 0 deletions monai/data/test_time_augmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict

import numpy as np
import torch

from monai.data.dataloader import DataLoader
from monai.data.dataset import Dataset
from monai.data.inverse_batch_transform import BatchInverseTransform
from monai.data.utils import pad_list_data_collate
from monai.transforms.compose import Compose
from monai.transforms.inverse_transform import InvertibleTransform
from monai.transforms.transform import Randomizable

__all__ = ["TestTimeAugmentation"]


def is_transform_rand(transform):
if not isinstance(transform, Compose):
return isinstance(transform, Randomizable)
# call recursively for each sub-transform
return any(is_transform_rand(t) for t in transform.transforms)


class TestTimeAugmentation:
def __init__(
self,
transform: InvertibleTransform,
batch_size,
num_workers,
inferrer_fn,
device,
) -> None:
self.transform = transform
self.batch_size = batch_size
self.num_workers = num_workers
self.inferrer_fn = inferrer_fn
self.device = device

# check that the transform has at least one random component
if not is_transform_rand(self.transform):
raise RuntimeError(
type(self).__name__
+ " requires a `Randomizable` transform or a"
+ " `Compose` containing at least one `Randomizable` transform."
)

def __call__(
self, data: Dict[str, Any], num_examples=10, image_key="image", label_key="label", return_full_data=False
):
d = dict(data)

# check num examples is multiple of batch size
if num_examples % self.batch_size != 0:
raise ValueError("num_examples should be multiple of batch size.")

# generate batch of data of size == batch_size, dataset and dataloader
data_in = [d for _ in range(num_examples)]
ds = Dataset(data_in, self.transform)
dl = DataLoader(ds, self.num_workers, batch_size=self.batch_size, collate_fn=pad_list_data_collate)

label_transform_key = label_key + "_transforms"

# create inverter
inverter = BatchInverseTransform(self.transform, dl)

outputs = []

for batch_data in dl:

batch_images = batch_data[image_key].to(self.device)

# do model forward pass
batch_output = self.inferrer_fn(batch_images)
if isinstance(batch_output, torch.Tensor):
batch_output = batch_output.detach().cpu()
if isinstance(batch_output, np.ndarray):
batch_output = torch.Tensor(batch_output)

# check binary labels are extracted
if not all(torch.unique(batch_output.int()) == torch.Tensor([0, 1])):
raise RuntimeError(
"Test-time augmentation requires binary channels. If this is "
"not binary segmentation, then you should one-hot your output."
)

# create a dictionary containing the inferred batch and their transforms
inferred_dict = {label_key: batch_output, label_transform_key: batch_data[label_transform_key]}

# do inverse transformation (only for the label key)
inv_batch = inverter(inferred_dict, label_key)

# append
outputs.append(inv_batch)

# calculate mean and standard deviation
output = np.concatenate(outputs)

if return_full_data:
return output

mode = np.apply_along_axis(lambda x: np.bincount(x).argmax(), axis=0, arr=output.astype(np.int64))
mean = np.mean(output, axis=0)
std = np.std(output, axis=0)
vvc = np.std(output) / np.mean(output)
return mode, mean, std, vvc
Loading