Skip to content
Merged
8 changes: 5 additions & 3 deletions .gitlab-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@ stages:
script:
- nvidia-smi
- python -m pip install --upgrade pip
- pip uninstall -y torch torchvision
- pip install -r requirements.txt
- pip list
- pip install flake8
- pip install pep8-naming
# stop the build if there are Python syntax errors or undefined names
- flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --config ./.flake8
# - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --config ./.flake8
# exit-zero treats all errors as warnings.
# - flake8 . --count --statistics --config ./.flake8
- ./runtests.sh --quick
- flake8 . --count --statistics --config ./.flake8
- ./runtests.sh --net
- echo "Done with runtests.sh"

build-ci-test:
Expand Down
7 changes: 3 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ _Contact: <monai.miccai2019@gmail.com>_

This document identifies key concepts of project MONAI at a high level, the goal is to facilitate further technical discussions of requirements,roadmap, feasibility and trade-offs.


## Vision
* Develop a community of academic, industrial and clinical researchers collaborating and working on a common foundation of standardized tools.
* Create a state-of-the-art, end-to-end training toolkit for healthcare imaging.
Expand All @@ -15,7 +14,7 @@ This document identifies key concepts of project MONAI at a high level, the goal
* Primarily focused on the healthcare researchers who develop DL models for medical imaging

## Goals
* Deliver domain-specific workflow capabilities
* Deliver domain-specific workflow capabilities
* Address the end-end “Pain points” when creating medical imaging deep learning workflows.
* Provide a robust foundation with a performance optimized system software stack that allows researchers to focus on the research and not worry about software development principles.

Expand Down Expand Up @@ -137,7 +136,7 @@ This document identifies key concepts of project MONAI at a high level, the goal
<tr>
<td>Configuration-driven workflow assembly
</td>
<td colspan="2" >Making workflow instances from configuration file
<td colspan="2" >Making workflow instances from configuration file
</td>
<td>Convenient for managing hyperparameters
</td>
Expand Down Expand Up @@ -209,7 +208,7 @@ This document identifies key concepts of project MONAI at a high level, the goal
</td>
</tr>
<tr>
<td>Compatibility with external toolkits
<td>Compatibility with external toolkits
</td>
<td colspan="2" >XNAT as data source, ITK as preprocessor
</td>
Expand Down
44 changes: 44 additions & 0 deletions monai/utils/generateddata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import numpy as np

from monai.utils.arrayutils import rescale_array


def create_test_image_2d(width, height, num_objs=12, rad_max=30, noise_max=0.0, num_seg_classes=5):
"""
Return a noisy 2D image with `numObj' circles and a 2D mask image. The maximum radius of the circles is given as
`radMax'. The mask will have `numSegClasses' number of classes for segmentations labeled sequentially from 1, plus a
background class represented as 0. If `noiseMax' is greater than 0 then noise will be added to the image taken from
the uniform distribution on range [0,noiseMax).
"""
image = np.zeros((width, height))

for i in range(num_objs):
x = np.random.randint(rad_max, width - rad_max)
y = np.random.randint(rad_max, height - rad_max)
rad = np.random.randint(5, rad_max)
spy, spx = np.ogrid[-x : width - x, -y : height - y]
circle = (spx * spx + spy * spy) <= rad * rad

if num_seg_classes > 1:
image[circle] = np.ceil(np.random.random() * num_seg_classes)
else:
image[circle] = np.random.random() * 0.5 + 0.5

labels = np.ceil(image).astype(np.int32)

norm = np.random.uniform(0, num_seg_classes * noise_max, size=image.shape)
noisyimage = rescale_array(np.maximum(image, norm))

return noisyimage, labels
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
torch
pytorch-ignite==0.2.1
torch>=1.4
pytorch-ignite==0.3.0
numpy
pyyaml
blinker
Expand Down
9 changes: 5 additions & 4 deletions runtests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ set -e
homedir="$( cd -P "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
cd $homedir

#export PYTHONPATH="$homedir:$PYTHONPATH"
export PYTHONPATH="$homedir:$PYTHONPATH"
echo $PYTHONPATH

# configuration values
doCoverage=false
Expand All @@ -16,7 +17,7 @@ doDryRun=false
doZooTests=false

# testing command to run
cmd="python"
cmd="python3"
cmdprefix=""


Expand Down Expand Up @@ -75,13 +76,13 @@ fi


# unit tests
${cmdprefix}${cmd} -m unittest
${cmdprefix}${cmd} -m unittest -v


# network training/inference/eval tests
if [ "$doNetTests" = 'true' ]
then
for i in examples/*.py
for i in tests/integration_*.py
do
echo $i
${cmdprefix}${cmd} $i
Expand Down
59 changes: 59 additions & 0 deletions tests/integration_unet2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys

import numpy as np
import torch
from ignite.engine import create_supervised_trainer
from torch.utils.data import DataLoader, Dataset

from monai import networks, utils


def run_test(batch_size=64, train_steps=100, device=torch.device("cuda:0")):

class _TestBatch(Dataset):

def __getitem__(self, _unused_id):
im, seg = utils.generateddata.create_test_image_2d(128, 128, noise_max=1, num_objs=4, num_seg_classes=1)
return im[None], seg[None].astype(np.float32)

def __len__(self):
return train_steps

net = networks.nets.UNet(
dimensions=2,
in_channels=1,
num_classes=1,
channels=(4, 8, 16, 32),
strides=(2, 2, 2),
num_res_units=2,
)

loss = networks.losses.DiceLoss()
opt = torch.optim.Adam(net.parameters(), 1e-4)
src = DataLoader(_TestBatch(), batch_size=batch_size)

def loss_fn(pred, grnd):
return loss(pred[0], grnd)

trainer = create_supervised_trainer(net, opt, loss_fn, device, False)

trainer.run(src, 1)

return trainer.state.output


if __name__ == "__main__":
result = run_test()

sys.exit(0 if result < 1 else 1)
33 changes: 2 additions & 31 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import numpy as np
import torch

from monai.utils.arrayutils import rescale_array
from monai.utils.generateddata import create_test_image_2d

quick_test_var = "QUICKTEST"

Expand All @@ -26,43 +26,14 @@ def skip_if_quick(obj):
return unittest.skipIf(is_quick, "Skipping slow tests")(obj)


def create_test_image(width, height, num_objs=12, rad_max=30, noise_max=0.0, num_seg_classes=5):
"""
Return a noisy 2D image with `numObj' circles and a 2D mask image. The maximum radius of the circles is given as
`radMax'. The mask will have `numSegClasses' number of classes for segmentations labeled sequentially from 1, plus a
background class represented as 0. If `noiseMax' is greater than 0 then noise will be added to the image taken from
the uniform distribution on range [0,noiseMax).
"""
image = np.zeros((width, height))

for i in range(num_objs):
x = np.random.randint(rad_max, width - rad_max)
y = np.random.randint(rad_max, height - rad_max)
rad = np.random.randint(5, rad_max)
spy, spx = np.ogrid[-x : width - x, -y : height - y]
circle = (spx * spx + spy * spy) <= rad * rad

if num_seg_classes > 1:
image[circle] = np.ceil(np.random.random() * num_seg_classes)
else:
image[circle] = np.random.random() * 0.5 + 0.5

labels = np.ceil(image).astype(np.int32)

norm = np.random.uniform(0, num_seg_classes * noise_max, size=image.shape)
noisyimage = rescale_array(np.maximum(image, norm))

return noisyimage, labels


class NumpyImageTestCase2D(unittest.TestCase):
im_shape = (128, 128)
input_channels = 1
output_channels = 4
num_classes = 3

def setUp(self):
im, msk = create_test_image(self.im_shape[0], self.im_shape[1], 4, 20, 0, self.num_classes)
im, msk = create_test_image_2d(self.im_shape[0], self.im_shape[1], 4, 20, 0, self.num_classes)

self.imt = im[None, None]
self.seg1 = (msk[None, None] > 0).astype(np.float32)
Expand Down