Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/unet_segmentation_3d.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@
" num_res_units=2,\n",
")\n",
"\n",
"loss = networks.losses.DiceLoss()\n",
"loss = networks.losses.DiceLoss(do_sigmoid=True)\n",
"opt = torch.optim.Adam(net.parameters(), lr)"
]
},
Expand Down
4 changes: 2 additions & 2 deletions examples/unet_segmentation_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def create_test_image_3d(height, width, depth, num_objs=12, rad_max=30, noise_ma
num_res_units=2,
)

loss = networks.losses.DiceLoss()
loss = networks.losses.DiceLoss(do_sigmoid=True)
opt = torch.optim.Adam(net.parameters(), lr)

train_epochs = 30
Expand All @@ -108,7 +108,7 @@ def _loss_fn(i, j):
trainer = create_supervised_trainer(net, opt, _loss_fn, device, False,
output_transform=lambda x, y, y_pred, loss: [y_pred, loss.item()])

checkpoint_handler = ModelCheckpoint('./', 'net', n_saved=10, save_interval=3, require_empty=False)
checkpoint_handler = ModelCheckpoint('./', 'net', n_saved=10, require_empty=False)
trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={'net': net})


Expand Down
13 changes: 6 additions & 7 deletions monai/networks/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,26 +49,25 @@ def forward(self, pred, ground, smooth=1e-5):

psum = pred.float()
if self.do_sigmoid:
psum = psum.sigmoid()
if pred.shape[1] == 1: # binary dice loss, use sigmoid activation
psum = psum.sigmoid() # use sigmoid activation
if pred.shape[1] == 1:
if self.do_softmax:
raise ValueError('do_softmax is not compatible with single channel prediction.')
if not self.include_background:
raise RuntimeWarning('single channel ground truth, `include_background=False` ignored.')
raise RuntimeWarning('single channel prediction, `include_background=False` ignored.')
tsum = ground
else:
else: # multiclass dice loss
if self.do_softmax:
if self.do_sigmoid:
raise ValueError('do_sigmoid=True and do_softmax=Ture are not compatible.')
# multiclass dice loss, use softmax in the first dimension and convert target to one-hot encoding
psum = torch.softmax(pred, 1)
tsum = one_hot(ground, pred.shape[1]) # B1HW(D) -> BNHW(D)
# exclude background category so that it doesn't overwhelm the other segmentations if they are small
if not self.include_background:
tsum = tsum[:, 1:]
psum = psum[:, 1:]
assert tsum.shape == pred.shape, ("Ground truth one-hot has differing shape (%r) from source (%r)" %
(tsum.shape, pred.shape))
assert tsum.shape == psum.shape, ("Ground truth one-hot has differing shape (%r) from source (%r)" %
(tsum.shape, psum.shape))

batchsize = ground.size(0)
tsum = tsum.float().view(batchsize, -1)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_dice_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@

TEST_CASE_3 = [ # shape: (2, 2, 3), (2, 1, 3)
{
'include_background': True,
'include_background': False,
},
{
'pred': torch.tensor([[[1., 1., 0.], [0., 0., 1.]], [[1., 0., 1.], [0., 1., 0.]]]),
Expand Down Expand Up @@ -80,6 +80,7 @@
0.373045,
]


class TestDiceLoss(unittest.TestCase):

@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])
Expand Down