From 884c3b14fc28302d92d0fe99765e28fcbd4e8048 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 17 Feb 2020 11:22:24 +0000 Subject: [PATCH 1/4] unit test -- mean_dice handler --- monai/handlers/mean_dice.py | 11 +++++--- monai/metrics/compute_meandice.py | 2 ++ tests/test_handler_mean_dice.py | 43 +++++++++++++++++++++++++++++++ 3 files changed, 52 insertions(+), 4 deletions(-) create mode 100644 tests/test_handler_mean_dice.py diff --git a/monai/handlers/mean_dice.py b/monai/handlers/mean_dice.py index 73ac5e4173..0925d3d261 100644 --- a/monai/handlers/mean_dice.py +++ b/monai/handlers/mean_dice.py @@ -27,10 +27,10 @@ class MeanDice(Metric): def __init__( self, include_background=True, - to_onehot_y=False, - logit_thresh=0.5, + to_onehot_y=True, + logit_thresh=None, add_sigmoid=False, - mutually_exclusive=False, + mutually_exclusive=True, output_transform: Callable = lambda x: x, device: Optional[Union[str, torch.device]] = None ): @@ -39,12 +39,15 @@ def __init__( Args: include_background (Bool): whether to include dice computation on the first channel of the predicted output. to_onehot_y (Bool): whether to convert the output prediction into the one-hot format. - logit_thresh (Float): the threshold value to round value to 0.0 and 1.0, default is 0.5. + logit_thresh (Float): the threshold value to round value to 0.0 and 1.0. add_sigmoid (Bool): whether to add sigmoid function to the output prediction before computing Dice. mutually_exclusive (Bool): if True, the output prediction will be converted into a binary matrix using a combination of argmax and to_onehot. output_transform (Callable): transform the ignite.engine.state.output into [y_pred, y] pair. device (torch.device): device specification in case of distributed computation usage. + + See also: + monai.metrics.compute_meandice.compute_meandice """ super(MeanDice, self).__init__(output_transform, device=device) self.include_background = include_background diff --git a/monai/metrics/compute_meandice.py b/monai/metrics/compute_meandice.py index 4a10c03693..978738af7b 100644 --- a/monai/metrics/compute_meandice.py +++ b/monai/metrics/compute_meandice.py @@ -27,6 +27,8 @@ def compute_meandice(y_pred, y_pred (torch.Tensor): input data to compute, typical segmentation model output. it must be One-Hot format and first dim is batch, example shape: [16, 3, 32, 32]. y (torch.Tensor): ground truth to compute mean dice metric, the first dim is batch. + example shape: [16, 3, 32, 32] for 3-class one-hot labels. + alternative shape: [16, 1, 32, 32] and set `to_onehot_y=True` to convert it into [16, 3, 32, 32]. include_background (Bool): whether to skip dice computation on the first channel of the predicted output. to_onehot_y (Bool): whether to convert `y` into the one-hot format. mutually_exclusive (Bool): if True, `y_pred` will be converted into a binary matrix using diff --git a/tests/test_handler_mean_dice.py b/tests/test_handler_mean_dice.py new file mode 100644 index 0000000000..380a81fa03 --- /dev/null +++ b/tests/test_handler_mean_dice.py @@ -0,0 +1,43 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.handlers.mean_dice import MeanDice + +TEST_CASE_1 = [{'to_onehot_y': True, 'mutually_exclusive': True}, 0.75] +TEST_CASE_2 = [{'include_background': False, 'to_onehot_y': False, 'mutually_exclusive': False}, 0.8333333] + + +class TestHandlerMeanDice(unittest.TestCase): + # TODO test multi node averaged dice + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_compute(self, input_params, expected_avg): + dice_metric = MeanDice(**input_params) + + y_pred = torch.Tensor([[0, 1], [1, 0]]) + y = torch.ones((2, 1)) + dice_metric.update([y_pred, y]) + + y_pred = torch.Tensor([[0, 1], [1, 0]]) + y = torch.Tensor([[1.], [0.]]) + dice_metric.update([y_pred, y]) + + avg_dice = dice_metric.compute() + self.assertAlmostEqual(avg_dice, expected_avg) + + +if __name__ == '__main__': + unittest.main() From 071f3b1b356a6aa210e6de906ffa9362997ba86b Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 17 Feb 2020 12:02:25 +0000 Subject: [PATCH 2/4] unit test -- stats handler --- tests/test_handler_stats.py | 61 +++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 tests/test_handler_stats.py diff --git a/tests/test_handler_stats.py b/tests/test_handler_stats.py new file mode 100644 index 0000000000..5bbe17d1c2 --- /dev/null +++ b/tests/test_handler_stats.py @@ -0,0 +1,61 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re +import unittest +from io import StringIO + +from ignite.engine import Engine, Events + +from monai.handlers.stats_handler import StatsHandler + + +class TestHandlerStats(unittest.TestCase): + + def test_metrics_print(self): + log_stream = StringIO() + logging.basicConfig(stream=log_stream, level=logging.INFO) + key_to_handler = 'test_logging' + key_to_print = 'testing_metric' + + # set up engine + def _train_func(engine, batch): + pass + + engine = Engine(_train_func) + + # set up dummy metric + @engine.on(Events.ITERATION_COMPLETED) + def _update_metric(engine): + current_metric = engine.state.metrics.get(key_to_print, 0.1) + engine.state.metrics[key_to_print] = current_metric + 0.1 + + # set up testing handler + stats_handler = StatsHandler(name=key_to_handler) + stats_handler.attach(engine) + + engine.run(range(3), max_epochs=2) + + # check logging output + output_str = log_stream.getvalue() + grep = re.compile('.*{}.*'.format(key_to_handler)) + has_key_word = re.compile('.*{}.*'.format(key_to_print)) + matched = [] + for idx, line in enumerate(output_str.split('\n')): + if grep.match(line): + self.assertTrue(has_key_word.match(line)) + matched.append(idx) + self.assertEqual(matched, [1, 2, 3, 5, 6, 7, 8, 10]) + + +if __name__ == '__main__': + unittest.main() From 59bd2375445aba7595ffd5cc5b090b6fee8534c9 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 17 Feb 2020 12:39:27 +0000 Subject: [PATCH 3/4] integration test -- sliding window --- monai/handlers/segmentation_saver.py | 9 +++- tests/integration_sliding_window.py | 81 ++++++++++++++++++++++++++++ tests/utils.py | 4 +- 3 files changed, 91 insertions(+), 3 deletions(-) create mode 100644 tests/integration_sliding_window.py diff --git a/monai/handlers/segmentation_saver.py b/monai/handlers/segmentation_saver.py index 31e6881678..a87e517f81 100644 --- a/monai/handlers/segmentation_saver.py +++ b/monai/handlers/segmentation_saver.py @@ -23,7 +23,7 @@ class SegmentationSaver: """ def __init__(self, output_path='./', dtype='float32', output_postfix='seg', output_ext='.nii.gz', - output_transform=lambda x: x): + output_transform=lambda x: x, name=None): """ Args: output_path (str): output image directory. @@ -34,6 +34,7 @@ def __init__(self, output_path='./', dtype='float32', output_postfix='seg', outp ignite.engine.output into the form expected nifti image data. The first dimension of this transform's output will be treated as the batch dimension. Each item in the batch will be saved individually. + name (str): identifier of logging.logger to use, defaulting to `engine.logger`. """ self.output_path = output_path self.dtype = dtype @@ -41,7 +42,11 @@ def __init__(self, output_path='./', dtype='float32', output_postfix='seg', outp self.output_ext = output_ext self.output_transform = output_transform + self.logger = None if name is None else logging.getLogger(name) + def attach(self, engine): + if self.logger is None: + self.logger = engine.logger return engine.add_event_handler(Events.ITERATION_COMPLETED, self) @staticmethod @@ -103,4 +108,4 @@ def __call__(self, engine): output_filename = self._create_file_basename(self.output_postfix, filename, self.output_path) output_filename = '{}{}'.format(output_filename, self.output_ext) write_nifti(seg_output, _affine, output_filename, _original_affine, dtype=seg_output.dtype) - print('saved: {}'.format(output_filename)) + self.logger.info('saved: {}'.format(output_filename)) diff --git a/tests/integration_sliding_window.py b/tests/integration_sliding_window.py new file mode 100644 index 0000000000..8857c89bf7 --- /dev/null +++ b/tests/integration_sliding_window.py @@ -0,0 +1,81 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import tempfile + +import nibabel as nib +import torch +from ignite.engine import Engine +from torch.utils.data import DataLoader + +from monai.data.nifti_reader import NiftiDataset +from monai.data.synthetic import create_test_image_3d +from monai.handlers.segmentation_saver import SegmentationSaver +from monai.networks.nets.unet import UNet +from monai.networks.utils import predict_segmentation +from monai.transforms.transforms import AddChannel +from monai.utils.sliding_window_inference import sliding_window_inference +from tests.utils import make_nifti_image + + +def run_test(batch_size=64, train_steps=100, device=torch.device("cuda:0")): + + im, seg = create_test_image_3d(25, 28, 63, rad_max=10, noise_max=1, num_objs=4, num_seg_classes=1) + input_shape = im.shape + img_name = make_nifti_image(im) + seg_name = make_nifti_image(seg) + ds = NiftiDataset([img_name], [seg_name], transform=AddChannel(), seg_transform=AddChannel(), image_only=False) + loader = DataLoader(ds, batch_size=1, pin_memory=torch.cuda.is_available()) + + net = UNet( + dimensions=3, + in_channels=1, + num_classes=1, + channels=(4, 8, 16, 32), + strides=(2, 2, 2), + num_res_units=2, + ) + roi_size = (16, 32, 48) + sw_batch_size = 2 + device = torch.device('cpu:0') + + def _sliding_window_processor(_engine, batch): + net.eval() + img, seg, meta_data = batch + with torch.no_grad(): + seg_probs = sliding_window_inference(img, roi_size, sw_batch_size, lambda x: net(x)[0], device) + return predict_segmentation(seg_probs) + + infer_engine = Engine(_sliding_window_processor) + + with tempfile.TemporaryDirectory() as temp_dir: + SegmentationSaver(output_path=temp_dir, output_ext='.nii.gz', output_postfix='seg').attach(infer_engine) + + infer_engine.run(loader) + + basename = os.path.basename(img_name)[:-len('.nii.gz')] + saved_name = os.path.join(temp_dir, basename, '{}_seg.nii.gz'.format(basename)) + testing_shape = nib.load(saved_name).get_fdata().shape + + if os.path.exists(img_name): + os.remove(img_name) + if os.path.exists(seg_name): + os.remove(seg_name) + + return testing_shape == input_shape + + +if __name__ == "__main__": + result = run_test() + + sys.exit(0 if result else 1) diff --git a/tests/utils.py b/tests/utils.py index 27f75fe7c3..b2e4e6d743 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -28,11 +28,13 @@ def skip_if_quick(obj): return unittest.skipIf(is_quick, "Skipping slow tests")(obj) -def make_nifti_image(array, affine): +def make_nifti_image(array, affine=None): """ Create a temporary nifti image on the disk and return the image name. User is responsible for deleting the temporary file when done with it. """ + if affine is None: + affine = np.eye(4) test_image = nib.Nifti1Image(array, affine) _, image_name = tempfile.mkstemp(suffix='.nii.gz') From 9863fe251231ed147cadaf9156d2af1dfbce76cc Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 19 Feb 2020 08:27:49 +0000 Subject: [PATCH 4/4] updating mean dice handler tests and docstrings --- monai/handlers/mean_dice.py | 8 +++++--- monai/metrics/compute_meandice.py | 7 +++++-- tests/integration_sliding_window.py | 5 ++--- tests/test_handler_mean_dice.py | 24 ++++++++++++++++++++++++ 4 files changed, 36 insertions(+), 8 deletions(-) diff --git a/monai/handlers/mean_dice.py b/monai/handlers/mean_dice.py index 0925d3d261..423137fc46 100644 --- a/monai/handlers/mean_dice.py +++ b/monai/handlers/mean_dice.py @@ -38,11 +38,13 @@ def __init__( Args: include_background (Bool): whether to include dice computation on the first channel of the predicted output. - to_onehot_y (Bool): whether to convert the output prediction into the one-hot format. - logit_thresh (Float): the threshold value to round value to 0.0 and 1.0. + Defaults to True. + to_onehot_y (Bool): whether to convert the output prediction into the one-hot format. Defaults to True. + logit_thresh (Float): the threshold value to round value to 0.0 and 1.0. Defaults to None (no thresholding). add_sigmoid (Bool): whether to add sigmoid function to the output prediction before computing Dice. + Defaults to False. mutually_exclusive (Bool): if True, the output prediction will be converted into a binary matrix using - a combination of argmax and to_onehot. + a combination of argmax and to_onehot. Defaults to True. output_transform (Callable): transform the ignite.engine.state.output into [y_pred, y] pair. device (torch.device): device specification in case of distributed computation usage. diff --git a/monai/metrics/compute_meandice.py b/monai/metrics/compute_meandice.py index 978738af7b..66aba387c1 100644 --- a/monai/metrics/compute_meandice.py +++ b/monai/metrics/compute_meandice.py @@ -46,8 +46,8 @@ def compute_meandice(y_pred, n_channels_y_pred = y_pred.shape[1] if mutually_exclusive: - if logit_thresh is not None: - raise ValueError('`logit_thresh` is incompatible when mutually_exclusive is True.') + if logit_thresh is not None or add_sigmoid: + raise ValueError('`logit_thresh` and `add_sigmoid` are incompatible when mutually_exclusive is True.') y_pred = torch.argmax(y_pred, dim=1, keepdim=True) y_pred = one_hot(y_pred, n_channels_y_pred) else: # channel-wise thresholding @@ -63,6 +63,9 @@ def compute_meandice(y_pred, y = y[:, 1:] if y.shape[1] > 1 else y y_pred = y_pred[:, 1:] if y_pred.shape[1] > 1 else y_pred + assert y.shape == y_pred.shape, ("Ground truth one-hot has differing shape (%r) from source (%r)" % + (y.shape, y_pred.shape)) + # reducing only spatial dimensions (not batch nor channels) reduce_axis = list(range(2, y_pred.dim())) intersection = torch.sum(y * y_pred, reduce_axis) diff --git a/tests/integration_sliding_window.py b/tests/integration_sliding_window.py index 8857c89bf7..32abbfff5c 100644 --- a/tests/integration_sliding_window.py +++ b/tests/integration_sliding_window.py @@ -28,7 +28,7 @@ from tests.utils import make_nifti_image -def run_test(batch_size=64, train_steps=100, device=torch.device("cuda:0")): +def run_test(batch_size=2, device=torch.device("cpu:0")): im, seg = create_test_image_3d(25, 28, 63, rad_max=10, noise_max=1, num_objs=4, num_seg_classes=1) input_shape = im.shape @@ -46,8 +46,7 @@ def run_test(batch_size=64, train_steps=100, device=torch.device("cuda:0")): num_res_units=2, ) roi_size = (16, 32, 48) - sw_batch_size = 2 - device = torch.device('cpu:0') + sw_batch_size = batch_size def _sliding_window_processor(_engine, batch): net.eval() diff --git a/tests/test_handler_mean_dice.py b/tests/test_handler_mean_dice.py index 380a81fa03..dfce516eec 100644 --- a/tests/test_handler_mean_dice.py +++ b/tests/test_handler_mean_dice.py @@ -19,6 +19,8 @@ TEST_CASE_1 = [{'to_onehot_y': True, 'mutually_exclusive': True}, 0.75] TEST_CASE_2 = [{'include_background': False, 'to_onehot_y': False, 'mutually_exclusive': False}, 0.8333333] +TEST_CASE_3 = [{'mutually_exclusive': True, 'add_sigmoid': True}] + class TestHandlerMeanDice(unittest.TestCase): # TODO test multi node averaged dice @@ -38,6 +40,28 @@ def test_compute(self, input_params, expected_avg): avg_dice = dice_metric.compute() self.assertAlmostEqual(avg_dice, expected_avg) + @parameterized.expand([TEST_CASE_3]) + def test_misconfig(self, input_params): + with self.assertRaisesRegex(ValueError, 'compatib'): + dice_metric = MeanDice(**input_params) + + y_pred = torch.Tensor([[0, 1], [1, 0]]) + y = torch.ones((2, 1)) + dice_metric.update([y_pred, y]) + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_shape_mismatch(self, input_params, _expected): + dice_metric = MeanDice(**input_params) + with self.assertRaises((AssertionError, ValueError)): + y_pred = torch.Tensor([[0, 1], [1, 0]]) + y = torch.ones((2, 3)) + dice_metric.update([y_pred, y]) + + with self.assertRaises((AssertionError, ValueError)): + y_pred = torch.Tensor([[0, 1], [1, 0]]) + y = torch.ones((3, 2)) + dice_metric.update([y_pred, y]) + if __name__ == '__main__': unittest.main()