diff --git a/examples/unet_segmentation_3d.ipynb b/examples/unet_segmentation_3d.ipynb index a5e614199a..b1dc49a8b1 100644 --- a/examples/unet_segmentation_3d.ipynb +++ b/examples/unet_segmentation_3d.ipynb @@ -2,19 +2,13 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", - "text": [ - "MONAI version: 0.0.1\n", - "Python version: 3.7.3 (default, Mar 27 2019, 22:11:17) [GCC 7.3.0]\n", - "Numpy version: 1.16.4\n", - "Pytorch version: 1.3.1\n", - "Ignite version: 0.2.1\n" - ] + "text": "MONAI version: 0.0.1\nPython version: 3.8.1 (default, Jan 8 2020, 22:29:32) [GCC 7.3.0]\nNumpy version: 1.18.1\nPytorch version: 1.4.0\nIgnite version: 0.3.0\n" } ], "source": [ @@ -28,21 +22,24 @@ "import torch\n", "import torch.nn as nn\n", "from torch.utils.data import DataLoader\n", - "import monai.data.transforms.compose as transforms\n", "\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import nibabel as nib\n", "\n", - "from ignite.engine import Events, create_supervised_trainer\n", - "from ignite.handlers import ModelCheckpoint\n", + "from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator\n", + "from ignite.handlers import ModelCheckpoint, EarlyStopping\n", "\n", "# assumes the framework is found here, change as necessary\n", "sys.path.append(\"..\")\n", "\n", + "\n", + "import monai.data.transforms.compose as transforms\n", "from monai import application, data, networks, utils\n", "from monai.data.readers import NiftiDataset\n", "from monai.data.transforms import AddChannel, Transpose, Rescale, ToTensor, UniformRandomPatch, GridPatchDataset\n", + "from monai.networks.metrics.mean_dice import MeanDice\n", + "from monai.utils.stopperutils import stopping_fn_from_metric\n", "\n", "\n", "application.config.print_config()" @@ -50,7 +47,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -81,7 +78,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -99,15 +96,13 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", - "text": [ - "torch.Size([10, 1, 64, 64, 64]) torch.Size([10, 1, 64, 64, 64])\n" - ] + "text": "torch.Size([10, 1, 64, 64, 64]) torch.Size([10, 1, 64, 64, 64])\n" } ], "source": [ @@ -136,7 +131,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -157,46 +152,9 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 14, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 1 Loss: 0.8619852662086487\n", - "Epoch 2 Loss: 0.8307779431343079\n", - "Epoch 3 Loss: 0.8064168691635132\n", - "Epoch 4 Loss: 0.7981672883033752\n", - "Epoch 5 Loss: 0.7950631976127625\n", - "Epoch 6 Loss: 0.7949732542037964\n", - "Epoch 7 Loss: 0.7963427901268005\n", - "Epoch 8 Loss: 0.7939450144767761\n", - "Epoch 9 Loss: 0.7926643490791321\n", - "Epoch 10 Loss: 0.7911991477012634\n", - "Epoch 11 Loss: 0.7886414527893066\n", - "Epoch 12 Loss: 0.7867528796195984\n", - "Epoch 13 Loss: 0.7857398390769958\n", - "Epoch 14 Loss: 0.7833380699157715\n", - "Epoch 15 Loss: 0.7791398763656616\n", - "Epoch 16 Loss: 0.7720394730567932\n", - "Epoch 17 Loss: 0.7671006917953491\n", - "Epoch 18 Loss: 0.7646064758300781\n", - "Epoch 19 Loss: 0.7672612071037292\n", - "Epoch 20 Loss: 0.7600041627883911\n", - "Epoch 21 Loss: 0.7583478689193726\n", - "Epoch 22 Loss: 0.7571365833282471\n", - "Epoch 23 Loss: 0.7545363306999207\n", - "Epoch 24 Loss: 0.7499511241912842\n", - "Epoch 25 Loss: 0.7481640577316284\n", - "Epoch 26 Loss: 0.7469437122344971\n", - "Epoch 27 Loss: 0.7460543513298035\n", - "Epoch 28 Loss: 0.74577796459198\n", - "Epoch 29 Loss: 0.7429620027542114\n", - "Epoch 30 Loss: 0.7424858808517456\n" - ] - } - ], + "outputs": [], "source": [ "trainEpochs = 30\n", "\n", @@ -218,16 +176,60 @@ "\n", "\n", "loader = DataLoader(ds, batch_size=20, num_workers=8, pin_memory=torch.cuda.is_available())\n", - " \n", + "val_loader = DataLoader(ds, batch_size=20, num_workers=8, pin_memory=torch.cuda.is_available())\n", + " \n" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "validation_every_n_epochs = 1\n", + "\n", + "val_metrics = {'Mean Dice': MeanDice(add_sigmoid=True)}\n", + "evaluator = create_supervised_evaluator(net, val_metrics, device, True,\n", + " output_transform=lambda x, y, y_pred: (y_pred[0], y))\n", + "\n", + "\n", + "early_stopper = EarlyStopping(patience=4, \n", + " score_function=stopping_fn_from_metric('Mean Dice'),\n", + " trainer=trainer)\n", + "evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper)\n", + "\n", + "@evaluator.on(Events.EPOCH_COMPLETED)\n", + "def log_validation_metrics(engine):\n", + " for name, value in engine.state.metrics.items():\n", + " print(\"Validation --\", name, \":\", value)\n", + "\n", + "@trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs))\n", + "def run_validation(engine):\n", + " evaluator.run(val_loader)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": "Epoch 1 Loss: 0.8975554704666138\nValidation -- Mean Dice : 0.11846490800380707\nEpoch 2 Loss: 0.8451039791107178\nValidation -- Mean Dice : 0.12091563045978546\nEpoch 3 Loss: 0.9355515241622925\nValidation -- Mean Dice : 0.12139833569526673\nEpoch 4 Loss: 0.843208909034729\nValidation -- Mean Dice : 0.12108306288719177\nEpoch 5 Loss: 0.8225834965705872\nValidation -- Mean Dice : 0.12179622799158096\nEpoch 6 Loss: 0.957372784614563\nValidation -- Mean Dice : 0.12193384170532226\nEpoch 7 Loss: 0.9011092782020569\nValidation -- Mean Dice : 0.1230143740773201\nEpoch 8 Loss: 0.8651387691497803\nValidation -- Mean Dice : 0.1254110112786293\nEpoch 9 Loss: 0.8767974972724915\nValidation -- Mean Dice : 0.12633273899555206\nEpoch 10 Loss: 0.8193061947822571\nValidation -- Mean Dice : 0.12657881826162337\nEpoch 11 Loss: 0.9466649293899536\nValidation -- Mean Dice : 0.12699378579854964\nEpoch 12 Loss: 0.8258659243583679\nValidation -- Mean Dice : 0.12790720015764237\nEpoch 13 Loss: 0.8661612868309021\nValidation -- Mean Dice : 0.12980296313762665\nEpoch 14 Loss: 0.8039132356643677\nValidation -- Mean Dice : 0.1311295285820961\nEpoch 15 Loss: 0.8050084114074707\nValidation -- Mean Dice : 0.13225494623184203\nEpoch 16 Loss: 0.9048625230789185\nValidation -- Mean Dice : 0.1330576255917549\nEpoch 17 Loss: 0.9179995656013489\nValidation -- Mean Dice : 0.13361359685659407\nEpoch 18 Loss: 0.8956605195999146\nValidation -- Mean Dice : 0.13432369381189346\nEpoch 19 Loss: 0.8029189705848694\nValidation -- Mean Dice : 0.13532216250896453\nEpoch 20 Loss: 0.8359838128089905\nValidation -- Mean Dice : 0.13622953295707702\nEpoch 21 Loss: 0.9225850105285645\nValidation -- Mean Dice : 0.13677610754966735\nEpoch 22 Loss: 0.7023072242736816\nValidation -- Mean Dice : 0.13693425357341765\nEpoch 23 Loss: 0.8776397705078125\nValidation -- Mean Dice : 0.13710424304008484\nEpoch 24 Loss: 0.9571539163589478\nValidation -- Mean Dice : 0.1370883911848068\nEpoch 25 Loss: 0.8877002596855164\nValidation -- Mean Dice : 0.13701471388339997\nEpoch 26 Loss: 0.817417562007904\nValidation -- Mean Dice : 0.13696834743022918\nEpoch 27 Loss: 0.8971314430236816\nValidation -- Mean Dice : 0.1371448516845703\nEpoch 28 Loss: 0.9443905353546143\nValidation -- Mean Dice : 0.13739778995513915\nEpoch 29 Loss: 0.7578094005584717\nValidation -- Mean Dice : 0.137495020031929\nEpoch 30 Loss: 0.7037953734397888\nValidation -- Mean Dice : 0.13759489357471466\n" + } + ], + "source": [ "state = trainer.run(loader, trainEpochs)" ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3.7.5 64-bit ('pytorch': conda)", "language": "python", - "name": "python3" + "name": "python37564bitpytorchconda9e7dd2186ac2430b947ee08d8eff35b4" }, "language_info": { "codemirror_mode": { @@ -239,9 +241,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.3" + "version": "3.8.1-final" } }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/examples/unet_segmentation_3d.py b/examples/unet_segmentation_3d.py index 86fcaef5e7..edf332f7ee 100644 --- a/examples/unet_segmentation_3d.py +++ b/examples/unet_segmentation_3d.py @@ -20,8 +20,8 @@ import torch import monai.transforms.compose as transforms from torch.utils.tensorboard import SummaryWriter -from ignite.engine import Events, create_supervised_trainer -from ignite.handlers import ModelCheckpoint +from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator +from ignite.handlers import ModelCheckpoint, EarlyStopping from torch.utils.data import DataLoader import monai @@ -32,12 +32,14 @@ from monai.handlers.mean_dice import MeanDice from monai.visualize import img2tensorboard from monai.data.synthetic import create_test_image_3d +from monai.handlers.utils import stopping_fn_from_metric # assumes the framework is found here, change as necessary sys.path.append("..") config.print_config() +# Create a temporary directory and 50 random image, mask paris tempdir = tempfile.mkdtemp() for i in range(50): @@ -52,18 +54,20 @@ images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz'))) segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) +# Define transforms for image and segmentation imtrans = transforms.Compose([Rescale(), AddChannel(), UniformRandomPatch((64, 64, 64)), ToTensor()]) - segtrans = transforms.Compose([AddChannel(), UniformRandomPatch((64, 64, 64)), ToTensor()]) +# Define nifti dataset, dataloader. ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans) - loader = DataLoader(ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) im, seg = monai.utils.misc.first(loader) print(im.shape, seg.shape) + lr = 1e-3 +# Create UNet, DiceLoss and Adam optimizer. net = monai.networks.nets.UNet( dimensions=3, in_channels=1, @@ -78,13 +82,12 @@ train_epochs = 3 - +# Since network outputs logits and segmentation, we need a custom function. def _loss_fn(i, j): return loss(i[0], j) - +# Create trainer device = torch.device("cuda:0") - trainer = create_supervised_trainer(net, opt, _loss_fn, device, False, output_transform=lambda x, y, y_pred, loss: [y_pred, loss.item(), y]) @@ -133,6 +136,28 @@ def log_training_loss(engine): loader = DataLoader(ds, batch_size=20, num_workers=8, pin_memory=torch.cuda.is_available()) +val_loader = DataLoader(ds, batch_size=20, num_workers=8, pin_memory=torch.cuda.is_available()) writer = SummaryWriter() +# Define mean dice metric and Evaluator. +validation_every_n_epochs = 1 + +val_metrics = {'Mean Dice': MeanDice(add_sigmoid=True)} +evaluator = create_supervised_evaluator(net, val_metrics, device, True, + output_transform=lambda x, y, y_pred: (y_pred[0], y)) + +val_stats_handler = StatsHandler() +val_stats_handler.attach(evaluator) + +# Add early stopping handler to evaluator. +early_stopper = EarlyStopping(patience=4, + score_function=stopping_fn_from_metric('Mean Dice'), + trainer=trainer) +evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) + +@trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) +def run_validation(engine): + evaluator.run(val_loader) + + state = trainer.run(loader, train_epochs) diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py new file mode 100644 index 0000000000..377d4d0073 --- /dev/null +++ b/monai/handlers/utils.py @@ -0,0 +1,24 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def stopping_fn_from_metric(metric_name): + """Returns a stopping function for ignite.handlers.EarlyStopping using the given metric name.""" + def stopping_fn(engine): + return engine.state.metrics[metric_name] + return stopping_fn + + +def stopping_fn_from_loss(): + """Returns a stopping function for ignite.handlers.EarlyStopping using the loss value.""" + def stopping_fn(engine): + return -engine.state.output + return stopping_fn