-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Description
I am training a 2D UNet to segment fetal MR images using MONAI and I have been observing some instability in the training when using MONAI Dice loss formulation. After some iteration, the loss jumps up and the network stops learning, as the gradients drop to zero. Here is an example (orange is loss on training set computed over 2D slices, blue is loss on validation computed over 3D volume).

After investigating several aspects (using the same deterministic seed), I've narrowed down the issue to the presence of the smooth term in both the numerator and denominator of the Dice Loss:
f = 1.0 - (2.0 * intersection + smooth) / (denominator + smooth)
When using the formulation:
f = 1.0 - (2.0 * intersection) / (denominator + smooth)
without the smooth term in the numerator, the training was stable and no longer showed unexpected behaviour:

[Note: this experiment was trained for much longer to make sure the jump would not appear later in the training]
The same pattern was observed also for the Tversky Loss, so it could be worth investigating the stability of the losses to identify the best default option.
Software version
MONAI version: 0.1.0+84.ga683c4e.dirty
Python version: 3.7.4 (default, Jul 9 2019, 03:52:42) [GCC 5.4.0 20160609]
Numpy version: 1.18.2
Pytorch version: 1.4.0
Ignite version: 0.3.0
Training information
Using MONAI PersistentCache
2D UNet (as default in MONAI)
Adam optimiser, LR = 1e-3, no LR decay
Batch size: 10
Other tests
The following aspects were investigated but did not solve the instability issue:
- Gradient clipping
- Different optimisers (SGD, SGD + Momentum)
- Transforming the binary segmentations to a two-channel approach ([background segmentation, foreground segmentation])
- Choosing
smooth = 1.0as default here (https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/training/loss_functions/dice_loss.py). However, this made the behaviour even more severe and the jump would happen sooner in the training.
The following losses were also investigated
- Binary Cross Entropy --> stable
- Dice Loss + Binary Cross Entropy --> unstable
- Dice Loss (no smooth at numerator) + Binary Cross Entropy --> stable
- Tversky Loss --> Unstable
- Tversky Loss (no smooth at numerator) --> stable