Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions monai/handlers/mean_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,29 @@ class MeanDice(Metric):
def __init__(
self,
include_background=True,
to_onehot_y=False,
logit_thresh=0.5,
to_onehot_y=True,
logit_thresh=None,
add_sigmoid=False,
mutually_exclusive=False,
mutually_exclusive=True,
output_transform: Callable = lambda x: x,
device: Optional[Union[str, torch.device]] = None
):
"""

Args:
include_background (Bool): whether to include dice computation on the first channel of the predicted output.
to_onehot_y (Bool): whether to convert the output prediction into the one-hot format.
logit_thresh (Float): the threshold value to round value to 0.0 and 1.0, default is 0.5.
Defaults to True.
to_onehot_y (Bool): whether to convert the output prediction into the one-hot format. Defaults to True.
logit_thresh (Float): the threshold value to round value to 0.0 and 1.0. Defaults to None (no thresholding).
add_sigmoid (Bool): whether to add sigmoid function to the output prediction before computing Dice.
Defaults to False.
mutually_exclusive (Bool): if True, the output prediction will be converted into a binary matrix using
a combination of argmax and to_onehot.
a combination of argmax and to_onehot. Defaults to True.
output_transform (Callable): transform the ignite.engine.state.output into [y_pred, y] pair.
device (torch.device): device specification in case of distributed computation usage.

See also:
monai.metrics.compute_meandice.compute_meandice
"""
super(MeanDice, self).__init__(output_transform, device=device)
self.include_background = include_background
Expand Down
9 changes: 7 additions & 2 deletions monai/handlers/segmentation_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class SegmentationSaver:
"""

def __init__(self, output_path='./', dtype='float32', output_postfix='seg', output_ext='.nii.gz',
output_transform=lambda x: x):
output_transform=lambda x: x, name=None):
"""
Args:
output_path (str): output image directory.
Expand All @@ -34,14 +34,19 @@ def __init__(self, output_path='./', dtype='float32', output_postfix='seg', outp
ignite.engine.output into the form expected nifti image data.
The first dimension of this transform's output will be treated as the
batch dimension. Each item in the batch will be saved individually.
name (str): identifier of logging.logger to use, defaulting to `engine.logger`.
"""
self.output_path = output_path
self.dtype = dtype
self.output_postfix = output_postfix
self.output_ext = output_ext
self.output_transform = output_transform

self.logger = None if name is None else logging.getLogger(name)

def attach(self, engine):
if self.logger is None:
self.logger = engine.logger
return engine.add_event_handler(Events.ITERATION_COMPLETED, self)

@staticmethod
Expand Down Expand Up @@ -103,4 +108,4 @@ def __call__(self, engine):
output_filename = self._create_file_basename(self.output_postfix, filename, self.output_path)
output_filename = '{}{}'.format(output_filename, self.output_ext)
write_nifti(seg_output, _affine, output_filename, _original_affine, dtype=seg_output.dtype)
print('saved: {}'.format(output_filename))
self.logger.info('saved: {}'.format(output_filename))
9 changes: 7 additions & 2 deletions monai/metrics/compute_meandice.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def compute_meandice(y_pred,
y_pred (torch.Tensor): input data to compute, typical segmentation model output.
it must be One-Hot format and first dim is batch, example shape: [16, 3, 32, 32].
y (torch.Tensor): ground truth to compute mean dice metric, the first dim is batch.
example shape: [16, 3, 32, 32] for 3-class one-hot labels.
alternative shape: [16, 1, 32, 32] and set `to_onehot_y=True` to convert it into [16, 3, 32, 32].
include_background (Bool): whether to skip dice computation on the first channel of the predicted output.
to_onehot_y (Bool): whether to convert `y` into the one-hot format.
mutually_exclusive (Bool): if True, `y_pred` will be converted into a binary matrix using
Expand All @@ -44,8 +46,8 @@ def compute_meandice(y_pred,
n_channels_y_pred = y_pred.shape[1]

if mutually_exclusive:
if logit_thresh is not None:
raise ValueError('`logit_thresh` is incompatible when mutually_exclusive is True.')
if logit_thresh is not None or add_sigmoid:
raise ValueError('`logit_thresh` and `add_sigmoid` are incompatible when mutually_exclusive is True.')
y_pred = torch.argmax(y_pred, dim=1, keepdim=True)
y_pred = one_hot(y_pred, n_channels_y_pred)
else: # channel-wise thresholding
Expand All @@ -61,6 +63,9 @@ def compute_meandice(y_pred,
y = y[:, 1:] if y.shape[1] > 1 else y
y_pred = y_pred[:, 1:] if y_pred.shape[1] > 1 else y_pred

assert y.shape == y_pred.shape, ("Ground truth one-hot has differing shape (%r) from source (%r)" %
(y.shape, y_pred.shape))

# reducing only spatial dimensions (not batch nor channels)
reduce_axis = list(range(2, y_pred.dim()))
intersection = torch.sum(y * y_pred, reduce_axis)
Expand Down
80 changes: 80 additions & 0 deletions tests/integration_sliding_window.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
import tempfile

import nibabel as nib
import torch
from ignite.engine import Engine
from torch.utils.data import DataLoader

from monai.data.nifti_reader import NiftiDataset
from monai.data.synthetic import create_test_image_3d
from monai.handlers.segmentation_saver import SegmentationSaver
from monai.networks.nets.unet import UNet
from monai.networks.utils import predict_segmentation
from monai.transforms.transforms import AddChannel
from monai.utils.sliding_window_inference import sliding_window_inference
from tests.utils import make_nifti_image


def run_test(batch_size=2, device=torch.device("cpu:0")):

im, seg = create_test_image_3d(25, 28, 63, rad_max=10, noise_max=1, num_objs=4, num_seg_classes=1)
input_shape = im.shape
img_name = make_nifti_image(im)
seg_name = make_nifti_image(seg)
ds = NiftiDataset([img_name], [seg_name], transform=AddChannel(), seg_transform=AddChannel(), image_only=False)
loader = DataLoader(ds, batch_size=1, pin_memory=torch.cuda.is_available())

net = UNet(
dimensions=3,
in_channels=1,
num_classes=1,
channels=(4, 8, 16, 32),
strides=(2, 2, 2),
num_res_units=2,
)
roi_size = (16, 32, 48)
sw_batch_size = batch_size

def _sliding_window_processor(_engine, batch):
net.eval()
img, seg, meta_data = batch
with torch.no_grad():
seg_probs = sliding_window_inference(img, roi_size, sw_batch_size, lambda x: net(x)[0], device)
return predict_segmentation(seg_probs)

infer_engine = Engine(_sliding_window_processor)

with tempfile.TemporaryDirectory() as temp_dir:
SegmentationSaver(output_path=temp_dir, output_ext='.nii.gz', output_postfix='seg').attach(infer_engine)

infer_engine.run(loader)

basename = os.path.basename(img_name)[:-len('.nii.gz')]
saved_name = os.path.join(temp_dir, basename, '{}_seg.nii.gz'.format(basename))
testing_shape = nib.load(saved_name).get_fdata().shape

if os.path.exists(img_name):
os.remove(img_name)
if os.path.exists(seg_name):
os.remove(seg_name)

return testing_shape == input_shape


if __name__ == "__main__":
result = run_test()

sys.exit(0 if result else 1)
67 changes: 67 additions & 0 deletions tests/test_handler_mean_dice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch
from parameterized import parameterized

from monai.handlers.mean_dice import MeanDice

TEST_CASE_1 = [{'to_onehot_y': True, 'mutually_exclusive': True}, 0.75]
TEST_CASE_2 = [{'include_background': False, 'to_onehot_y': False, 'mutually_exclusive': False}, 0.8333333]

TEST_CASE_3 = [{'mutually_exclusive': True, 'add_sigmoid': True}]


class TestHandlerMeanDice(unittest.TestCase):
# TODO test multi node averaged dice

@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
def test_compute(self, input_params, expected_avg):
dice_metric = MeanDice(**input_params)

y_pred = torch.Tensor([[0, 1], [1, 0]])
y = torch.ones((2, 1))
dice_metric.update([y_pred, y])

y_pred = torch.Tensor([[0, 1], [1, 0]])
y = torch.Tensor([[1.], [0.]])
dice_metric.update([y_pred, y])

avg_dice = dice_metric.compute()
self.assertAlmostEqual(avg_dice, expected_avg)

@parameterized.expand([TEST_CASE_3])
def test_misconfig(self, input_params):
with self.assertRaisesRegex(ValueError, 'compatib'):
dice_metric = MeanDice(**input_params)

y_pred = torch.Tensor([[0, 1], [1, 0]])
y = torch.ones((2, 1))
dice_metric.update([y_pred, y])

@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
def test_shape_mismatch(self, input_params, _expected):
dice_metric = MeanDice(**input_params)
with self.assertRaises((AssertionError, ValueError)):
y_pred = torch.Tensor([[0, 1], [1, 0]])
y = torch.ones((2, 3))
dice_metric.update([y_pred, y])

with self.assertRaises((AssertionError, ValueError)):
y_pred = torch.Tensor([[0, 1], [1, 0]])
y = torch.ones((3, 2))
dice_metric.update([y_pred, y])


if __name__ == '__main__':
unittest.main()
61 changes: 61 additions & 0 deletions tests/test_handler_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import re
import unittest
from io import StringIO

from ignite.engine import Engine, Events

from monai.handlers.stats_handler import StatsHandler


class TestHandlerStats(unittest.TestCase):

def test_metrics_print(self):
log_stream = StringIO()
logging.basicConfig(stream=log_stream, level=logging.INFO)
key_to_handler = 'test_logging'
key_to_print = 'testing_metric'

# set up engine
def _train_func(engine, batch):
pass

engine = Engine(_train_func)

# set up dummy metric
@engine.on(Events.ITERATION_COMPLETED)
def _update_metric(engine):
current_metric = engine.state.metrics.get(key_to_print, 0.1)
engine.state.metrics[key_to_print] = current_metric + 0.1

# set up testing handler
stats_handler = StatsHandler(name=key_to_handler)
stats_handler.attach(engine)

engine.run(range(3), max_epochs=2)

# check logging output
output_str = log_stream.getvalue()
grep = re.compile('.*{}.*'.format(key_to_handler))
has_key_word = re.compile('.*{}.*'.format(key_to_print))
matched = []
for idx, line in enumerate(output_str.split('\n')):
if grep.match(line):
self.assertTrue(has_key_word.match(line))
matched.append(idx)
self.assertEqual(matched, [1, 2, 3, 5, 6, 7, 8, 10])


if __name__ == '__main__':
unittest.main()
4 changes: 3 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@ def skip_if_quick(obj):
return unittest.skipIf(is_quick, "Skipping slow tests")(obj)


def make_nifti_image(array, affine):
def make_nifti_image(array, affine=None):
"""
Create a temporary nifti image on the disk and return the image name.
User is responsible for deleting the temporary file when done with it.
"""
if affine is None:
affine = np.eye(4)
test_image = nib.Nifti1Image(array, affine)

_, image_name = tempfile.mkstemp(suffix='.nii.gz')
Expand Down