From ac00686ea58f5f8f7a19b02fd7f4d371cf3c188f Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Fri, 31 Jan 2020 15:40:12 +0000 Subject: [PATCH 1/6] Adding adaptation to use the default Ignite supervised training function with multiple GPUs using data parallelism --- examples/multi_gpu_test.ipynb | 140 ++++++++++++++++++ examples/unet_segmentation_3d.ipynb | 1 - .../engine/multi_gpu_supervised_trainer.py | 95 ++++++++++++ 3 files changed, 235 insertions(+), 1 deletion(-) create mode 100644 examples/multi_gpu_test.ipynb create mode 100644 monai/application/engine/multi_gpu_supervised_trainer.py diff --git a/examples/multi_gpu_test.ipynb b/examples/multi_gpu_test.ipynb new file mode 100644 index 0000000000..8ceb30af67 --- /dev/null +++ b/examples/multi_gpu_test.ipynb @@ -0,0 +1,140 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MONAI version: 0.0.1\n", + "Python version: 3.7.3 (default, Mar 27 2019, 22:11:17) [GCC 7.3.0]\n", + "Numpy version: 1.16.4\n", + "Pytorch version: 1.4.0\n", + "Ignite version: 0.3.0\n" + ] + } + ], + "source": [ + "import sys\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "from torch.utils.data import DataLoader\n", + "import torchvision.transforms as transforms\n", + "\n", + "import numpy as np\n", + "\n", + "from ignite.engine import create_supervised_trainer\n", + "from ignite.engine.engine import Events\n", + "from ignite.handlers import ModelCheckpoint\n", + "\n", + "# assumes the framework is found here, change as necessary\n", + "sys.path.append(\"..\")\n", + "\n", + "from monai import application, data, networks, utils\n", + "\n", + "\n", + "application.config.print_config()" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/localek10/miniconda3/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py:26: UserWarning: \n", + " There is an imbalance between your GPUs. You may want to exclude GPU 1 which\n", + " has less than 75% of the memory or cores of GPU 0. You can do so by setting\n", + " the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES\n", + " environment variable.\n", + " warnings.warn(imbalance_warn.format(device_ids[min_pos], device_ids[max_pos]))\n" + ] + }, + { + "data": { + "text/plain": [ + "State:\n", + "\titeration: 4\n", + "\tepoch: 2\n", + "\tepoch_length: 2\n", + "\tmax_epochs: 2\n", + "\toutput: 20912.578125\n", + "\tbatch: \n", + "\tmetrics: \n", + "\tdataloader: \n", + "\tseed: 12" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "lr = 1e-3\n", + "\n", + "net = networks.nets.UNet(\n", + " dimensions=2,\n", + " in_channels=1,\n", + " num_classes=1,\n", + " channels=(16, 32, 64, 128, 256),\n", + " strides=(2, 2, 2, 2),\n", + " num_res_units=2,\n", + ")\n", + "\n", + "\n", + "def fake_loss(y_pred,y):\n", + " return (y_pred[0]+y).sum()\n", + "\n", + "\n", + "def fake_data_stream():\n", + " while True:\n", + " yield torch.rand((10,1,64,64)),torch.rand((10,1,64,64))\n", + " \n", + " \n", + "# 1 GPU\n", + "opt = torch.optim.Adam(net.parameters(), lr)\n", + "trainer=application.engine.create_multigpu_supervised_trainer(net,opt,fake_loss,[torch.device('cuda:0')])\n", + "trainer.run(fake_data_stream(),2,2)\n", + "\n", + "# all GPUs\n", + "opt = torch.optim.Adam(net.parameters(), lr)\n", + "trainer=application.engine.create_multigpu_supervised_trainer(net,opt,fake_loss,None)\n", + "trainer.run(fake_data_stream(),2,2)\n", + "\n", + "# CPU\n", + "opt = torch.optim.Adam(net.parameters(), lr)\n", + "trainer=application.engine.create_multigpu_supervised_trainer(net,opt,fake_loss,[])\n", + "trainer.run(fake_data_stream(),2,2)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/unet_segmentation_3d.ipynb b/examples/unet_segmentation_3d.ipynb index 81ee6ce8b2..6af33fe27b 100644 --- a/examples/unet_segmentation_3d.ipynb +++ b/examples/unet_segmentation_3d.ipynb @@ -24,7 +24,6 @@ "import sys\n", "import tempfile\n", "from glob import glob\n", - "from functools import partial\n", "\n", "import torch\n", "import torch.nn as nn\n", diff --git a/monai/application/engine/multi_gpu_supervised_trainer.py b/monai/application/engine/multi_gpu_supervised_trainer.py new file mode 100644 index 0000000000..912d4cc9ad --- /dev/null +++ b/monai/application/engine/multi_gpu_supervised_trainer.py @@ -0,0 +1,95 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + +from ignite.engine import create_supervised_trainer, create_supervised_evaluator, _prepare_batch + +import monai + + +def _default_transform(x, y, y_pred, loss): + return loss.item() + +def _default_eval_transform(x, y, y_pred): + return y_pred, y + +@monai.utils.export("monai.application.engine") +def create_multigpu_supervised_trainer(net, optimizer, loss_fn, devices=None, non_blocking=False, + prepare_batch=_prepare_batch, output_transform=_default_transform): + """ + ***Derived from `create_supervised_trainer` in Ignite. + + Factory function for creating a trainer for supervised models. + Args: + net (`torch.nn.Module`): the network to train. + optimizer (`torch.optim.Optimizer`): the optimizer to use. + loss_fn (torch.nn loss function): the loss function to use. + devices (list, optional): device(s) type specification (default: None). + Applies to both model and batches. None is all devices used, empty list is CPU only. + non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously + with respect to the host. For other cases, this argument has no effect. + prepare_batch (callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs + tuple of tensors `(batch_x, batch_y)`. + output_transform (callable, optional): function that receives 'x', 'y', 'y_pred', 'loss' and returns value + to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`. + Note: `engine.state.output` for this engine is defind by `output_transform` parameter and is the loss + of the processed batch by default. + Returns: + Engine: a trainer engine with supervised update function. + """ + + if devices is None: + devices = [torch.device('cuda:%i' % d) for d in range(torch.cuda.device_count())] + elif len(devices) == 0: + devices = [torch.device("cpu")] + + if len(devices) > 1: + net = torch.nn.parallel.DataParallel(net) + + return create_supervised_trainer(net, optimizer, loss_fn, devices[0], non_blocking, prepare_batch, output_transform) + + +@monai.utils.export("monai.application.engine") +def create_multigpu_supervised_evaluator(net, metrics=None, device=None, non_blocking=False, + prepare_batch=_prepare_batch, output_transform=_default_eval_transform): + """ + ***Derived from `create_supervised_evaluator` in Ignite. + + Factory function for creating an evaluator for supervised models. + Args: + net (`torch.nn.Module`): the model to train. + metrics (dict of str - :class:`~ignite.metrics.Metric`): a map of metric names to Metrics. + devices (list, optional): device(s) type specification (default: None). + Applies to both model and batches. None is all devices used, empty list is CPU only. + non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously + with respect to the host. For other cases, this argument has no effect. + prepare_batch (callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs + tuple of tensors `(batch_x, batch_y)`. + output_transform (callable, optional): function that receives 'x', 'y', 'y_pred' and returns value + to be assigned to engine's state.output after each iteration. Default is returning `(y_pred, y,)` which fits + output expected by metrics. If you change it you should use `output_transform` in metrics. + Note: `engine.state.output` for this engine is defind by `output_transform` parameter and is + a tuple of `(batch_pred, batch_y)` by default. + Returns: + Engine: an evaluator engine with supervised inference function. + """ + + if devices is None: + devices = [torch.device('cuda:%i' % d) for d in range(torch.cuda.device_count())] + elif len(devices) == 0: + devices = [torch.device("cpu")] + + if len(devices) > 1: + net = torch.nn.parallel.DataParallel(net) + + return create_supervised_evaluator(net, metrics, devices[0], non_blocking, prepare_batch, output_transform) From 5ca872279784439cd654bea49be0f57efe742be3 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Fri, 31 Jan 2020 15:40:22 +0000 Subject: [PATCH 2/6] Update test_parallel_execution.py --- tests/test_parallel_execution.py | 49 ++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 tests/test_parallel_execution.py diff --git a/tests/test_parallel_execution.py b/tests/test_parallel_execution.py new file mode 100644 index 0000000000..02985c4d47 --- /dev/null +++ b/tests/test_parallel_execution.py @@ -0,0 +1,49 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from monai.application.engine import create_multigpu_supervised_trainer + + +def fake_loss(y_pred, y): + return (y_pred[0] + y).sum() + + +def fake_data_stream(): + while True: + yield torch.rand((10, 1, 64, 64)), torch.rand((10, 1, 64, 64)) + + +class TestParallelExecution(unittest.TestCase): + """ + Tests single GPU, multi GPU, and CPU execution with the Ignite supervised trainer. + """ + + def test_single_gpu(self): + net = torch.nn.Conv2d(1, 1, 3, padding=1) + opt = torch.optim.Adam(net.parameters(), 1e-3) + trainer = create_multigpu_supervised_trainer(net, opt, fake_loss, [0]) + trainer.run(fake_data_stream(), 2, 2) + + def test_multi_gpu(self): + net = torch.nn.Conv2d(1, 1, 3, padding=1) + opt = torch.optim.Adam(net.parameters(), 1e-3) + trainer = create_multigpu_supervised_trainer(net, opt, fake_loss, None) + trainer.run(fake_data_stream(), 2, 2) + + def test_cpu(self): + net = torch.nn.Conv2d(1, 1, 3, padding=1) + opt = torch.optim.Adam(net.parameters(), 1e-3) + trainer = create_multigpu_supervised_trainer(net, opt, fake_loss, []) + trainer.run(fake_data_stream(), 2, 2) From 81116592c2fdca3fe004c37ca252296209396847 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Fri, 31 Jan 2020 15:51:19 +0000 Subject: [PATCH 3/6] Ignite version change to 0.3.0. --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 6999c4db50..d4d6c8ca49 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ torch -pytorch-ignite==0.2.1 +pytorch-ignite==0.3.0 numpy pyyaml blinker From 94b7949a75a4abb2949355b1073e3bfc48a56594 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Wed, 5 Feb 2020 11:30:16 +0000 Subject: [PATCH 4/6] Slight tweak to parallel code --- .gitignore | 2 ++ .../engine/multi_gpu_supervised_trainer.py | 36 ++++++++++++++----- tests/test_parallel_execution.py | 8 ++++- 3 files changed, 37 insertions(+), 9 deletions(-) diff --git a/.gitignore b/.gitignore index c30f242fd2..b15c95db2f 100644 --- a/.gitignore +++ b/.gitignore @@ -104,3 +104,5 @@ venv.bak/ .mypy_cache/ examples/scd_lvsegs.npz .idea/ + +*~ diff --git a/monai/application/engine/multi_gpu_supervised_trainer.py b/monai/application/engine/multi_gpu_supervised_trainer.py index 912d4cc9ad..51e7d4ca7c 100644 --- a/monai/application/engine/multi_gpu_supervised_trainer.py +++ b/monai/application/engine/multi_gpu_supervised_trainer.py @@ -17,12 +17,38 @@ import monai +def get_devices_spec(devices=None): + """ + Get a valid specification for one or more devices. If `devices` is None get devices for all CUDA devices available. + If `devices` is and zero-length structure a single CPU compute device is returned. In any other cases `devices` is + returned unchanged. + + Args: + devices (list, optional): list of devices to request, None for all GPU devices, [] for CPU. + + Returns: + list of torch.device: list of devices. + """ + if devices is None: + devices = [torch.device('cuda:%i' % d) for d in range(torch.cuda.device_count())] + + if len(devices) == 0: + raise ValueError("No GPU devices available") + + elif len(devices) == 0: + devices = [torch.device("cpu")] + + return devices + + def _default_transform(x, y, y_pred, loss): return loss.item() + def _default_eval_transform(x, y, y_pred): return y_pred, y + @monai.utils.export("monai.application.engine") def create_multigpu_supervised_trainer(net, optimizer, loss_fn, devices=None, non_blocking=False, prepare_batch=_prepare_batch, output_transform=_default_transform): @@ -48,10 +74,7 @@ def create_multigpu_supervised_trainer(net, optimizer, loss_fn, devices=None, no Engine: a trainer engine with supervised update function. """ - if devices is None: - devices = [torch.device('cuda:%i' % d) for d in range(torch.cuda.device_count())] - elif len(devices) == 0: - devices = [torch.device("cpu")] + devices = get_devices_spec(devices) if len(devices) > 1: net = torch.nn.parallel.DataParallel(net) @@ -84,10 +107,7 @@ def create_multigpu_supervised_evaluator(net, metrics=None, device=None, non_blo Engine: an evaluator engine with supervised inference function. """ - if devices is None: - devices = [torch.device('cuda:%i' % d) for d in range(torch.cuda.device_count())] - elif len(devices) == 0: - devices = [torch.device("cpu")] + devices = get_devices_spec(devices) if len(devices) > 1: net = torch.nn.parallel.DataParallel(net) diff --git a/tests/test_parallel_execution.py b/tests/test_parallel_execution.py index 02985c4d47..c9562c7b23 100644 --- a/tests/test_parallel_execution.py +++ b/tests/test_parallel_execution.py @@ -10,6 +10,7 @@ # limitations under the License. import unittest +import warnings import torch @@ -39,7 +40,12 @@ def test_single_gpu(self): def test_multi_gpu(self): net = torch.nn.Conv2d(1, 1, 3, padding=1) opt = torch.optim.Adam(net.parameters(), 1e-3) - trainer = create_multigpu_supervised_trainer(net, opt, fake_loss, None) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") # ignore warnings about imbalanced GPU memory + + trainer = create_multigpu_supervised_trainer(net, opt, fake_loss, None) + trainer.run(fake_data_stream(), 2, 2) def test_cpu(self): From 34fd93a6617ed5bacedf9d942930adc80b79d6a9 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Thu, 6 Feb 2020 15:20:07 +0000 Subject: [PATCH 5/6] Update test_parallel_execution.py --- tests/test_parallel_execution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_parallel_execution.py b/tests/test_parallel_execution.py index c9562c7b23..f8e0e07d4a 100644 --- a/tests/test_parallel_execution.py +++ b/tests/test_parallel_execution.py @@ -42,7 +42,7 @@ def test_multi_gpu(self): opt = torch.optim.Adam(net.parameters(), 1e-3) with warnings.catch_warnings(): - warnings.simplefilter("ignore") # ignore warnings about imbalanced GPU memory + warnings.simplefilter("ignore") # ignore warnings about imbalanced GPU memory trainer = create_multigpu_supervised_trainer(net, opt, fake_loss, None) From 463affea2e9ca1dbb82812e17af1d6959f39ffbb Mon Sep 17 00:00:00 2001 From: Mohammad Adil Date: Fri, 7 Feb 2020 01:00:27 -0800 Subject: [PATCH 6/6] Add decorator to expect failure if no GPUs. (#51) --- tests/test_parallel_execution.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/test_parallel_execution.py b/tests/test_parallel_execution.py index f8e0e07d4a..0ef0dccd2a 100644 --- a/tests/test_parallel_execution.py +++ b/tests/test_parallel_execution.py @@ -25,18 +25,26 @@ def fake_data_stream(): while True: yield torch.rand((10, 1, 64, 64)), torch.rand((10, 1, 64, 64)) +def expect_failure_if_no_gpu(test): + if not torch.cuda.is_available(): + return unittest.expectedFailure(test) + else: + return test + class TestParallelExecution(unittest.TestCase): """ Tests single GPU, multi GPU, and CPU execution with the Ignite supervised trainer. """ + @expect_failure_if_no_gpu def test_single_gpu(self): net = torch.nn.Conv2d(1, 1, 3, padding=1) opt = torch.optim.Adam(net.parameters(), 1e-3) - trainer = create_multigpu_supervised_trainer(net, opt, fake_loss, [0]) + trainer = create_multigpu_supervised_trainer(net, opt, fake_loss, [torch.device("cuda:0")]) trainer.run(fake_data_stream(), 2, 2) + @expect_failure_if_no_gpu def test_multi_gpu(self): net = torch.nn.Conv2d(1, 1, 3, padding=1) opt = torch.optim.Adam(net.parameters(), 1e-3)