Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,5 @@ venv.bak/
.mypy_cache/
examples/scd_lvsegs.npz
.idea/

*~
140 changes: 140 additions & 0 deletions examples/multi_gpu_test.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MONAI version: 0.0.1\n",
"Python version: 3.7.3 (default, Mar 27 2019, 22:11:17) [GCC 7.3.0]\n",
"Numpy version: 1.16.4\n",
"Pytorch version: 1.4.0\n",
"Ignite version: 0.3.0\n"
]
}
],
"source": [
"import sys\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch.utils.data import DataLoader\n",
"import torchvision.transforms as transforms\n",
"\n",
"import numpy as np\n",
"\n",
"from ignite.engine import create_supervised_trainer\n",
"from ignite.engine.engine import Events\n",
"from ignite.handlers import ModelCheckpoint\n",
"\n",
"# assumes the framework is found here, change as necessary\n",
"sys.path.append(\"..\")\n",
"\n",
"from monai import application, data, networks, utils\n",
"\n",
"\n",
"application.config.print_config()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/localek10/miniconda3/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py:26: UserWarning: \n",
" There is an imbalance between your GPUs. You may want to exclude GPU 1 which\n",
" has less than 75% of the memory or cores of GPU 0. You can do so by setting\n",
" the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES\n",
" environment variable.\n",
" warnings.warn(imbalance_warn.format(device_ids[min_pos], device_ids[max_pos]))\n"
]
},
{
"data": {
"text/plain": [
"State:\n",
"\titeration: 4\n",
"\tepoch: 2\n",
"\tepoch_length: 2\n",
"\tmax_epochs: 2\n",
"\toutput: 20912.578125\n",
"\tbatch: <class 'tuple'>\n",
"\tmetrics: <class 'dict'>\n",
"\tdataloader: <class 'generator'>\n",
"\tseed: 12"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lr = 1e-3\n",
"\n",
"net = networks.nets.UNet(\n",
" dimensions=2,\n",
" in_channels=1,\n",
" num_classes=1,\n",
" channels=(16, 32, 64, 128, 256),\n",
" strides=(2, 2, 2, 2),\n",
" num_res_units=2,\n",
")\n",
"\n",
"\n",
"def fake_loss(y_pred,y):\n",
" return (y_pred[0]+y).sum()\n",
"\n",
"\n",
"def fake_data_stream():\n",
" while True:\n",
" yield torch.rand((10,1,64,64)),torch.rand((10,1,64,64))\n",
" \n",
" \n",
"# 1 GPU\n",
"opt = torch.optim.Adam(net.parameters(), lr)\n",
"trainer=application.engine.create_multigpu_supervised_trainer(net,opt,fake_loss,[torch.device('cuda:0')])\n",
"trainer.run(fake_data_stream(),2,2)\n",
"\n",
"# all GPUs\n",
"opt = torch.optim.Adam(net.parameters(), lr)\n",
"trainer=application.engine.create_multigpu_supervised_trainer(net,opt,fake_loss,None)\n",
"trainer.run(fake_data_stream(),2,2)\n",
"\n",
"# CPU\n",
"opt = torch.optim.Adam(net.parameters(), lr)\n",
"trainer=application.engine.create_multigpu_supervised_trainer(net,opt,fake_loss,[])\n",
"trainer.run(fake_data_stream(),2,2)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
1 change: 0 additions & 1 deletion examples/unet_segmentation_3d.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
"import sys\n",
"import tempfile\n",
"from glob import glob\n",
"from functools import partial\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
Expand Down
115 changes: 115 additions & 0 deletions monai/application/engine/multi_gpu_supervised_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch

from ignite.engine import create_supervised_trainer, create_supervised_evaluator, _prepare_batch

import monai


def get_devices_spec(devices=None):
"""
Get a valid specification for one or more devices. If `devices` is None get devices for all CUDA devices available.
If `devices` is and zero-length structure a single CPU compute device is returned. In any other cases `devices` is
returned unchanged.

Args:
devices (list, optional): list of devices to request, None for all GPU devices, [] for CPU.

Returns:
list of torch.device: list of devices.
"""
if devices is None:
devices = [torch.device('cuda:%i' % d) for d in range(torch.cuda.device_count())]

if len(devices) == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If no GPU is found, we should default to CPU.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I mentioned elsewhere, defaulting to CPU like that will cause silent errors when people expect to use GPUs. If people want GPUs they should get a loud and clear error that they can't get them, otherwise they'll think everything is find just super slow.

raise ValueError("No GPU devices available")

elif len(devices) == 0:
devices = [torch.device("cpu")]
Copy link
Contributor

@Nic-Ma Nic-Ma Feb 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just suggest to print a warning here if use CPU instead.
Because this code file is for "multi_gpu_trainer", what do you think?
People may don't know that "devices = empty list" is "CPU device".
Others look good to me.
Thanks.


return devices


def _default_transform(x, y, y_pred, loss):
return loss.item()


def _default_eval_transform(x, y, y_pred):
return y_pred, y


@monai.utils.export("monai.application.engine")
def create_multigpu_supervised_trainer(net, optimizer, loss_fn, devices=None, non_blocking=False,
prepare_batch=_prepare_batch, output_transform=_default_transform):
"""
***Derived from `create_supervised_trainer` in Ignite.

Factory function for creating a trainer for supervised models.
Args:
net (`torch.nn.Module`): the network to train.
optimizer (`torch.optim.Optimizer`): the optimizer to use.
loss_fn (torch.nn loss function): the loss function to use.
devices (list, optional): device(s) type specification (default: None).
Applies to both model and batches. None is all devices used, empty list is CPU only.
non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously
with respect to the host. For other cases, this argument has no effect.
prepare_batch (callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs
tuple of tensors `(batch_x, batch_y)`.
output_transform (callable, optional): function that receives 'x', 'y', 'y_pred', 'loss' and returns value
to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
Note: `engine.state.output` for this engine is defind by `output_transform` parameter and is the loss
of the processed batch by default.
Returns:
Engine: a trainer engine with supervised update function.
"""

devices = get_devices_spec(devices)

if len(devices) > 1:
net = torch.nn.parallel.DataParallel(net)

return create_supervised_trainer(net, optimizer, loss_fn, devices[0], non_blocking, prepare_batch, output_transform)


@monai.utils.export("monai.application.engine")
def create_multigpu_supervised_evaluator(net, metrics=None, device=None, non_blocking=False,
prepare_batch=_prepare_batch, output_transform=_default_eval_transform):
"""
***Derived from `create_supervised_evaluator` in Ignite.

Factory function for creating an evaluator for supervised models.
Args:
net (`torch.nn.Module`): the model to train.
metrics (dict of str - :class:`~ignite.metrics.Metric`): a map of metric names to Metrics.
devices (list, optional): device(s) type specification (default: None).
Applies to both model and batches. None is all devices used, empty list is CPU only.
non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously
with respect to the host. For other cases, this argument has no effect.
prepare_batch (callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs
tuple of tensors `(batch_x, batch_y)`.
output_transform (callable, optional): function that receives 'x', 'y', 'y_pred' and returns value
to be assigned to engine's state.output after each iteration. Default is returning `(y_pred, y,)` which fits
output expected by metrics. If you change it you should use `output_transform` in metrics.
Note: `engine.state.output` for this engine is defind by `output_transform` parameter and is
a tuple of `(batch_pred, batch_y)` by default.
Returns:
Engine: an evaluator engine with supervised inference function.
"""

devices = get_devices_spec(devices)

if len(devices) > 1:
net = torch.nn.parallel.DataParallel(net)

return create_supervised_evaluator(net, metrics, devices[0], non_blocking, prepare_batch, output_transform)
63 changes: 63 additions & 0 deletions tests/test_parallel_execution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import warnings

import torch

from monai.application.engine import create_multigpu_supervised_trainer


def fake_loss(y_pred, y):
return (y_pred[0] + y).sum()


def fake_data_stream():
while True:
yield torch.rand((10, 1, 64, 64)), torch.rand((10, 1, 64, 64))

def expect_failure_if_no_gpu(test):
if not torch.cuda.is_available():
return unittest.expectedFailure(test)
else:
return test


class TestParallelExecution(unittest.TestCase):
"""
Tests single GPU, multi GPU, and CPU execution with the Ignite supervised trainer.
"""

@expect_failure_if_no_gpu
def test_single_gpu(self):
net = torch.nn.Conv2d(1, 1, 3, padding=1)
opt = torch.optim.Adam(net.parameters(), 1e-3)
trainer = create_multigpu_supervised_trainer(net, opt, fake_loss, [torch.device("cuda:0")])
trainer.run(fake_data_stream(), 2, 2)

@expect_failure_if_no_gpu
def test_multi_gpu(self):
net = torch.nn.Conv2d(1, 1, 3, padding=1)
opt = torch.optim.Adam(net.parameters(), 1e-3)

with warnings.catch_warnings():
warnings.simplefilter("ignore") # ignore warnings about imbalanced GPU memory

trainer = create_multigpu_supervised_trainer(net, opt, fake_loss, None)

trainer.run(fake_data_stream(), 2, 2)

def test_cpu(self):
net = torch.nn.Conv2d(1, 1, 3, padding=1)
opt = torch.optim.Adam(net.parameters(), 1e-3)
trainer = create_multigpu_supervised_trainer(net, opt, fake_loss, [])
trainer.run(fake_data_stream(), 2, 2)