diff --git a/monai/handlers/mean_dice.py b/monai/handlers/mean_dice.py index 73ac5e4173..423137fc46 100644 --- a/monai/handlers/mean_dice.py +++ b/monai/handlers/mean_dice.py @@ -27,10 +27,10 @@ class MeanDice(Metric): def __init__( self, include_background=True, - to_onehot_y=False, - logit_thresh=0.5, + to_onehot_y=True, + logit_thresh=None, add_sigmoid=False, - mutually_exclusive=False, + mutually_exclusive=True, output_transform: Callable = lambda x: x, device: Optional[Union[str, torch.device]] = None ): @@ -38,13 +38,18 @@ def __init__( Args: include_background (Bool): whether to include dice computation on the first channel of the predicted output. - to_onehot_y (Bool): whether to convert the output prediction into the one-hot format. - logit_thresh (Float): the threshold value to round value to 0.0 and 1.0, default is 0.5. + Defaults to True. + to_onehot_y (Bool): whether to convert the output prediction into the one-hot format. Defaults to True. + logit_thresh (Float): the threshold value to round value to 0.0 and 1.0. Defaults to None (no thresholding). add_sigmoid (Bool): whether to add sigmoid function to the output prediction before computing Dice. + Defaults to False. mutually_exclusive (Bool): if True, the output prediction will be converted into a binary matrix using - a combination of argmax and to_onehot. + a combination of argmax and to_onehot. Defaults to True. output_transform (Callable): transform the ignite.engine.state.output into [y_pred, y] pair. device (torch.device): device specification in case of distributed computation usage. + + See also: + monai.metrics.compute_meandice.compute_meandice """ super(MeanDice, self).__init__(output_transform, device=device) self.include_background = include_background diff --git a/monai/handlers/segmentation_saver.py b/monai/handlers/segmentation_saver.py index 31e6881678..a87e517f81 100644 --- a/monai/handlers/segmentation_saver.py +++ b/monai/handlers/segmentation_saver.py @@ -23,7 +23,7 @@ class SegmentationSaver: """ def __init__(self, output_path='./', dtype='float32', output_postfix='seg', output_ext='.nii.gz', - output_transform=lambda x: x): + output_transform=lambda x: x, name=None): """ Args: output_path (str): output image directory. @@ -34,6 +34,7 @@ def __init__(self, output_path='./', dtype='float32', output_postfix='seg', outp ignite.engine.output into the form expected nifti image data. The first dimension of this transform's output will be treated as the batch dimension. Each item in the batch will be saved individually. + name (str): identifier of logging.logger to use, defaulting to `engine.logger`. """ self.output_path = output_path self.dtype = dtype @@ -41,7 +42,11 @@ def __init__(self, output_path='./', dtype='float32', output_postfix='seg', outp self.output_ext = output_ext self.output_transform = output_transform + self.logger = None if name is None else logging.getLogger(name) + def attach(self, engine): + if self.logger is None: + self.logger = engine.logger return engine.add_event_handler(Events.ITERATION_COMPLETED, self) @staticmethod @@ -103,4 +108,4 @@ def __call__(self, engine): output_filename = self._create_file_basename(self.output_postfix, filename, self.output_path) output_filename = '{}{}'.format(output_filename, self.output_ext) write_nifti(seg_output, _affine, output_filename, _original_affine, dtype=seg_output.dtype) - print('saved: {}'.format(output_filename)) + self.logger.info('saved: {}'.format(output_filename)) diff --git a/monai/metrics/compute_meandice.py b/monai/metrics/compute_meandice.py index 4a10c03693..66aba387c1 100644 --- a/monai/metrics/compute_meandice.py +++ b/monai/metrics/compute_meandice.py @@ -27,6 +27,8 @@ def compute_meandice(y_pred, y_pred (torch.Tensor): input data to compute, typical segmentation model output. it must be One-Hot format and first dim is batch, example shape: [16, 3, 32, 32]. y (torch.Tensor): ground truth to compute mean dice metric, the first dim is batch. + example shape: [16, 3, 32, 32] for 3-class one-hot labels. + alternative shape: [16, 1, 32, 32] and set `to_onehot_y=True` to convert it into [16, 3, 32, 32]. include_background (Bool): whether to skip dice computation on the first channel of the predicted output. to_onehot_y (Bool): whether to convert `y` into the one-hot format. mutually_exclusive (Bool): if True, `y_pred` will be converted into a binary matrix using @@ -44,8 +46,8 @@ def compute_meandice(y_pred, n_channels_y_pred = y_pred.shape[1] if mutually_exclusive: - if logit_thresh is not None: - raise ValueError('`logit_thresh` is incompatible when mutually_exclusive is True.') + if logit_thresh is not None or add_sigmoid: + raise ValueError('`logit_thresh` and `add_sigmoid` are incompatible when mutually_exclusive is True.') y_pred = torch.argmax(y_pred, dim=1, keepdim=True) y_pred = one_hot(y_pred, n_channels_y_pred) else: # channel-wise thresholding @@ -61,6 +63,9 @@ def compute_meandice(y_pred, y = y[:, 1:] if y.shape[1] > 1 else y y_pred = y_pred[:, 1:] if y_pred.shape[1] > 1 else y_pred + assert y.shape == y_pred.shape, ("Ground truth one-hot has differing shape (%r) from source (%r)" % + (y.shape, y_pred.shape)) + # reducing only spatial dimensions (not batch nor channels) reduce_axis = list(range(2, y_pred.dim())) intersection = torch.sum(y * y_pred, reduce_axis) diff --git a/tests/integration_sliding_window.py b/tests/integration_sliding_window.py new file mode 100644 index 0000000000..32abbfff5c --- /dev/null +++ b/tests/integration_sliding_window.py @@ -0,0 +1,80 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import tempfile + +import nibabel as nib +import torch +from ignite.engine import Engine +from torch.utils.data import DataLoader + +from monai.data.nifti_reader import NiftiDataset +from monai.data.synthetic import create_test_image_3d +from monai.handlers.segmentation_saver import SegmentationSaver +from monai.networks.nets.unet import UNet +from monai.networks.utils import predict_segmentation +from monai.transforms.transforms import AddChannel +from monai.utils.sliding_window_inference import sliding_window_inference +from tests.utils import make_nifti_image + + +def run_test(batch_size=2, device=torch.device("cpu:0")): + + im, seg = create_test_image_3d(25, 28, 63, rad_max=10, noise_max=1, num_objs=4, num_seg_classes=1) + input_shape = im.shape + img_name = make_nifti_image(im) + seg_name = make_nifti_image(seg) + ds = NiftiDataset([img_name], [seg_name], transform=AddChannel(), seg_transform=AddChannel(), image_only=False) + loader = DataLoader(ds, batch_size=1, pin_memory=torch.cuda.is_available()) + + net = UNet( + dimensions=3, + in_channels=1, + num_classes=1, + channels=(4, 8, 16, 32), + strides=(2, 2, 2), + num_res_units=2, + ) + roi_size = (16, 32, 48) + sw_batch_size = batch_size + + def _sliding_window_processor(_engine, batch): + net.eval() + img, seg, meta_data = batch + with torch.no_grad(): + seg_probs = sliding_window_inference(img, roi_size, sw_batch_size, lambda x: net(x)[0], device) + return predict_segmentation(seg_probs) + + infer_engine = Engine(_sliding_window_processor) + + with tempfile.TemporaryDirectory() as temp_dir: + SegmentationSaver(output_path=temp_dir, output_ext='.nii.gz', output_postfix='seg').attach(infer_engine) + + infer_engine.run(loader) + + basename = os.path.basename(img_name)[:-len('.nii.gz')] + saved_name = os.path.join(temp_dir, basename, '{}_seg.nii.gz'.format(basename)) + testing_shape = nib.load(saved_name).get_fdata().shape + + if os.path.exists(img_name): + os.remove(img_name) + if os.path.exists(seg_name): + os.remove(seg_name) + + return testing_shape == input_shape + + +if __name__ == "__main__": + result = run_test() + + sys.exit(0 if result else 1) diff --git a/tests/test_handler_mean_dice.py b/tests/test_handler_mean_dice.py new file mode 100644 index 0000000000..dfce516eec --- /dev/null +++ b/tests/test_handler_mean_dice.py @@ -0,0 +1,67 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.handlers.mean_dice import MeanDice + +TEST_CASE_1 = [{'to_onehot_y': True, 'mutually_exclusive': True}, 0.75] +TEST_CASE_2 = [{'include_background': False, 'to_onehot_y': False, 'mutually_exclusive': False}, 0.8333333] + +TEST_CASE_3 = [{'mutually_exclusive': True, 'add_sigmoid': True}] + + +class TestHandlerMeanDice(unittest.TestCase): + # TODO test multi node averaged dice + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_compute(self, input_params, expected_avg): + dice_metric = MeanDice(**input_params) + + y_pred = torch.Tensor([[0, 1], [1, 0]]) + y = torch.ones((2, 1)) + dice_metric.update([y_pred, y]) + + y_pred = torch.Tensor([[0, 1], [1, 0]]) + y = torch.Tensor([[1.], [0.]]) + dice_metric.update([y_pred, y]) + + avg_dice = dice_metric.compute() + self.assertAlmostEqual(avg_dice, expected_avg) + + @parameterized.expand([TEST_CASE_3]) + def test_misconfig(self, input_params): + with self.assertRaisesRegex(ValueError, 'compatib'): + dice_metric = MeanDice(**input_params) + + y_pred = torch.Tensor([[0, 1], [1, 0]]) + y = torch.ones((2, 1)) + dice_metric.update([y_pred, y]) + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_shape_mismatch(self, input_params, _expected): + dice_metric = MeanDice(**input_params) + with self.assertRaises((AssertionError, ValueError)): + y_pred = torch.Tensor([[0, 1], [1, 0]]) + y = torch.ones((2, 3)) + dice_metric.update([y_pred, y]) + + with self.assertRaises((AssertionError, ValueError)): + y_pred = torch.Tensor([[0, 1], [1, 0]]) + y = torch.ones((3, 2)) + dice_metric.update([y_pred, y]) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_handler_stats.py b/tests/test_handler_stats.py new file mode 100644 index 0000000000..5bbe17d1c2 --- /dev/null +++ b/tests/test_handler_stats.py @@ -0,0 +1,61 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re +import unittest +from io import StringIO + +from ignite.engine import Engine, Events + +from monai.handlers.stats_handler import StatsHandler + + +class TestHandlerStats(unittest.TestCase): + + def test_metrics_print(self): + log_stream = StringIO() + logging.basicConfig(stream=log_stream, level=logging.INFO) + key_to_handler = 'test_logging' + key_to_print = 'testing_metric' + + # set up engine + def _train_func(engine, batch): + pass + + engine = Engine(_train_func) + + # set up dummy metric + @engine.on(Events.ITERATION_COMPLETED) + def _update_metric(engine): + current_metric = engine.state.metrics.get(key_to_print, 0.1) + engine.state.metrics[key_to_print] = current_metric + 0.1 + + # set up testing handler + stats_handler = StatsHandler(name=key_to_handler) + stats_handler.attach(engine) + + engine.run(range(3), max_epochs=2) + + # check logging output + output_str = log_stream.getvalue() + grep = re.compile('.*{}.*'.format(key_to_handler)) + has_key_word = re.compile('.*{}.*'.format(key_to_print)) + matched = [] + for idx, line in enumerate(output_str.split('\n')): + if grep.match(line): + self.assertTrue(has_key_word.match(line)) + matched.append(idx) + self.assertEqual(matched, [1, 2, 3, 5, 6, 7, 8, 10]) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/utils.py b/tests/utils.py index 27f75fe7c3..b2e4e6d743 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -28,11 +28,13 @@ def skip_if_quick(obj): return unittest.skipIf(is_quick, "Skipping slow tests")(obj) -def make_nifti_image(array, affine): +def make_nifti_image(array, affine=None): """ Create a temporary nifti image on the disk and return the image name. User is responsible for deleting the temporary file when done with it. """ + if affine is None: + affine = np.eye(4) test_image = nib.Nifti1Image(array, affine) _, image_name = tempfile.mkstemp(suffix='.nii.gz')