Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions docs/source/highlights.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,27 @@ affine = Affine(
# convert the image using interpolation mode
new_img = affine(image, spatial_size=(300, 400), mode='bilinear')
```

### 4. Randomly crop out batch images based on positive/negative ratio
Medical image data volume may be too large to fit into GPU memory. A widely-used approach is to randomly draw small size data samples during training. MONAI currrently provides uniform random sampling strategy as well as class-balanced fixed ratio sampling which may help stabilize the patch-based training process.

### 5. Deterministic training for reproducibility
Deterministic training support is necessary and important in DL research area, especially when sharing reproducible work with others. Users can easily set random seed to all the transforms in MONAI locally and will not affect other non-deterministic modules in the user's program.
Example code:
```py
# define a transform chain for pre-processing
train_transforms = monai.transforms.compose.Compose([
LoadNiftid(keys=['image', 'label']),
... ...
])
# set determinism for reproducibility
train_transforms.set_random_state(seed=0)
np.random.seed(0)
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
```

## Losses
There are domain-specific loss functions in the medical research area which are different from the generic computer vision ones. As an important module of MONAI, these loss functions are implemented in PyTorch, such as Dice loss and generalized Dice loss.

Expand Down
8 changes: 4 additions & 4 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ Vanilla Transforms
:members:
:special-members: __call__

`UniformRandomPatch`
`RandUniformPatch`
~~~~~~~~~~~~~~~~~~~~
.. autoclass:: UniformRandomPatch
.. autoclass:: RandUniformPatch
:members:
:special-members: __call__

Expand Down Expand Up @@ -280,9 +280,9 @@ Dictionary-based Composables
:members:
:special-members: __call__

`UniformRandomPatchd`
`RandUniformPatchd`
~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: UniformRandomPatchd
.. autoclass:: RandUniformPatchd
:members:
:special-members: __call__

Expand Down
7 changes: 3 additions & 4 deletions examples/integrate_to_spleen_program.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 1,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -121,8 +121,8 @@
" RandCropByPosNegLabeld(keys=['image', 'label'], label_key='label', size=(96, 96, 96), pos=1,\n",
" neg=1, num_samples=4, image_key='image', image_threshold=0),\n",
" # user can also add other random transforms\n",
" # RandAffined(keys=['image', 'label'], spatial_size=(96, 96, 96), prob=0.5,\n",
" # rotate_range=(np.pi / 15, np.pi / 15, np.pi / 15), scale_range=(0.1, 0.1, 0.1))\n",
" # RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=1.0, spatial_size=(96, 96, 96),\n",
" # rotate_range=(0, 0, np.pi/15), scale_range=(0.1, 0.1, 0.1))\n",
"])\n",
"val_transforms = transforms.Compose([\n",
" LoadNiftid(keys=['image', 'label']),\n",
Expand All @@ -145,7 +145,6 @@
"outputs": [],
"source": [
"train_transforms.set_random_state(seed=0)\n",
"np.random.seed(0)\n",
"torch.manual_seed(0)\n",
"torch.backends.cudnn.deterministic = True\n",
"torch.backends.cudnn.benchmark = False"
Expand Down
6 changes: 3 additions & 3 deletions examples/nifti_read_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
"\n",
"from monai.transforms.utils import rescale_array\n",
"from monai.data.nifti_reader import NiftiDataset\n",
"from monai.transforms import AddChannel, Transpose, Rescale, ToTensor, UniformRandomPatch\n",
"from monai.transforms import AddChannel, Transpose, Rescale, ToTensor, RandUniformPatch\n",
"from monai.data.grid_dataset import GridPatchDataset\n",
"\n",
"monai.config.print_config()"
Expand Down Expand Up @@ -135,13 +135,13 @@
"imtrans=transforms.Compose([\n",
" Rescale(),\n",
" AddChannel(),\n",
" UniformRandomPatch((64, 64, 64)),\n",
" RandUniformPatch((64, 64, 64)),\n",
" ToTensor()\n",
"]) \n",
"\n",
"segtrans=transforms.Compose([\n",
" AddChannel(),\n",
" UniformRandomPatch((64, 64, 64)),\n",
" RandUniformPatch((64, 64, 64)),\n",
" ToTensor()\n",
"]) \n",
" \n",
Expand Down
6 changes: 3 additions & 3 deletions examples/segmentation_3d/unet_training_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import monai.transforms.compose as transforms

from monai.data.nifti_reader import NiftiDataset
from monai.transforms import AddChannel, Rescale, UniformRandomPatch, Resize
from monai.transforms import AddChannel, Rescale, RandUniformPatch, Resize
from monai.handlers.stats_handler import StatsHandler
from monai.handlers.tensorboard_handlers import TensorBoardStatsHandler, TensorBoardImageHandler
from monai.handlers.mean_dice import MeanDice
Expand Down Expand Up @@ -55,11 +55,11 @@
train_imtrans = transforms.Compose([
Rescale(),
AddChannel(),
UniformRandomPatch((96, 96, 96))
RandUniformPatch((96, 96, 96))
])
train_segtrans = transforms.Compose([
AddChannel(),
UniformRandomPatch((96, 96, 96))
RandUniformPatch((96, 96, 96))
])
val_imtrans = transforms.Compose([
Rescale(),
Expand Down
6 changes: 3 additions & 3 deletions examples/transform_speed.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
"from monai.transforms.compose import Compose\n",
"from monai.data.nifti_reader import NiftiDataset\n",
"from monai.transforms import (AddChannel, Rescale, ToTensor, \n",
" UniformRandomPatch, Rotate, RandAffine)\n",
" RandUniformPatch, Rotate, RandAffine)\n",
"\n",
"monai.config.print_config()"
]
Expand Down Expand Up @@ -180,14 +180,14 @@
" Rescale(),\n",
" AddChannel(),\n",
" Rotate(angle=45.),\n",
" UniformRandomPatch((64, 64, 64)),\n",
" RandUniformPatch((64, 64, 64)),\n",
" ToTensor()\n",
"]) \n",
"\n",
"segtrans = Compose([\n",
" AddChannel(),\n",
" Rotate(angle=45.),\n",
" UniformRandomPatch((64, 64, 64)),\n",
" RandUniformPatch((64, 64, 64)),\n",
" ToTensor()\n",
"]) \n",
" \n",
Expand Down
6 changes: 3 additions & 3 deletions examples/unet_segmentation_3d.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
"import monai.transforms.compose as transforms\n",
"\n",
"from monai.data.nifti_reader import NiftiDataset\n",
"from monai.transforms import (AddChannel, Rescale, ToTensor, UniformRandomPatch)\n",
"from monai.transforms import (AddChannel, Rescale, ToTensor, RandUniformPatch)\n",
"from monai.handlers.stats_handler import StatsHandler\n",
"from monai.handlers.mean_dice import MeanDice\n",
"from monai.visualize import img2tensorboard\n",
Expand Down Expand Up @@ -107,12 +107,12 @@
"imtrans = transforms.Compose([\n",
" Rescale(), \n",
" AddChannel(), \n",
" UniformRandomPatch((96, 96, 96)), \n",
" RandUniformPatch((96, 96, 96)), \n",
" ToTensor()\n",
"])\n",
"segtrans = transforms.Compose([\n",
" AddChannel(), \n",
" UniformRandomPatch((96, 96, 96)), \n",
" RandUniformPatch((96, 96, 96)), \n",
" ToTensor()\n",
"])\n",
"\n",
Expand Down
4 changes: 2 additions & 2 deletions monai/transforms/composables.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,8 @@ def __call__(self, data):


@export
@alias('UniformRandomPatchD', 'UniformRandomPatchDict')
class UniformRandomPatchd(Randomizable, MapTransform):
@alias('RandUniformPatchD', 'RandUniformPatchDict')
class RandUniformPatchd(Randomizable, MapTransform):
"""
Selects a patch of the given size chosen at a uniformly random position in the image.

Expand Down
2 changes: 1 addition & 1 deletion monai/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def __call__(self, img):


@export
class UniformRandomPatch(Randomizable):
class RandUniformPatch(Randomizable):
"""
Selects a patch of the given size chosen at a uniformly random position in the image.

Expand Down
13 changes: 9 additions & 4 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
import random

import numpy as np
Expand Down Expand Up @@ -202,12 +203,16 @@ def generate_pos_neg_label_crop_centers(label, size, num_samples, pos_ratio, ima
else:
bg_indicies = np.nonzero(~label_flat)[0]

if not len(fg_indicies) or not len(bg_indicies):
if not len(fg_indicies) and not len(bg_indicies):
raise ValueError('no sampling location available.')
warnings.warn('N foreground {}, N background {}, unable to generate class balanced samples.'.format(
len(fg_indicies), len(bg_indicies)))
pos_ratio = 0 if not len(fg_indicies) else 1

centers = []
for _ in range(num_samples):
if rand_state.rand() < pos_ratio:
indicies_to_use = fg_indicies
else:
indicies_to_use = bg_indicies
indicies_to_use = fg_indicies if rand_state.rand() < pos_ratio else bg_indicies
random_int = rand_state.randint(len(indicies_to_use))
center = np.unravel_index(indicies_to_use[random_int], label.shape)
center = center[1:]
Expand Down
88 changes: 88 additions & 0 deletions tests/integration_determinism.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys

import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from monai.transforms import AddChannel, Rescale, RandUniformPatch, RandRotate90
import monai.transforms.compose as transforms
from monai.data.synthetic import create_test_image_2d
from monai.losses.dice import DiceLoss
from monai.networks.nets.unet import UNet


def run_test(batch_size=64, train_steps=100, device=torch.device("cuda:0")):

class _TestBatch(Dataset):

def __init__(self, transforms):
self.transforms = transforms

def __getitem__(self, _unused_id):
im, seg = create_test_image_2d(128, 128, noise_max=1, num_objs=4, num_seg_classes=1)
seed = np.random.randint(2147483647)
self.transforms.set_random_state(seed=seed)
im = self.transforms(im)
self.transforms.set_random_state(seed=seed)
seg = self.transforms(seg)
return im, seg

def __len__(self):
return train_steps

net = UNet(
dimensions=2,
in_channels=1,
out_channels=1,
channels=(4, 8, 16, 32),
strides=(2, 2, 2),
num_res_units=2,
).to(device)

loss = DiceLoss(do_sigmoid=True)
opt = torch.optim.Adam(net.parameters(), 1e-2)
train_transforms = transforms.Compose([
AddChannel(),
Rescale(),
RandUniformPatch((96, 96)),
RandRotate90()
])

src = DataLoader(_TestBatch(train_transforms), batch_size=batch_size)

net.train()
epoch_loss = 0
step = 0
for img, seg in src:
step += 1
opt.zero_grad()
output = net(img.to(device))
step_loss = loss(output, seg.to(device))
step_loss.backward()
opt.step()
epoch_loss += step_loss.item()
epoch_loss /= step

print('Loss:', epoch_loss)
result = np.allclose(epoch_loss, 0.578675)
if result is False:
print('Loss value is wrong, expect to be 0.578675.')
return result


if __name__ == "__main__":
np.random.seed(0)
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
sys.exit(0 if run_test() is True else 1)
4 changes: 2 additions & 2 deletions tests/integration_sliding_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from tests.utils import make_nifti_image


def run_test(batch_size=2, device=torch.device("cpu:0")):
def run_test(batch_size=2, device=torch.device("cuda:0")):

im, seg = create_test_image_3d(25, 28, 63, rad_max=10, noise_max=1, num_objs=4, num_seg_classes=1)
input_shape = im.shape
Expand All @@ -44,7 +44,7 @@ def run_test(batch_size=2, device=torch.device("cpu:0")):
channels=(4, 8, 16, 32),
strides=(2, 2, 2),
num_res_units=2,
)
).to(device)
roi_size = (16, 32, 48)
sw_batch_size = batch_size

Expand Down
2 changes: 1 addition & 1 deletion tests/integration_unet2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __len__(self):
channels=(4, 8, 16, 32),
strides=(2, 2, 2),
num_res_units=2,
)
).to(device)

loss = DiceLoss(do_sigmoid=True)
opt = torch.optim.Adam(net.parameters(), 1e-4)
Expand Down
24 changes: 23 additions & 1 deletion tests/test_rand_crop_by_pos_neg_labeld.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,32 @@
(3, 2, 2, 2),
]

TEST_CASE_2 = [
{
'keys': ['image', 'extral', 'label'],
'label_key': 'label',
'size': [2, 2, 2],
'pos': 1,
'neg': 1,
'num_samples': 2,
'image_key': None,
'image_threshold': 0
},
{
'image': np.zeros([3, 3, 3, 3]) - 1,
'extral': np.zeros([3, 3, 3, 3]),
'label': np.ones([3, 3, 3, 3]),
'affine': np.eye(3),
'shape': 'CHWD'
},
list,
(3, 2, 2, 2),
]


class TestRandCropByPosNegLabeld(unittest.TestCase):

@parameterized.expand([TEST_CASE_1])
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
def test_type_shape(self, input_param, input_data, expected_type, expected_shape):
result = RandCropByPosNegLabeld(**input_param)(input_data)
self.assertIsInstance(result, expected_type)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@

import numpy as np

from monai.transforms.transforms import UniformRandomPatch
from monai.transforms.transforms import RandUniformPatch
from tests.utils import NumpyImageTestCase2D


class TestUniformRandomPatch(NumpyImageTestCase2D):
class TestRandUniformPatch(NumpyImageTestCase2D):

def test_2d(self):
patch_spatial_size = (10, 10)
patch_transform = UniformRandomPatch(patch_spatial_size=patch_spatial_size)
patch_transform = RandUniformPatch(patch_spatial_size=patch_spatial_size)
patch = patch_transform(self.imt[0])
self.assertTrue(np.allclose(patch.shape[1:], patch_spatial_size))

Expand Down
Loading