From 2d420d1e039356ba8574bb778cdb00b7df59bd44 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 20 Mar 2020 11:38:10 +0800 Subject: [PATCH 1/6] [DLMED] add determinism integration test and fix bugs --- examples/integrate_to_spleen_program.ipynb | 7 +-- tests/integration_determinism.py | 73 ++++++++++++++++++++++ tests/integration_sliding_window.py | 4 +- tests/integration_unet2d.py | 2 +- 4 files changed, 79 insertions(+), 7 deletions(-) create mode 100644 tests/integration_determinism.py diff --git a/examples/integrate_to_spleen_program.ipynb b/examples/integrate_to_spleen_program.ipynb index b43cda61b1..887e48d7f9 100644 --- a/examples/integrate_to_spleen_program.ipynb +++ b/examples/integrate_to_spleen_program.ipynb @@ -33,7 +33,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 1, "metadata": {}, "outputs": [ { @@ -121,8 +121,8 @@ " RandCropByPosNegLabeld(keys=['image', 'label'], label_key='label', size=(96, 96, 96), pos=1,\n", " neg=1, num_samples=4, image_key='image', image_threshold=0),\n", " # user can also add other random transforms\n", - " # RandAffined(keys=['image', 'label'], spatial_size=(96, 96, 96), prob=0.5,\n", - " # rotate_range=(np.pi / 15, np.pi / 15, np.pi / 15), scale_range=(0.1, 0.1, 0.1))\n", + " # RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=1.0, spatial_size=(96, 96, 96),\n", + " # rotate_range=(0, 0, np.pi/15), scale_range=(0.1, 0.1, 0.1))\n", "])\n", "val_transforms = transforms.Compose([\n", " LoadNiftid(keys=['image', 'label']),\n", @@ -145,7 +145,6 @@ "outputs": [], "source": [ "train_transforms.set_random_state(seed=0)\n", - "np.random.seed(0)\n", "torch.manual_seed(0)\n", "torch.backends.cudnn.deterministic = True\n", "torch.backends.cudnn.benchmark = False" diff --git a/tests/integration_determinism.py b/tests/integration_determinism.py new file mode 100644 index 0000000000..a1675da0cc --- /dev/null +++ b/tests/integration_determinism.py @@ -0,0 +1,73 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +import numpy as np +import torch +from torch.utils.data import DataLoader, Dataset + +from monai.data.synthetic import create_test_image_2d +from monai.losses.dice import DiceLoss +from monai.networks.nets.unet import UNet + + +def run_test(batch_size=64, train_steps=100, device=torch.device("cuda:0")): + + class _TestBatch(Dataset): + + def __getitem__(self, _unused_id): + im, seg = create_test_image_2d(128, 128, noise_max=1, num_objs=4, num_seg_classes=1) + return im[None], seg[None].astype(np.float32) + + def __len__(self): + return train_steps + + net = UNet( + dimensions=2, + in_channels=1, + out_channels=1, + channels=(4, 8, 16, 32), + strides=(2, 2, 2), + num_res_units=2, + ).to(device) + + loss = DiceLoss(do_sigmoid=True) + opt = torch.optim.Adam(net.parameters(), 1e-4) + src = DataLoader(_TestBatch(), batch_size=batch_size) + + net.train() + epoch_loss = 0 + step = 0 + for img, seg in src: + step += 1 + opt.zero_grad() + output = net(img.to(device)) + step_loss = loss(output, seg.to(device)) + step_loss.backward() + opt.step() + epoch_loss += step_loss.item() + epoch_loss /= step + + print('Loss:', epoch_loss) + delta = abs(epoch_loss - 0.70605) + if delta > 0.0001: + print('Loss value is wrong, expect to be 0.70605.') + return delta + + +if __name__ == "__main__": + np.random.seed(0) + torch.manual_seed(0) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + result = run_test() + sys.exit(1 if result > 0.0001 else 0) diff --git a/tests/integration_sliding_window.py b/tests/integration_sliding_window.py index b99bb5c681..91fd2994be 100644 --- a/tests/integration_sliding_window.py +++ b/tests/integration_sliding_window.py @@ -28,7 +28,7 @@ from tests.utils import make_nifti_image -def run_test(batch_size=2, device=torch.device("cpu:0")): +def run_test(batch_size=2, device=torch.device("cuda:0")): im, seg = create_test_image_3d(25, 28, 63, rad_max=10, noise_max=1, num_objs=4, num_seg_classes=1) input_shape = im.shape @@ -44,7 +44,7 @@ def run_test(batch_size=2, device=torch.device("cpu:0")): channels=(4, 8, 16, 32), strides=(2, 2, 2), num_res_units=2, - ) + ).to(device) roi_size = (16, 32, 48) sw_batch_size = batch_size diff --git a/tests/integration_unet2d.py b/tests/integration_unet2d.py index d437258c10..819be91e4d 100644 --- a/tests/integration_unet2d.py +++ b/tests/integration_unet2d.py @@ -39,7 +39,7 @@ def __len__(self): channels=(4, 8, 16, 32), strides=(2, 2, 2), num_res_units=2, - ) + ).to(device) loss = DiceLoss(do_sigmoid=True) opt = torch.optim.Adam(net.parameters(), 1e-4) From ecf7c1cf2f6becef06dda3a8dc9b9f4c56d8fe0e Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 20 Mar 2020 18:01:07 +0800 Subject: [PATCH 2/6] [DLMED] add determinism support to highlights --- docs/source/highlights.md | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/docs/source/highlights.md b/docs/source/highlights.md index fadb8df622..6571ae640d 100644 --- a/docs/source/highlights.md +++ b/docs/source/highlights.md @@ -47,9 +47,27 @@ affine = Affine( # convert the image using interpolation mode new_img = affine(image, spatial_size=(300, 400), mode='bilinear') ``` + ### 4. Randomly crop out batch images based on positive/negative ratio Medical image data volume may be too large to fit into GPU memory. A widely-used approach is to randomly draw small size data samples during training. MONAI currrently provides uniform random sampling strategy as well as class-balanced fixed ratio sampling which may help stabilize the patch-based training process. +### 5. Determinism training for reproducibility +Determinism training support is necessary and important in DL research area, especially when sharing reproducible work with others. Users can easily set random seed to all the transforms in MONAI and will not affect any other random logic in the user's program. +Example code: +```py +# define a transform chain for pre-processing +train_transforms = monai.transforms.compose.Compose([ + LoadNiftid(keys=['image', 'label']), + ... ... +]) +# set determinism for reproducibility +train_transforms.set_random_state(seed=0) +np.random.seed(0) +torch.manual_seed(0) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False +``` + ## Losses There are domain-specific loss functions in the medical research area which are different from the generic computer vision ones. As an important module of MONAI, these loss functions are implemented in PyTorch, such as Dice loss and generalized Dice loss. From 5e5ba19bceb3084a6c7f2b5e47023e5a08462fc6 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 20 Mar 2020 20:51:41 +0800 Subject: [PATCH 3/6] [DLMED] add transforms to the test --- tests/integration_determinism.py | 35 +++++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/tests/integration_determinism.py b/tests/integration_determinism.py index a1675da0cc..02cc742ba9 100644 --- a/tests/integration_determinism.py +++ b/tests/integration_determinism.py @@ -14,7 +14,8 @@ import numpy as np import torch from torch.utils.data import DataLoader, Dataset - +from monai.transforms import AddChannel, Rescale, UniformRandomPatch, RandRotate90 +import monai.transforms.compose as transforms from monai.data.synthetic import create_test_image_2d from monai.losses.dice import DiceLoss from monai.networks.nets.unet import UNet @@ -24,9 +25,17 @@ def run_test(batch_size=64, train_steps=100, device=torch.device("cuda:0")): class _TestBatch(Dataset): + def __init__(self, transforms): + self.transforms = transforms + def __getitem__(self, _unused_id): im, seg = create_test_image_2d(128, 128, noise_max=1, num_objs=4, num_seg_classes=1) - return im[None], seg[None].astype(np.float32) + seed = np.random.randint(2147483647) + self.transforms.set_random_state(seed=seed) + im = self.transforms(im) + self.transforms.set_random_state(seed=seed) + seg = self.transforms(seg) + return im, seg def __len__(self): return train_steps @@ -41,8 +50,15 @@ def __len__(self): ).to(device) loss = DiceLoss(do_sigmoid=True) - opt = torch.optim.Adam(net.parameters(), 1e-4) - src = DataLoader(_TestBatch(), batch_size=batch_size) + opt = torch.optim.Adam(net.parameters(), 1e-2) + train_transforms = transforms.Compose([ + AddChannel(), + Rescale(), + UniformRandomPatch((96, 96)), + RandRotate90() + ]) + + src = DataLoader(_TestBatch(train_transforms), batch_size=batch_size) net.train() epoch_loss = 0 @@ -58,10 +74,10 @@ def __len__(self): epoch_loss /= step print('Loss:', epoch_loss) - delta = abs(epoch_loss - 0.70605) - if delta > 0.0001: - print('Loss value is wrong, expect to be 0.70605.') - return delta + result = np.allclose(epoch_loss, 0.578675) + if result is False: + print('Loss value is wrong, expect to be 0.578675.') + return result if __name__ == "__main__": @@ -69,5 +85,4 @@ def __len__(self): torch.manual_seed(0) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False - result = run_test() - sys.exit(1 if result > 0.0001 else 0) + sys.exit(0 if run_test() is True else 1) From adab6798c822055e98511247b0ca9282d7d41498 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 20 Mar 2020 20:59:46 +0800 Subject: [PATCH 4/6] [DLMED] update to RandUniformPatch --- docs/source/transforms.rst | 8 ++++---- examples/nifti_read_example.ipynb | 6 +++--- examples/segmentation_3d/unet_training_array.py | 6 +++--- examples/transform_speed.ipynb | 6 +++--- examples/unet_segmentation_3d.ipynb | 6 +++--- monai/transforms/composables.py | 4 ++-- monai/transforms/transforms.py | 2 +- tests/integration_determinism.py | 4 ++-- ...t_uniform_rand_patch.py => test_rand_uniform_patch.py} | 6 +++--- ...uniform_rand_patchd.py => test_rand_uniform_patchd.py} | 8 ++++---- 10 files changed, 28 insertions(+), 28 deletions(-) rename tests/{test_uniform_rand_patch.py => test_rand_uniform_patch.py} (82%) rename tests/{test_uniform_rand_patchd.py => test_rand_uniform_patchd.py} (79%) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 9c48b9f08b..de5bb45975 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -123,9 +123,9 @@ Vanilla Transforms :members: :special-members: __call__ -`UniformRandomPatch` +`RandUniformPatch` ~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: UniformRandomPatch +.. autoclass:: RandUniformPatch :members: :special-members: __call__ @@ -280,9 +280,9 @@ Dictionary-based Composables :members: :special-members: __call__ -`UniformRandomPatchd` +`RandUniformPatchd` ~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: UniformRandomPatchd +.. autoclass:: RandUniformPatchd :members: :special-members: __call__ diff --git a/examples/nifti_read_example.ipynb b/examples/nifti_read_example.ipynb index c83ae70429..f50f838156 100644 --- a/examples/nifti_read_example.ipynb +++ b/examples/nifti_read_example.ipynb @@ -41,7 +41,7 @@ "\n", "from monai.transforms.utils import rescale_array\n", "from monai.data.nifti_reader import NiftiDataset\n", - "from monai.transforms import AddChannel, Transpose, Rescale, ToTensor, UniformRandomPatch\n", + "from monai.transforms import AddChannel, Transpose, Rescale, ToTensor, RandUniformPatch\n", "from monai.data.grid_dataset import GridPatchDataset\n", "\n", "monai.config.print_config()" @@ -135,13 +135,13 @@ "imtrans=transforms.Compose([\n", " Rescale(),\n", " AddChannel(),\n", - " UniformRandomPatch((64, 64, 64)),\n", + " RandUniformPatch((64, 64, 64)),\n", " ToTensor()\n", "]) \n", "\n", "segtrans=transforms.Compose([\n", " AddChannel(),\n", - " UniformRandomPatch((64, 64, 64)),\n", + " RandUniformPatch((64, 64, 64)),\n", " ToTensor()\n", "]) \n", " \n", diff --git a/examples/segmentation_3d/unet_training_array.py b/examples/segmentation_3d/unet_training_array.py index 92e00ec13b..fd700b23f3 100644 --- a/examples/segmentation_3d/unet_training_array.py +++ b/examples/segmentation_3d/unet_training_array.py @@ -25,7 +25,7 @@ import monai.transforms.compose as transforms from monai.data.nifti_reader import NiftiDataset -from monai.transforms import AddChannel, Rescale, UniformRandomPatch, Resize +from monai.transforms import AddChannel, Rescale, RandUniformPatch, Resize from monai.handlers.stats_handler import StatsHandler from monai.handlers.tensorboard_handlers import TensorBoardStatsHandler, TensorBoardImageHandler from monai.handlers.mean_dice import MeanDice @@ -55,11 +55,11 @@ train_imtrans = transforms.Compose([ Rescale(), AddChannel(), - UniformRandomPatch((96, 96, 96)) + RandUniformPatch((96, 96, 96)) ]) train_segtrans = transforms.Compose([ AddChannel(), - UniformRandomPatch((96, 96, 96)) + RandUniformPatch((96, 96, 96)) ]) val_imtrans = transforms.Compose([ Rescale(), diff --git a/examples/transform_speed.ipynb b/examples/transform_speed.ipynb index 2f4974d527..6d486d1850 100644 --- a/examples/transform_speed.ipynb +++ b/examples/transform_speed.ipynb @@ -50,7 +50,7 @@ "from monai.transforms.compose import Compose\n", "from monai.data.nifti_reader import NiftiDataset\n", "from monai.transforms import (AddChannel, Rescale, ToTensor, \n", - " UniformRandomPatch, Rotate, RandAffine)\n", + " RandUniformPatch, Rotate, RandAffine)\n", "\n", "monai.config.print_config()" ] @@ -180,14 +180,14 @@ " Rescale(),\n", " AddChannel(),\n", " Rotate(angle=45.),\n", - " UniformRandomPatch((64, 64, 64)),\n", + " RandUniformPatch((64, 64, 64)),\n", " ToTensor()\n", "]) \n", "\n", "segtrans = Compose([\n", " AddChannel(),\n", " Rotate(angle=45.),\n", - " UniformRandomPatch((64, 64, 64)),\n", + " RandUniformPatch((64, 64, 64)),\n", " ToTensor()\n", "]) \n", " \n", diff --git a/examples/unet_segmentation_3d.ipynb b/examples/unet_segmentation_3d.ipynb index 522250f419..00dceb150a 100644 --- a/examples/unet_segmentation_3d.ipynb +++ b/examples/unet_segmentation_3d.ipynb @@ -43,7 +43,7 @@ "import monai.transforms.compose as transforms\n", "\n", "from monai.data.nifti_reader import NiftiDataset\n", - "from monai.transforms import (AddChannel, Rescale, ToTensor, UniformRandomPatch)\n", + "from monai.transforms import (AddChannel, Rescale, ToTensor, RandUniformPatch)\n", "from monai.handlers.stats_handler import StatsHandler\n", "from monai.handlers.mean_dice import MeanDice\n", "from monai.visualize import img2tensorboard\n", @@ -107,12 +107,12 @@ "imtrans = transforms.Compose([\n", " Rescale(), \n", " AddChannel(), \n", - " UniformRandomPatch((96, 96, 96)), \n", + " RandUniformPatch((96, 96, 96)), \n", " ToTensor()\n", "])\n", "segtrans = transforms.Compose([\n", " AddChannel(), \n", - " UniformRandomPatch((96, 96, 96)), \n", + " RandUniformPatch((96, 96, 96)), \n", " ToTensor()\n", "])\n", "\n", diff --git a/monai/transforms/composables.py b/monai/transforms/composables.py index c84c81653b..6990df5f56 100644 --- a/monai/transforms/composables.py +++ b/monai/transforms/composables.py @@ -278,8 +278,8 @@ def __call__(self, data): @export -@alias('UniformRandomPatchD', 'UniformRandomPatchDict') -class UniformRandomPatchd(Randomizable, MapTransform): +@alias('RandUniformPatchD', 'RandUniformPatchDict') +class RandUniformPatchd(Randomizable, MapTransform): """ Selects a patch of the given size chosen at a uniformly random position in the image. diff --git a/monai/transforms/transforms.py b/monai/transforms/transforms.py index b1708c3e3a..f0ee4a9322 100644 --- a/monai/transforms/transforms.py +++ b/monai/transforms/transforms.py @@ -493,7 +493,7 @@ def __call__(self, img): @export -class UniformRandomPatch(Randomizable): +class RandUniformPatch(Randomizable): """ Selects a patch of the given size chosen at a uniformly random position in the image. diff --git a/tests/integration_determinism.py b/tests/integration_determinism.py index 02cc742ba9..5e31d10fab 100644 --- a/tests/integration_determinism.py +++ b/tests/integration_determinism.py @@ -14,7 +14,7 @@ import numpy as np import torch from torch.utils.data import DataLoader, Dataset -from monai.transforms import AddChannel, Rescale, UniformRandomPatch, RandRotate90 +from monai.transforms import AddChannel, Rescale, RandUniformPatch, RandRotate90 import monai.transforms.compose as transforms from monai.data.synthetic import create_test_image_2d from monai.losses.dice import DiceLoss @@ -54,7 +54,7 @@ def __len__(self): train_transforms = transforms.Compose([ AddChannel(), Rescale(), - UniformRandomPatch((96, 96)), + RandUniformPatch((96, 96)), RandRotate90() ]) diff --git a/tests/test_uniform_rand_patch.py b/tests/test_rand_uniform_patch.py similarity index 82% rename from tests/test_uniform_rand_patch.py rename to tests/test_rand_uniform_patch.py index 32b1077ef7..923f302c97 100644 --- a/tests/test_uniform_rand_patch.py +++ b/tests/test_rand_uniform_patch.py @@ -13,15 +13,15 @@ import numpy as np -from monai.transforms.transforms import UniformRandomPatch +from monai.transforms.transforms import RandUniformPatch from tests.utils import NumpyImageTestCase2D -class TestUniformRandomPatch(NumpyImageTestCase2D): +class TestRandUniformPatch(NumpyImageTestCase2D): def test_2d(self): patch_spatial_size = (10, 10) - patch_transform = UniformRandomPatch(patch_spatial_size=patch_spatial_size) + patch_transform = RandUniformPatch(patch_spatial_size=patch_spatial_size) patch = patch_transform(self.imt[0]) self.assertTrue(np.allclose(patch.shape[1:], patch_spatial_size)) diff --git a/tests/test_uniform_rand_patchd.py b/tests/test_rand_uniform_patchd.py similarity index 79% rename from tests/test_uniform_rand_patchd.py rename to tests/test_rand_uniform_patchd.py index a87438acce..b68d555178 100644 --- a/tests/test_uniform_rand_patchd.py +++ b/tests/test_rand_uniform_patchd.py @@ -13,16 +13,16 @@ import numpy as np -from monai.transforms.composables import UniformRandomPatchd +from monai.transforms.composables import RandUniformPatchd from tests.utils import NumpyImageTestCase2D -class TestUniformRandomPatchd(NumpyImageTestCase2D): +class TestRandUniformPatchd(NumpyImageTestCase2D): def test_2d(self): patch_spatial_size = (10, 10) key = 'test' - patch_transform = UniformRandomPatchd(keys='test', patch_spatial_size=patch_spatial_size) + patch_transform = RandUniformPatchd(keys='test', patch_spatial_size=patch_spatial_size) patch = patch_transform({key: self.imt[0]}) self.assertTrue(np.allclose(patch[key].shape[1:], patch_spatial_size)) @@ -30,7 +30,7 @@ def test_sync(self): patch_spatial_size = (4, 4) key_1, key_2 = 'foo', 'bar' rand_image = np.random.rand(3, 10, 10) - patch_transform = UniformRandomPatchd(keys=(key_1, key_2), patch_spatial_size=patch_spatial_size) + patch_transform = RandUniformPatchd(keys=(key_1, key_2), patch_spatial_size=patch_spatial_size) patch = patch_transform({key_1: rand_image, key_2: rand_image}) self.assertTrue(np.allclose(patch[key_1], patch[key_2])) From 0a0d8f0eef28c25c2a3b214bcd81a14dba593399 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 20 Mar 2020 15:10:19 +0000 Subject: [PATCH 5/6] fixes typos --- docs/source/highlights.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/highlights.md b/docs/source/highlights.md index 6571ae640d..f7895bf65d 100644 --- a/docs/source/highlights.md +++ b/docs/source/highlights.md @@ -51,9 +51,9 @@ new_img = affine(image, spatial_size=(300, 400), mode='bilinear') ### 4. Randomly crop out batch images based on positive/negative ratio Medical image data volume may be too large to fit into GPU memory. A widely-used approach is to randomly draw small size data samples during training. MONAI currrently provides uniform random sampling strategy as well as class-balanced fixed ratio sampling which may help stabilize the patch-based training process. -### 5. Determinism training for reproducibility -Determinism training support is necessary and important in DL research area, especially when sharing reproducible work with others. Users can easily set random seed to all the transforms in MONAI and will not affect any other random logic in the user's program. -Example code: +### 5. Deterministic training for reproducibility +Deterministic training support is necessary and important in DL research area, especially when sharing reproducible work with others. Users can easily set random seed to all the transforms in MONAI locally and will not affect other non-deterministic modules in the user's program. +Example code: ```py # define a transform chain for pre-processing train_transforms = monai.transforms.compose.Compose([ From 1e5057bbc7630726717c022210fd633aff518368 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 20 Mar 2020 17:42:27 +0000 Subject: [PATCH 6/6] fixes empty indicies issue in balanced sampling --- monai/transforms/utils.py | 13 ++++++++---- tests/test_rand_crop_by_pos_neg_labeld.py | 24 ++++++++++++++++++++++- 2 files changed, 32 insertions(+), 5 deletions(-) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index b8ce70d854..f7a2f24501 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings import random import numpy as np @@ -202,12 +203,16 @@ def generate_pos_neg_label_crop_centers(label, size, num_samples, pos_ratio, ima else: bg_indicies = np.nonzero(~label_flat)[0] + if not len(fg_indicies) or not len(bg_indicies): + if not len(fg_indicies) and not len(bg_indicies): + raise ValueError('no sampling location available.') + warnings.warn('N foreground {}, N background {}, unable to generate class balanced samples.'.format( + len(fg_indicies), len(bg_indicies))) + pos_ratio = 0 if not len(fg_indicies) else 1 + centers = [] for _ in range(num_samples): - if rand_state.rand() < pos_ratio: - indicies_to_use = fg_indicies - else: - indicies_to_use = bg_indicies + indicies_to_use = fg_indicies if rand_state.rand() < pos_ratio else bg_indicies random_int = rand_state.randint(len(indicies_to_use)) center = np.unravel_index(indicies_to_use[random_int], label.shape) center = center[1:] diff --git a/tests/test_rand_crop_by_pos_neg_labeld.py b/tests/test_rand_crop_by_pos_neg_labeld.py index 2292dfed9f..f83d737873 100644 --- a/tests/test_rand_crop_by_pos_neg_labeld.py +++ b/tests/test_rand_crop_by_pos_neg_labeld.py @@ -36,10 +36,32 @@ (3, 2, 2, 2), ] +TEST_CASE_2 = [ + { + 'keys': ['image', 'extral', 'label'], + 'label_key': 'label', + 'size': [2, 2, 2], + 'pos': 1, + 'neg': 1, + 'num_samples': 2, + 'image_key': None, + 'image_threshold': 0 + }, + { + 'image': np.zeros([3, 3, 3, 3]) - 1, + 'extral': np.zeros([3, 3, 3, 3]), + 'label': np.ones([3, 3, 3, 3]), + 'affine': np.eye(3), + 'shape': 'CHWD' + }, + list, + (3, 2, 2, 2), +] + class TestRandCropByPosNegLabeld(unittest.TestCase): - @parameterized.expand([TEST_CASE_1]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_type_shape(self, input_param, input_data, expected_type, expected_shape): result = RandCropByPosNegLabeld(**input_param)(input_data) self.assertIsInstance(result, expected_type)