diff --git a/docs/source/highlights.md b/docs/source/highlights.md index fadb8df622..f7895bf65d 100644 --- a/docs/source/highlights.md +++ b/docs/source/highlights.md @@ -47,9 +47,27 @@ affine = Affine( # convert the image using interpolation mode new_img = affine(image, spatial_size=(300, 400), mode='bilinear') ``` + ### 4. Randomly crop out batch images based on positive/negative ratio Medical image data volume may be too large to fit into GPU memory. A widely-used approach is to randomly draw small size data samples during training. MONAI currrently provides uniform random sampling strategy as well as class-balanced fixed ratio sampling which may help stabilize the patch-based training process. +### 5. Deterministic training for reproducibility +Deterministic training support is necessary and important in DL research area, especially when sharing reproducible work with others. Users can easily set random seed to all the transforms in MONAI locally and will not affect other non-deterministic modules in the user's program. +Example code: +```py +# define a transform chain for pre-processing +train_transforms = monai.transforms.compose.Compose([ + LoadNiftid(keys=['image', 'label']), + ... ... +]) +# set determinism for reproducibility +train_transforms.set_random_state(seed=0) +np.random.seed(0) +torch.manual_seed(0) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False +``` + ## Losses There are domain-specific loss functions in the medical research area which are different from the generic computer vision ones. As an important module of MONAI, these loss functions are implemented in PyTorch, such as Dice loss and generalized Dice loss. diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 9c48b9f08b..de5bb45975 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -123,9 +123,9 @@ Vanilla Transforms :members: :special-members: __call__ -`UniformRandomPatch` +`RandUniformPatch` ~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: UniformRandomPatch +.. autoclass:: RandUniformPatch :members: :special-members: __call__ @@ -280,9 +280,9 @@ Dictionary-based Composables :members: :special-members: __call__ -`UniformRandomPatchd` +`RandUniformPatchd` ~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: UniformRandomPatchd +.. autoclass:: RandUniformPatchd :members: :special-members: __call__ diff --git a/examples/integrate_to_spleen_program.ipynb b/examples/integrate_to_spleen_program.ipynb index b43cda61b1..887e48d7f9 100644 --- a/examples/integrate_to_spleen_program.ipynb +++ b/examples/integrate_to_spleen_program.ipynb @@ -33,7 +33,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 1, "metadata": {}, "outputs": [ { @@ -121,8 +121,8 @@ " RandCropByPosNegLabeld(keys=['image', 'label'], label_key='label', size=(96, 96, 96), pos=1,\n", " neg=1, num_samples=4, image_key='image', image_threshold=0),\n", " # user can also add other random transforms\n", - " # RandAffined(keys=['image', 'label'], spatial_size=(96, 96, 96), prob=0.5,\n", - " # rotate_range=(np.pi / 15, np.pi / 15, np.pi / 15), scale_range=(0.1, 0.1, 0.1))\n", + " # RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=1.0, spatial_size=(96, 96, 96),\n", + " # rotate_range=(0, 0, np.pi/15), scale_range=(0.1, 0.1, 0.1))\n", "])\n", "val_transforms = transforms.Compose([\n", " LoadNiftid(keys=['image', 'label']),\n", @@ -145,7 +145,6 @@ "outputs": [], "source": [ "train_transforms.set_random_state(seed=0)\n", - "np.random.seed(0)\n", "torch.manual_seed(0)\n", "torch.backends.cudnn.deterministic = True\n", "torch.backends.cudnn.benchmark = False" diff --git a/examples/nifti_read_example.ipynb b/examples/nifti_read_example.ipynb index c83ae70429..f50f838156 100644 --- a/examples/nifti_read_example.ipynb +++ b/examples/nifti_read_example.ipynb @@ -41,7 +41,7 @@ "\n", "from monai.transforms.utils import rescale_array\n", "from monai.data.nifti_reader import NiftiDataset\n", - "from monai.transforms import AddChannel, Transpose, Rescale, ToTensor, UniformRandomPatch\n", + "from monai.transforms import AddChannel, Transpose, Rescale, ToTensor, RandUniformPatch\n", "from monai.data.grid_dataset import GridPatchDataset\n", "\n", "monai.config.print_config()" @@ -135,13 +135,13 @@ "imtrans=transforms.Compose([\n", " Rescale(),\n", " AddChannel(),\n", - " UniformRandomPatch((64, 64, 64)),\n", + " RandUniformPatch((64, 64, 64)),\n", " ToTensor()\n", "]) \n", "\n", "segtrans=transforms.Compose([\n", " AddChannel(),\n", - " UniformRandomPatch((64, 64, 64)),\n", + " RandUniformPatch((64, 64, 64)),\n", " ToTensor()\n", "]) \n", " \n", diff --git a/examples/segmentation_3d/unet_training_array.py b/examples/segmentation_3d/unet_training_array.py index 92e00ec13b..fd700b23f3 100644 --- a/examples/segmentation_3d/unet_training_array.py +++ b/examples/segmentation_3d/unet_training_array.py @@ -25,7 +25,7 @@ import monai.transforms.compose as transforms from monai.data.nifti_reader import NiftiDataset -from monai.transforms import AddChannel, Rescale, UniformRandomPatch, Resize +from monai.transforms import AddChannel, Rescale, RandUniformPatch, Resize from monai.handlers.stats_handler import StatsHandler from monai.handlers.tensorboard_handlers import TensorBoardStatsHandler, TensorBoardImageHandler from monai.handlers.mean_dice import MeanDice @@ -55,11 +55,11 @@ train_imtrans = transforms.Compose([ Rescale(), AddChannel(), - UniformRandomPatch((96, 96, 96)) + RandUniformPatch((96, 96, 96)) ]) train_segtrans = transforms.Compose([ AddChannel(), - UniformRandomPatch((96, 96, 96)) + RandUniformPatch((96, 96, 96)) ]) val_imtrans = transforms.Compose([ Rescale(), diff --git a/examples/transform_speed.ipynb b/examples/transform_speed.ipynb index 2f4974d527..6d486d1850 100644 --- a/examples/transform_speed.ipynb +++ b/examples/transform_speed.ipynb @@ -50,7 +50,7 @@ "from monai.transforms.compose import Compose\n", "from monai.data.nifti_reader import NiftiDataset\n", "from monai.transforms import (AddChannel, Rescale, ToTensor, \n", - " UniformRandomPatch, Rotate, RandAffine)\n", + " RandUniformPatch, Rotate, RandAffine)\n", "\n", "monai.config.print_config()" ] @@ -180,14 +180,14 @@ " Rescale(),\n", " AddChannel(),\n", " Rotate(angle=45.),\n", - " UniformRandomPatch((64, 64, 64)),\n", + " RandUniformPatch((64, 64, 64)),\n", " ToTensor()\n", "]) \n", "\n", "segtrans = Compose([\n", " AddChannel(),\n", " Rotate(angle=45.),\n", - " UniformRandomPatch((64, 64, 64)),\n", + " RandUniformPatch((64, 64, 64)),\n", " ToTensor()\n", "]) \n", " \n", diff --git a/examples/unet_segmentation_3d.ipynb b/examples/unet_segmentation_3d.ipynb index 522250f419..00dceb150a 100644 --- a/examples/unet_segmentation_3d.ipynb +++ b/examples/unet_segmentation_3d.ipynb @@ -43,7 +43,7 @@ "import monai.transforms.compose as transforms\n", "\n", "from monai.data.nifti_reader import NiftiDataset\n", - "from monai.transforms import (AddChannel, Rescale, ToTensor, UniformRandomPatch)\n", + "from monai.transforms import (AddChannel, Rescale, ToTensor, RandUniformPatch)\n", "from monai.handlers.stats_handler import StatsHandler\n", "from monai.handlers.mean_dice import MeanDice\n", "from monai.visualize import img2tensorboard\n", @@ -107,12 +107,12 @@ "imtrans = transforms.Compose([\n", " Rescale(), \n", " AddChannel(), \n", - " UniformRandomPatch((96, 96, 96)), \n", + " RandUniformPatch((96, 96, 96)), \n", " ToTensor()\n", "])\n", "segtrans = transforms.Compose([\n", " AddChannel(), \n", - " UniformRandomPatch((96, 96, 96)), \n", + " RandUniformPatch((96, 96, 96)), \n", " ToTensor()\n", "])\n", "\n", diff --git a/monai/transforms/composables.py b/monai/transforms/composables.py index c84c81653b..6990df5f56 100644 --- a/monai/transforms/composables.py +++ b/monai/transforms/composables.py @@ -278,8 +278,8 @@ def __call__(self, data): @export -@alias('UniformRandomPatchD', 'UniformRandomPatchDict') -class UniformRandomPatchd(Randomizable, MapTransform): +@alias('RandUniformPatchD', 'RandUniformPatchDict') +class RandUniformPatchd(Randomizable, MapTransform): """ Selects a patch of the given size chosen at a uniformly random position in the image. diff --git a/monai/transforms/transforms.py b/monai/transforms/transforms.py index b1708c3e3a..f0ee4a9322 100644 --- a/monai/transforms/transforms.py +++ b/monai/transforms/transforms.py @@ -493,7 +493,7 @@ def __call__(self, img): @export -class UniformRandomPatch(Randomizable): +class RandUniformPatch(Randomizable): """ Selects a patch of the given size chosen at a uniformly random position in the image. diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index b8ce70d854..f7a2f24501 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings import random import numpy as np @@ -202,12 +203,16 @@ def generate_pos_neg_label_crop_centers(label, size, num_samples, pos_ratio, ima else: bg_indicies = np.nonzero(~label_flat)[0] + if not len(fg_indicies) or not len(bg_indicies): + if not len(fg_indicies) and not len(bg_indicies): + raise ValueError('no sampling location available.') + warnings.warn('N foreground {}, N background {}, unable to generate class balanced samples.'.format( + len(fg_indicies), len(bg_indicies))) + pos_ratio = 0 if not len(fg_indicies) else 1 + centers = [] for _ in range(num_samples): - if rand_state.rand() < pos_ratio: - indicies_to_use = fg_indicies - else: - indicies_to_use = bg_indicies + indicies_to_use = fg_indicies if rand_state.rand() < pos_ratio else bg_indicies random_int = rand_state.randint(len(indicies_to_use)) center = np.unravel_index(indicies_to_use[random_int], label.shape) center = center[1:] diff --git a/tests/integration_determinism.py b/tests/integration_determinism.py new file mode 100644 index 0000000000..5e31d10fab --- /dev/null +++ b/tests/integration_determinism.py @@ -0,0 +1,88 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +import numpy as np +import torch +from torch.utils.data import DataLoader, Dataset +from monai.transforms import AddChannel, Rescale, RandUniformPatch, RandRotate90 +import monai.transforms.compose as transforms +from monai.data.synthetic import create_test_image_2d +from monai.losses.dice import DiceLoss +from monai.networks.nets.unet import UNet + + +def run_test(batch_size=64, train_steps=100, device=torch.device("cuda:0")): + + class _TestBatch(Dataset): + + def __init__(self, transforms): + self.transforms = transforms + + def __getitem__(self, _unused_id): + im, seg = create_test_image_2d(128, 128, noise_max=1, num_objs=4, num_seg_classes=1) + seed = np.random.randint(2147483647) + self.transforms.set_random_state(seed=seed) + im = self.transforms(im) + self.transforms.set_random_state(seed=seed) + seg = self.transforms(seg) + return im, seg + + def __len__(self): + return train_steps + + net = UNet( + dimensions=2, + in_channels=1, + out_channels=1, + channels=(4, 8, 16, 32), + strides=(2, 2, 2), + num_res_units=2, + ).to(device) + + loss = DiceLoss(do_sigmoid=True) + opt = torch.optim.Adam(net.parameters(), 1e-2) + train_transforms = transforms.Compose([ + AddChannel(), + Rescale(), + RandUniformPatch((96, 96)), + RandRotate90() + ]) + + src = DataLoader(_TestBatch(train_transforms), batch_size=batch_size) + + net.train() + epoch_loss = 0 + step = 0 + for img, seg in src: + step += 1 + opt.zero_grad() + output = net(img.to(device)) + step_loss = loss(output, seg.to(device)) + step_loss.backward() + opt.step() + epoch_loss += step_loss.item() + epoch_loss /= step + + print('Loss:', epoch_loss) + result = np.allclose(epoch_loss, 0.578675) + if result is False: + print('Loss value is wrong, expect to be 0.578675.') + return result + + +if __name__ == "__main__": + np.random.seed(0) + torch.manual_seed(0) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + sys.exit(0 if run_test() is True else 1) diff --git a/tests/integration_sliding_window.py b/tests/integration_sliding_window.py index b99bb5c681..91fd2994be 100644 --- a/tests/integration_sliding_window.py +++ b/tests/integration_sliding_window.py @@ -28,7 +28,7 @@ from tests.utils import make_nifti_image -def run_test(batch_size=2, device=torch.device("cpu:0")): +def run_test(batch_size=2, device=torch.device("cuda:0")): im, seg = create_test_image_3d(25, 28, 63, rad_max=10, noise_max=1, num_objs=4, num_seg_classes=1) input_shape = im.shape @@ -44,7 +44,7 @@ def run_test(batch_size=2, device=torch.device("cpu:0")): channels=(4, 8, 16, 32), strides=(2, 2, 2), num_res_units=2, - ) + ).to(device) roi_size = (16, 32, 48) sw_batch_size = batch_size diff --git a/tests/integration_unet2d.py b/tests/integration_unet2d.py index d437258c10..819be91e4d 100644 --- a/tests/integration_unet2d.py +++ b/tests/integration_unet2d.py @@ -39,7 +39,7 @@ def __len__(self): channels=(4, 8, 16, 32), strides=(2, 2, 2), num_res_units=2, - ) + ).to(device) loss = DiceLoss(do_sigmoid=True) opt = torch.optim.Adam(net.parameters(), 1e-4) diff --git a/tests/test_rand_crop_by_pos_neg_labeld.py b/tests/test_rand_crop_by_pos_neg_labeld.py index 2292dfed9f..f83d737873 100644 --- a/tests/test_rand_crop_by_pos_neg_labeld.py +++ b/tests/test_rand_crop_by_pos_neg_labeld.py @@ -36,10 +36,32 @@ (3, 2, 2, 2), ] +TEST_CASE_2 = [ + { + 'keys': ['image', 'extral', 'label'], + 'label_key': 'label', + 'size': [2, 2, 2], + 'pos': 1, + 'neg': 1, + 'num_samples': 2, + 'image_key': None, + 'image_threshold': 0 + }, + { + 'image': np.zeros([3, 3, 3, 3]) - 1, + 'extral': np.zeros([3, 3, 3, 3]), + 'label': np.ones([3, 3, 3, 3]), + 'affine': np.eye(3), + 'shape': 'CHWD' + }, + list, + (3, 2, 2, 2), +] + class TestRandCropByPosNegLabeld(unittest.TestCase): - @parameterized.expand([TEST_CASE_1]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_type_shape(self, input_param, input_data, expected_type, expected_shape): result = RandCropByPosNegLabeld(**input_param)(input_data) self.assertIsInstance(result, expected_type) diff --git a/tests/test_uniform_rand_patch.py b/tests/test_rand_uniform_patch.py similarity index 82% rename from tests/test_uniform_rand_patch.py rename to tests/test_rand_uniform_patch.py index 32b1077ef7..923f302c97 100644 --- a/tests/test_uniform_rand_patch.py +++ b/tests/test_rand_uniform_patch.py @@ -13,15 +13,15 @@ import numpy as np -from monai.transforms.transforms import UniformRandomPatch +from monai.transforms.transforms import RandUniformPatch from tests.utils import NumpyImageTestCase2D -class TestUniformRandomPatch(NumpyImageTestCase2D): +class TestRandUniformPatch(NumpyImageTestCase2D): def test_2d(self): patch_spatial_size = (10, 10) - patch_transform = UniformRandomPatch(patch_spatial_size=patch_spatial_size) + patch_transform = RandUniformPatch(patch_spatial_size=patch_spatial_size) patch = patch_transform(self.imt[0]) self.assertTrue(np.allclose(patch.shape[1:], patch_spatial_size)) diff --git a/tests/test_uniform_rand_patchd.py b/tests/test_rand_uniform_patchd.py similarity index 79% rename from tests/test_uniform_rand_patchd.py rename to tests/test_rand_uniform_patchd.py index a87438acce..b68d555178 100644 --- a/tests/test_uniform_rand_patchd.py +++ b/tests/test_rand_uniform_patchd.py @@ -13,16 +13,16 @@ import numpy as np -from monai.transforms.composables import UniformRandomPatchd +from monai.transforms.composables import RandUniformPatchd from tests.utils import NumpyImageTestCase2D -class TestUniformRandomPatchd(NumpyImageTestCase2D): +class TestRandUniformPatchd(NumpyImageTestCase2D): def test_2d(self): patch_spatial_size = (10, 10) key = 'test' - patch_transform = UniformRandomPatchd(keys='test', patch_spatial_size=patch_spatial_size) + patch_transform = RandUniformPatchd(keys='test', patch_spatial_size=patch_spatial_size) patch = patch_transform({key: self.imt[0]}) self.assertTrue(np.allclose(patch[key].shape[1:], patch_spatial_size)) @@ -30,7 +30,7 @@ def test_sync(self): patch_spatial_size = (4, 4) key_1, key_2 = 'foo', 'bar' rand_image = np.random.rand(3, 10, 10) - patch_transform = UniformRandomPatchd(keys=(key_1, key_2), patch_spatial_size=patch_spatial_size) + patch_transform = RandUniformPatchd(keys=(key_1, key_2), patch_spatial_size=patch_spatial_size) patch = patch_transform({key_1: rand_image, key_2: rand_image}) self.assertTrue(np.allclose(patch[key_1], patch[key_2]))