diff --git a/examples/unet_segmentation_3d.ipynb b/examples/unet_segmentation_3d.ipynb index 6af33fe27b..3a87b3224e 100644 --- a/examples/unet_segmentation_3d.ipynb +++ b/examples/unet_segmentation_3d.ipynb @@ -151,7 +151,7 @@ " num_res_units=2,\n", ")\n", "\n", - "loss = networks.losses.DiceLoss()\n", + "loss = networks.losses.DiceLoss(do_sigmoid=True)\n", "opt = torch.optim.Adam(net.parameters(), lr)" ] }, diff --git a/examples/unet_segmentation_3d.py b/examples/unet_segmentation_3d.py index b1f921e6b5..364cdaab7f 100644 --- a/examples/unet_segmentation_3d.py +++ b/examples/unet_segmentation_3d.py @@ -93,7 +93,7 @@ def create_test_image_3d(height, width, depth, num_objs=12, rad_max=30, noise_ma num_res_units=2, ) -loss = networks.losses.DiceLoss() +loss = networks.losses.DiceLoss(do_sigmoid=True) opt = torch.optim.Adam(net.parameters(), lr) train_epochs = 30 @@ -108,7 +108,7 @@ def _loss_fn(i, j): trainer = create_supervised_trainer(net, opt, _loss_fn, device, False, output_transform=lambda x, y, y_pred, loss: [y_pred, loss.item()]) -checkpoint_handler = ModelCheckpoint('./', 'net', n_saved=10, save_interval=3, require_empty=False) +checkpoint_handler = ModelCheckpoint('./', 'net', n_saved=10, require_empty=False) trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={'net': net}) diff --git a/monai/networks/losses/dice.py b/monai/networks/losses/dice.py index c193dbd3f8..c98f503034 100644 --- a/monai/networks/losses/dice.py +++ b/monai/networks/losses/dice.py @@ -49,26 +49,25 @@ def forward(self, pred, ground, smooth=1e-5): psum = pred.float() if self.do_sigmoid: - psum = psum.sigmoid() - if pred.shape[1] == 1: # binary dice loss, use sigmoid activation + psum = psum.sigmoid() # use sigmoid activation + if pred.shape[1] == 1: if self.do_softmax: raise ValueError('do_softmax is not compatible with single channel prediction.') if not self.include_background: - raise RuntimeWarning('single channel ground truth, `include_background=False` ignored.') + raise RuntimeWarning('single channel prediction, `include_background=False` ignored.') tsum = ground - else: + else: # multiclass dice loss if self.do_softmax: if self.do_sigmoid: raise ValueError('do_sigmoid=True and do_softmax=Ture are not compatible.') - # multiclass dice loss, use softmax in the first dimension and convert target to one-hot encoding psum = torch.softmax(pred, 1) tsum = one_hot(ground, pred.shape[1]) # B1HW(D) -> BNHW(D) # exclude background category so that it doesn't overwhelm the other segmentations if they are small if not self.include_background: tsum = tsum[:, 1:] psum = psum[:, 1:] - assert tsum.shape == pred.shape, ("Ground truth one-hot has differing shape (%r) from source (%r)" % - (tsum.shape, pred.shape)) + assert tsum.shape == psum.shape, ("Ground truth one-hot has differing shape (%r) from source (%r)" % + (tsum.shape, psum.shape)) batchsize = ground.size(0) tsum = tsum.float().view(batchsize, -1) diff --git a/tests/test_dice_loss.py b/tests/test_dice_loss.py index bbaa49c8ac..dac0039945 100644 --- a/tests/test_dice_loss.py +++ b/tests/test_dice_loss.py @@ -44,7 +44,7 @@ TEST_CASE_3 = [ # shape: (2, 2, 3), (2, 1, 3) { - 'include_background': True, + 'include_background': False, }, { 'pred': torch.tensor([[[1., 1., 0.], [0., 0., 1.]], [[1., 0., 1.], [0., 1., 0.]]]), @@ -80,6 +80,7 @@ 0.373045, ] + class TestDiceLoss(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])