From 9b73bf6aea76a3c8e961fff2783ef2adf700440e Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 11 Mar 2020 20:07:51 +0800 Subject: [PATCH 1/7] [DLMED] add 3D classification inference examples --- ...py => densenet_classification_3d_array.py} | 37 ++--- examples/densenet_classification_3d_dict.py | 147 ++++++++++++++++++ examples/densenet_inference_3d_array.py | 93 +++++++++++ examples/densenet_inference_3d_dict.py | 95 +++++++++++ examples/unet_inference_3d_array.py | 6 +- examples/unet_inference_3d_dict.py | 2 +- examples/unet_segmentation_3d_array.py | 16 +- examples/unet_segmentation_3d_dict.py | 12 +- monai/handlers/classification_saver.py | 91 +++++++++++ monai/transforms/composables.py | 59 ++++++- monai/transforms/transforms.py | 28 ++-- tests/test_resize.py | 21 +-- 12 files changed, 545 insertions(+), 62 deletions(-) rename examples/{densenet_classification_3d.py => densenet_classification_3d_array.py} (82%) create mode 100644 examples/densenet_classification_3d_dict.py create mode 100644 examples/densenet_inference_3d_array.py create mode 100644 examples/densenet_inference_3d_dict.py create mode 100644 monai/handlers/classification_saver.py diff --git a/examples/densenet_classification_3d.py b/examples/densenet_classification_3d_array.py similarity index 82% rename from examples/densenet_classification_3d.py rename to examples/densenet_classification_3d_array.py index 07ac3ffe04..10cc0b2689 100644 --- a/examples/densenet_classification_3d.py +++ b/examples/densenet_classification_3d_array.py @@ -21,16 +21,15 @@ sys.path.append("..") import monai import monai.transforms.compose as transforms - from monai.data.nifti_reader import NiftiDataset -from monai.transforms import (AddChannel, Rescale, ToTensor, UniformRandomPatch) +from monai.transforms import (AddChannel, Rescale, Resize, RandRotate90) from monai.handlers.stats_handler import StatsHandler from ignite.metrics import Accuracy from monai.handlers.utils import stopping_fn_from_metric monai.config.print_config() -# FIXME: temp test dataset, Wenqi will replace later +# demo dataset, user can easily change to own dataset images = [ "/workspace/data/medical/ixi/IXI-T1/IXI314-IOP-0889-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI249-Guys-1072-T1.nii.gz", @@ -57,18 +56,23 @@ 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0 ]) -# Define transforms for image and segmentation -imtrans = transforms.Compose([ +# Define transforms for image +train_transforms = transforms.Compose([ Rescale(), AddChannel(), - UniformRandomPatch((96, 96, 96)), - ToTensor() + Resize((96, 96, 96)), + RandRotate90() +]) +val_transforms = transforms.Compose([ + Rescale(), + AddChannel(), + Resize((96, 96, 96)) ]) # Define nifti dataset, dataloader. -ds = NiftiDataset(image_files=images, labels=labels, transform=imtrans) -loader = DataLoader(ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available()) -im, label = monai.utils.misc.first(loader) +check_ds = NiftiDataset(image_files=images, labels=labels, transform=train_transforms) +check_loader = DataLoader(check_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available()) +im, label = monai.utils.misc.first(check_loader) print(type(im), im.shape, label) lr = 1e-5 @@ -84,7 +88,8 @@ # Create trainer device = torch.device("cuda:0") -trainer = create_supervised_trainer(net, opt, loss, device, False) +trainer = create_supervised_trainer(net, opt, loss, device, False, + output_transform=lambda x, y, y_pred, loss: [y_pred, loss.item(), y]) # adding checkpoint handler to save models (network params and optimizer stats) during training checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False) @@ -94,10 +99,6 @@ train_stats_handler = StatsHandler() train_stats_handler.attach(trainer) -@trainer.on(Events.EPOCH_COMPLETED) -def log_training_loss(engine): - engine.logger.info("Epoch[%s] Loss: %s", engine.state.epoch, engine.state.output) - # Set parameters for validation validation_every_n_epochs = 1 metric_name = 'Accuracy' @@ -118,8 +119,8 @@ def log_training_loss(engine): evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) # create a validation data loader -val_ds = NiftiDataset(image_files=images[-5:], labels=labels[-5:], transform=imtrans) -val_loader = DataLoader(ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available()) +val_ds = NiftiDataset(image_files=images[-10:], labels=labels[-10:], transform=val_transforms) +val_loader = DataLoader(val_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available()) @trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) @@ -129,7 +130,7 @@ def run_validation(engine): # create a training data loader logging.basicConfig(stream=sys.stdout, level=logging.INFO) -train_ds = NiftiDataset(image_files=images[:15], labels=labels[:15], transform=imtrans) +train_ds = NiftiDataset(image_files=images[:10], labels=labels[:10], transform=train_transforms) train_loader = DataLoader(train_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available()) train_epochs = 30 diff --git a/examples/densenet_classification_3d_dict.py b/examples/densenet_classification_3d_dict.py new file mode 100644 index 0000000000..1be49de05d --- /dev/null +++ b/examples/densenet_classification_3d_dict.py @@ -0,0 +1,147 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import logging +import numpy as np +import torch +from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator, _prepare_batch +from ignite.handlers import ModelCheckpoint, EarlyStopping +from torch.utils.data import DataLoader + +# assumes the framework is found here, change as necessary +sys.path.append("..") +import monai +import monai.transforms.compose as transforms +from monai.transforms.composables import \ + LoadNiftid, AddChanneld, Rescaled, Resized, RandRotate90d +from monai.handlers.stats_handler import StatsHandler +from ignite.metrics import Accuracy +from monai.handlers.utils import stopping_fn_from_metric + +monai.config.print_config() + +# demo dataset, user can easily change to own dataset +images = [ + "/workspace/data/medical/ixi/IXI-T1/IXI314-IOP-0889-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI249-Guys-1072-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI609-HH-2600-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI173-HH-1590-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI020-Guys-0700-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI342-Guys-0909-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI134-Guys-0780-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI577-HH-2661-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI066-Guys-0731-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI130-HH-1528-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz" +] +labels = np.array([ + 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0 +]) +train_files = [{'img': img, 'label': label} for img, label in zip(images[:10], labels[:10])] +val_files = [{'img': img, 'label': label} for img, label in zip(images[-10:], labels[-10:])] + +# Define transforms for image +train_transforms = transforms.Compose([ + LoadNiftid(keys=['img']), + AddChanneld(keys=['img']), + Rescaled(keys=['img']), + Resized(keys=['img'], output_shape=(96, 96, 96)), + RandRotate90d(keys=['img'], prob=0.8, axes=[1, 3]) +]) +val_transforms = transforms.Compose([ + LoadNiftid(keys=['img']), + AddChanneld(keys=['img']), + Rescaled(keys=['img']), + Resized(keys=['img'], output_shape=(96, 96, 96)) +]) + +# Define dataset, dataloader. +check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) +check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) +check_data = monai.utils.misc.first(check_loader) +print(check_data['img'].shape, check_data['label']) + +lr = 1e-5 + +# Create DenseNet121, CrossEntropyLoss and Adam optimizer. +net = monai.networks.nets.densenet3d.densenet121( + in_channels=1, + out_channels=2, +) + +loss = torch.nn.CrossEntropyLoss() +opt = torch.optim.Adam(net.parameters(), lr) + + +# Create trainer +def prepare_batch(batch, device=None, non_blocking=False): + return _prepare_batch((batch['img'], batch['label']), device, non_blocking) + + +# Create trainer +device = torch.device("cuda:0") +trainer = create_supervised_trainer(net, opt, loss, device, False, prepare_batch=prepare_batch, + output_transform=lambda x, y, y_pred, loss: [y_pred, loss.item(), y]) + +# adding checkpoint handler to save models (network params and optimizer stats) during training +checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False) +trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, + handler=checkpoint_handler, + to_save={'net': net, 'opt': opt}) +train_stats_handler = StatsHandler() +train_stats_handler.attach(trainer) + +# Set parameters for validation +validation_every_n_epochs = 1 +metric_name = 'Accuracy' + +# add evaluation metric to the evaluator engine +val_metrics = {metric_name: Accuracy()} +evaluator = create_supervised_evaluator(net, val_metrics, device, True, prepare_batch=prepare_batch) + +# Add stats event handler to print validation stats via evaluator +logging.basicConfig(stream=sys.stdout, level=logging.INFO) +val_stats_handler = StatsHandler() +val_stats_handler.attach(evaluator) + +# Add early stopping handler to evaluator. +early_stopper = EarlyStopping(patience=4, + score_function=stopping_fn_from_metric(metric_name), + trainer=trainer) +evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) + +# create a validation data loader +val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) +val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) + + +@trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) +def run_validation(engine): + evaluator.run(val_loader) + +# create a training data loader +logging.basicConfig(stream=sys.stdout, level=logging.INFO) + +train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) +train_loader = DataLoader(train_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) + +train_epochs = 30 +state = trainer.run(train_loader, train_epochs) diff --git a/examples/densenet_inference_3d_array.py b/examples/densenet_inference_3d_array.py new file mode 100644 index 0000000000..25908bb821 --- /dev/null +++ b/examples/densenet_inference_3d_array.py @@ -0,0 +1,93 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import logging +import numpy as np +import torch +from ignite.engine import create_supervised_evaluator, _prepare_batch +from torch.utils.data import DataLoader + +# assumes the framework is found here, change as necessary +sys.path.append("..") +import monai +import monai.transforms.compose as transforms +from monai.data.nifti_reader import NiftiDataset +from monai.transforms import (AddChannel, Rescale, Resize) +from monai.handlers.stats_handler import StatsHandler +from monai.handlers.classification_saver import ClassificationSaver +from monai.handlers.checkpoint_loader import CheckpointLoader +from ignite.metrics import Accuracy + +monai.config.print_config() + +# demo dataset, user can easily change to own dataset +images = [ + "/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz" +] +labels = np.array([ + 0, 0, 1, 0, 1, 0, 1, 0, 1, 0 +]) + +# Define transforms for image +val_transforms = transforms.Compose([ + Rescale(), + AddChannel(), + Resize((96, 96, 96)) +]) +# Define nifti dataset +val_ds = NiftiDataset(image_files=images, labels=labels, transform=val_transforms, image_only=False) +# Create DenseNet121 +net = monai.networks.nets.densenet3d.densenet121( + in_channels=1, + out_channels=2, +) + +device = torch.device("cuda:0") +metric_name = 'Accuracy' + +# add evaluation metric to the evaluator engine +val_metrics = {metric_name: Accuracy()} + + +def prepare_batch(batch, device=None, non_blocking=False): + return _prepare_batch((batch[0], batch[1]), device, non_blocking) + + +evaluator = create_supervised_evaluator(net, val_metrics, device, True, prepare_batch=prepare_batch) + +# Add stats event handler to print validation stats via evaluator +logging.basicConfig(stream=sys.stdout, level=logging.INFO) +val_stats_handler = StatsHandler() +val_stats_handler.attach(evaluator) + +# for the arrary data format, assume the 3rd item of batch data is the meta_data +prediction_saver = ClassificationSaver(output_dir='tempdir', batch_transform=lambda batch: batch[2], + output_transform=lambda output: output[0].argmax(1)) +prediction_saver.attach(evaluator) + +# the model was trained by "densenet_classification_3d_array" exmple +CheckpointLoader(load_path='./runs/net_checkpoint_40.pth', load_dict={'net': net}).attach(evaluator) + +# create a validation data loader +val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) + +state = evaluator.run(val_loader) +prediction_saver.finalize() diff --git a/examples/densenet_inference_3d_dict.py b/examples/densenet_inference_3d_dict.py new file mode 100644 index 0000000000..8f7dc75c72 --- /dev/null +++ b/examples/densenet_inference_3d_dict.py @@ -0,0 +1,95 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ignite.metrics import Accuracy +import sys +import logging +import numpy as np +import torch +from ignite.engine import create_supervised_evaluator, _prepare_batch +from torch.utils.data import DataLoader + +# assumes the framework is found here, change as necessary +sys.path.append("..") +from monai.handlers.classification_saver import ClassificationSaver +from monai.handlers.checkpoint_loader import CheckpointLoader +from monai.handlers.stats_handler import StatsHandler +from monai.transforms.composables import LoadNiftid, AddChanneld, Rescaled, Resized +import monai.transforms.compose as transforms +import monai + +monai.config.print_config() + +# demo dataset, user can easily change to own dataset +images = [ + "/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz" +] +labels = np.array([ + 0, 0, 1, 0, 1, 0, 1, 0, 1, 0 +]) +val_files = [{'img': img, 'label': label} for img, label in zip(images, labels)] + +# Define transforms for image +val_transforms = transforms.Compose([ + LoadNiftid(keys=['img']), + AddChanneld(keys=['img']), + Rescaled(keys=['img']), + Resized(keys=['img'], output_shape=(96, 96, 96)) +]) + +# Create DenseNet121 +net = monai.networks.nets.densenet3d.densenet121( + in_channels=1, + out_channels=2, +) + + +def prepare_batch(batch, device=None, non_blocking=False): + return _prepare_batch((batch['img'], batch['label']), device, non_blocking) + + +# Create trainer +device = torch.device("cuda:0") +metric_name = 'Accuracy' + +# add evaluation metric to the evaluator engine +val_metrics = {metric_name: Accuracy()} +evaluator = create_supervised_evaluator(net, val_metrics, device, True, prepare_batch=prepare_batch) + +# Add stats event handler to print validation stats via evaluator +logging.basicConfig(stream=sys.stdout, level=logging.INFO) +val_stats_handler = StatsHandler() +val_stats_handler.attach(evaluator) + +# for the arrary data format, assume the 3rd item of batch data is the meta_data +prediction_saver = ClassificationSaver(output_dir='tempdir', batch_transform=lambda batch: { + 'filename_or_obj': batch['img.filename_or_obj']}, + output_transform=lambda output: output[0].argmax(1)) +prediction_saver.attach(evaluator) + +# the model was trained by "densenet_classification_3d_dict" exmple +CheckpointLoader(load_path='./runs/net_checkpoint_40.pth', load_dict={'net': net}).attach(evaluator) + +# create a validation data loader +val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) +val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) + +state = evaluator.run(val_loader) +prediction_saver.finalize() diff --git a/examples/unet_inference_3d_array.py b/examples/unet_inference_3d_array.py index 8fe417c7dd..ea87590285 100644 --- a/examples/unet_inference_3d_array.py +++ b/examples/unet_inference_3d_array.py @@ -28,7 +28,7 @@ from monai.handlers.segmentation_saver import SegmentationSaver import monai.transforms.compose as transforms from monai.data.nifti_reader import NiftiDataset -from monai.transforms import AddChannel, Rescale, ToTensor +from monai.transforms import AddChannel, Rescale from monai.networks.nets.unet import UNet from monai.networks.utils import predict_segmentation from monai.data.synthetic import create_test_image_3d @@ -50,8 +50,8 @@ images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz'))) segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) -imtrans = transforms.Compose([Rescale(), AddChannel(), ToTensor()]) -segtrans = transforms.Compose([AddChannel(), ToTensor()]) +imtrans = transforms.Compose([Rescale(), AddChannel()]) +segtrans = transforms.Compose([AddChannel()]) ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False) device = torch.device("cuda:0") diff --git a/examples/unet_inference_3d_dict.py b/examples/unet_inference_3d_dict.py index 405b49aa8d..7274453b22 100644 --- a/examples/unet_inference_3d_dict.py +++ b/examples/unet_inference_3d_dict.py @@ -87,7 +87,7 @@ def _sliding_window_processor(engine, batch): 'original_affine': batch['img.original_affine'], 'affine': batch['img.affine'], }).attach(infer_engine) -# the model was trained by "unet_segmentation_3d_array" exmple +# the model was trained by "unet_segmentation_3d_dict" exmple CheckpointLoader(load_path='./runs/net_checkpoint_120.pth', load_dict={'net': net}).attach(infer_engine) val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate, diff --git a/examples/unet_segmentation_3d_array.py b/examples/unet_segmentation_3d_array.py index de5996470d..3e1a819e79 100644 --- a/examples/unet_segmentation_3d_array.py +++ b/examples/unet_segmentation_3d_array.py @@ -30,7 +30,7 @@ import monai.transforms.compose as transforms from monai.data.nifti_reader import NiftiDataset -from monai.transforms import AddChannel, Rescale, ToTensor, UniformRandomPatch +from monai.transforms import AddChannel, Rescale, UniformRandomPatch from monai.handlers.stats_handler import StatsHandler from monai.handlers.mean_dice import MeanDice from monai.visualize import img2tensorboard @@ -58,19 +58,17 @@ imtrans = transforms.Compose([ Rescale(), AddChannel(), - UniformRandomPatch((96, 96, 96)), - ToTensor() + UniformRandomPatch((96, 96, 96)) ]) segtrans = transforms.Compose([ AddChannel(), - UniformRandomPatch((96, 96, 96)), - ToTensor() + UniformRandomPatch((96, 96, 96)) ]) # Define nifti dataset, dataloader. -ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans) -loader = DataLoader(ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) -im, seg = monai.utils.misc.first(loader) +check_ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans) +check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) +im, seg = monai.utils.misc.first(check_loader) print(im.shape, seg.shape) lr = 1e-5 @@ -159,7 +157,7 @@ def log_training_loss(engine): # create a validation data loader val_ds = NiftiDataset(images[-20:], segs[-20:], transform=imtrans, seg_transform=segtrans) -val_loader = DataLoader(ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available()) +val_loader = DataLoader(val_ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available()) @trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) diff --git a/examples/unet_segmentation_3d_dict.py b/examples/unet_segmentation_3d_dict.py index 640bed21c0..8eb2a76c73 100644 --- a/examples/unet_segmentation_3d_dict.py +++ b/examples/unet_segmentation_3d_dict.py @@ -68,11 +68,11 @@ AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1) ]) -# Define nifti dataset, dataloader. -ds = monai.data.Dataset(data=train_files, transform=train_transforms) -loader = DataLoader(ds, batch_size=2, num_workers=4, collate_fn=list_data_collate, - pin_memory=torch.cuda.is_available()) -check_data = monai.utils.misc.first(loader) +# Define dataset, dataloader. +check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) +check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, collate_fn=list_data_collate, + pin_memory=torch.cuda.is_available()) +check_data = monai.utils.misc.first(check_loader) print(check_data['img'].shape, check_data['seg'].shape) lr = 1e-5 @@ -171,7 +171,7 @@ def log_training_loss(engine): # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) -val_loader = DataLoader(ds, batch_size=5, num_workers=8, collate_fn=list_data_collate, +val_loader = DataLoader(val_ds, batch_size=5, num_workers=8, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()) diff --git a/monai/handlers/classification_saver.py b/monai/handlers/classification_saver.py new file mode 100644 index 0000000000..48f0ac59bb --- /dev/null +++ b/monai/handlers/classification_saver.py @@ -0,0 +1,91 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import csv +import numpy as np +import torch +from ignite.engine import Events + + +class ClassificationSaver: + """ + Event handler triggered on completing every iteration to save the classification predictions as CSV file. + """ + + def __init__(self, output_dir='./', overwrite=True, + batch_transform=lambda x: x, output_transform=lambda x: x, name=None): + """ + Args: + output_dir (str): output CSV file directory. + overwrite (bool): whether to overwriting existing CSV file content. If we are not overwriting, + then we check if the results have been previously saved, and load them to the prediction_dict. + batch_transform (Callable): a callable that is used to transform the + ignite.engine.batch into expected format to extract the meta_data dictionary. + output_transform (Callable): a callable that is used to transform the + ignite.engine.output into the form expected model prediction data. + The first dimension of this transform's output will be treated as the + batch dimension. Each item in the batch will be saved individually. + name (str): identifier of logging.logger to use, defaulting to `engine.logger`. + + """ + self.output_dir = output_dir + self._prediction_dict = {} + self._preds_filepath = os.path.join(output_dir, 'predictions.csv') + if not overwrite: + if os.path.exists(self._preds_filepath): + with open(self._path, 'r') as f: + reader = csv.reader(f) + for row in reader: + self._prediction_dict[row[0]] = np.array(row[1:]).astype(np.float32) + + self.batch_transform = batch_transform + self.output_transform = output_transform + + self.logger = None if name is None else logging.getLogger(name) + + def attach(self, engine): + if self.logger is None: + self.logger = engine.logger + return engine.add_event_handler(Events.ITERATION_COMPLETED, self) + + def finalize(self): + """ + Writes the prediction dict to a csv + + """ + if not os.path.exists(self.output_dir): + os.makedirs(self.output_dir) + with open(self._preds_filepath, 'w') as f: + for k, v in sorted(self._prediction_dict.items()): + f.write(k) + for result in v.flatten(): + f.write("," + str(result)) + f.write("\n") + self.logger.info('saved classification predictions into: {}'.format(self._preds_filepath)) + + def __call__(self, engine): + """ + This method assumes self.batch_transform will extract Metadata from the input batch. + Metadata should have the following keys: + + - ``'filename_or_obj'`` -- save the prediction corresponding to file name. + + """ + meta_data = self.batch_transform(engine.state.batch) + filenames = meta_data['filename_or_obj'] + + engine_output = self.output_transform(engine.state.output) + for batch_id, filename in enumerate(filenames): # save a batch of files + output = engine_output[batch_id] + if isinstance(output, torch.Tensor): + output = output.detach().cpu().numpy() + self._prediction_dict[filename] = output diff --git a/monai/transforms/composables.py b/monai/transforms/composables.py index 0e404c5233..8422077410 100644 --- a/monai/transforms/composables.py +++ b/monai/transforms/composables.py @@ -14,12 +14,12 @@ """ from collections.abc import Hashable - +import numpy as np import monai from monai.data.utils import get_random_patch, get_valid_patch_size from monai.transforms.compose import Randomizable, Transform -from monai.transforms.transforms import (LoadNifti, AsChannelFirst, Orientation, - AddChannel, Spacing, Rotate90, SpatialCrop) +from monai.transforms.transforms import (LoadNifti, AsChannelFirst, Orientation, AddChannel, + Spacing, Rotate90, Rescale, Resize, SpatialCrop) from monai.utils.misc import ensure_tuple from monai.transforms.utils import generate_pos_neg_label_crop_centers from monai.utils.aliases import alias @@ -245,6 +245,59 @@ def __call__(self, data): return d +@export +@alias('RescaleD', 'RescaleDict') +class Rescaled(MapTransform): + """ + dictionary-based wrapper of Rescale. + """ + + def __init__(self, keys, minv=0.0, maxv=1.0, dtype=np.float32): + MapTransform.__init__(self, keys) + self.rescaler = Rescale(minv, maxv, dtype) + + def __call__(self, data): + d = dict(data) + for key in self.keys: + d[key] = self.rescaler(d[key]) + return d + + +@export +@alias('ResizeD', 'ResizeDict') +class Resized(MapTransform): + """ + dictionary-based wrapper of Resize. + Args: + keys (hashable items): keys of the corresponding items to be transformed. + See also: monai.transform.composables.MapTransform + output_shape (tuple or list): expected shape after resize operation. + order (int): Order of spline interpolation. Default=1. + mode (str): Points outside boundaries are filled according to given mode. + Options are 'constant', 'edge', 'symmetric', 'reflect', 'wrap'. + cval (float): Used with mode 'constant', the value outside image boundaries. + clip (bool): Wheter to clip range of output values after interpolation. Default: True. + preserve_range (bool): Whether to keep original range of values. Default is True. + If False, input is converted according to conventions of img_as_float. See + https://scikit-image.org/docs/dev/user_guide/data_types.html. + anti_aliasing (bool): Whether to apply a gaussian filter to image before down-scaling. + Default is True. + anti_aliasing_sigma (float, tuple of floats): Standard deviation for gaussian filtering. + """ + + def __init__(self, keys, output_shape, order=1, mode='reflect', cval=0, + clip=True, preserve_range=True, anti_aliasing=True, anti_aliasing_sigma=None): + MapTransform.__init__(self, keys) + self.resizer = Resize(output_shape, order, mode, cval, clip, preserve_range, + anti_aliasing, anti_aliasing_sigma) + + def __call__(self, data): + d = dict(data) + for key in self.keys: + d[key] = self.resizer(d[key]) + return d + + @export @alias('UniformRandomPatchD', 'UniformRandomPatchDict') class UniformRandomPatchd(Randomizable, MapTransform): diff --git a/monai/transforms/transforms.py b/monai/transforms/transforms.py index a2352e5db8..06ce58d9fb 100644 --- a/monai/transforms/transforms.py +++ b/monai/transforms/transforms.py @@ -290,6 +290,7 @@ class Resize: For additional details, see https://scikit-image.org/docs/dev/api/skimage.transform.html#skimage.transform.resize. Args: + output_shape (tuple or list): expected shape after resize operation. order (int): Order of spline interpolation. Default=1. mode (str): Points outside boundaries are filled according to given mode. Options are 'constant', 'edge', 'symmetric', 'reflect', 'wrap'. @@ -316,11 +317,16 @@ def __init__(self, output_shape, order=1, mode='reflect', cval=0, self.anti_aliasing_sigma = anti_aliasing_sigma def __call__(self, img): - return resize(img, self.output_shape, order=self.order, - mode=self.mode, cval=self.cval, - clip=self.clip, preserve_range=self.preserve_range, - anti_aliasing=self.anti_aliasing, - anti_aliasing_sigma=self.anti_aliasing_sigma) + resized = list() + for channel in img: + resized.append( + resize(channel, self.output_shape, order=self.order, + mode=self.mode, cval=self.cval, + clip=self.clip, preserve_range=self.preserve_range, + anti_aliasing=self.anti_aliasing, + anti_aliasing_sigma=self.anti_aliasing_sigma) + ) + return np.stack(resized).astype(np.float32) @export @@ -354,7 +360,7 @@ def __init__(self, angle, axes=(1, 2), reshape=True, order=1, mode='constant', c def __call__(self, img): return scipy.ndimage.rotate(img, self.angle, self.axes, reshape=self.reshape, order=self.order, mode=self.mode, cval=self.cval, - prefilter=self.prefilter) + prefilter=self.prefilter).astype(np.float32) @export @@ -421,7 +427,7 @@ def __call__(self, img): mode=self.mode, cval=self.cval, prefilter=self.prefilter)) - zoomed = np.stack(zoomed) + zoomed = np.stack(zoomed).astype(np.float32) if not self.keep_size or np.allclose(img.shape, zoomed.shape): return zoomed @@ -482,13 +488,12 @@ class IntensityNormalizer: dtype: output data format """ - def __init__(self, subtrahend=None, divisor=None, dtype=np.float32): + def __init__(self, subtrahend=None, divisor=None): if subtrahend is not None or divisor is not None: assert isinstance(subtrahend, np.ndarray) and isinstance(divisor, np.ndarray), \ 'subtrahend and divisor must be set in pair and in numpy array.' self.subtrahend = subtrahend self.divisor = divisor - self.dtype = dtype def __call__(self, img): if self.subtrahend is not None and self.divisor is not None: @@ -498,8 +503,6 @@ def __call__(self, img): img -= np.mean(img) img /= np.std(img) - if self.dtype != img.dtype: - img = img.astype(self.dtype) return img @@ -515,12 +518,11 @@ class ImageEndPadder: dtype: output data format. """ - def __init__(self, out_size, mode, dtype=np.float32): + def __init__(self, out_size, mode): assert out_size is not None and isinstance(out_size, (list, tuple)), 'out_size must be list or tuple' self.out_size = out_size assert isinstance(mode, str), 'mode must be str' self.mode = mode - self.dtype = dtype def _determine_data_pad_width(self, data_shape): return [(0, max(self.out_size[i] - data_shape[i], 0)) for i in range(len(self.out_size))] diff --git a/tests/test_resize.py b/tests/test_resize.py index 7feaf9f634..7a3d327b28 100644 --- a/tests/test_resize.py +++ b/tests/test_resize.py @@ -30,9 +30,9 @@ def test_invalid_inputs(self, _, order, raises): resize(self.imt) @parameterized.expand([ - ((1, 1, 64, 64), 1, 'reflect', 0, True, True, True, None), - ((1, 1, 32, 32), 2, 'constant', 3, False, False, False, None), - ((1, 1, 256, 256), 3, 'constant', 3, False, False, False, None), + ((64, 64), 1, 'reflect', 0, True, True, True, None), + ((32, 32), 2, 'constant', 3, False, False, False, None), + ((256, 256), 3, 'constant', 3, False, False, False, None), ]) def test_correct_results(self, output_shape, order, mode, cval, clip, preserve_range, @@ -40,12 +40,15 @@ def test_correct_results(self, output_shape, order, mode, resize = Resize(output_shape, order, mode, cval, clip, preserve_range, anti_aliasing, anti_aliasing_sigma) - expected = skimage.transform.resize(self.imt, output_shape, - order=order, mode=mode, - cval=cval, clip=clip, - preserve_range=preserve_range, - anti_aliasing=anti_aliasing, - anti_aliasing_sigma=anti_aliasing_sigma) + expected = list() + for channel in self.imt: + expected.append(skimage.transform.resize(channel, output_shape, + order=order, mode=mode, + cval=cval, clip=clip, + preserve_range=preserve_range, + anti_aliasing=anti_aliasing, + anti_aliasing_sigma=anti_aliasing_sigma)) + expected = np.stack(expected).astype(np.float32) self.assertTrue(np.allclose(resize(self.imt), expected)) From 84e5c06a457e0a1e177994158051f4455cb5dbb1 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 11 Mar 2020 22:30:28 +0800 Subject: [PATCH 2/7] [DLMED] change UNet to 1 output --- monai/networks/nets/unet.py | 3 +-- tests/integration_sliding_window.py | 2 +- tests/integration_unet2d.py | 2 +- tests/test_unet.py | 8 ++++---- 4 files changed, 7 insertions(+), 8 deletions(-) diff --git a/monai/networks/nets/unet.py b/monai/networks/nets/unet.py index b0d42612eb..6018cd06e8 100644 --- a/monai/networks/nets/unet.py +++ b/monai/networks/nets/unet.py @@ -13,7 +13,6 @@ from monai.networks.blocks.convolutions import Convolution, ResidualUnit from monai.networks.layers.simplelayers import SkipConnection -from monai.networks.utils import predict_segmentation from monai.utils import export from monai.utils.aliases import alias @@ -98,4 +97,4 @@ def _get_up_layer(self, in_channels, out_channels, strides, is_top): def forward(self, x): x = self.model(x) - return x, predict_segmentation(x) + return x diff --git a/tests/integration_sliding_window.py b/tests/integration_sliding_window.py index db10d7cc49..8025a4c821 100644 --- a/tests/integration_sliding_window.py +++ b/tests/integration_sliding_window.py @@ -52,7 +52,7 @@ def _sliding_window_processor(_engine, batch): net.eval() img, seg, meta_data = batch with torch.no_grad(): - seg_probs = sliding_window_inference(img, roi_size, sw_batch_size, lambda x: net(x)[0], device) + seg_probs = sliding_window_inference(img, roi_size, sw_batch_size, lambda x: net(x), device) return predict_segmentation(seg_probs) infer_engine = Engine(_sliding_window_processor) diff --git a/tests/integration_unet2d.py b/tests/integration_unet2d.py index 1fd9074c66..ed1afa9872 100644 --- a/tests/integration_unet2d.py +++ b/tests/integration_unet2d.py @@ -46,7 +46,7 @@ def __len__(self): src = DataLoader(_TestBatch(), batch_size=batch_size) def loss_fn(pred, grnd): - return loss(pred[0], grnd) + return loss(pred, grnd) trainer = create_supervised_trainer(net, opt, loss_fn, device, False) diff --git a/tests/test_unet.py b/tests/test_unet.py index 98102375a6..d64407bb4b 100644 --- a/tests/test_unet.py +++ b/tests/test_unet.py @@ -26,7 +26,7 @@ 'num_res_units': 1, }, torch.randn(16, 1, 32, 32), - (16, 32, 32), + (16, 3, 32, 32), ] TEST_CASE_2 = [ # single channel 3D, batch 16 @@ -39,7 +39,7 @@ 'num_res_units': 1, }, torch.randn(16, 1, 32, 24, 48), - (16, 32, 24, 48), + (16, 3, 32, 24, 48), ] TEST_CASE_3 = [ # 4-channel 3D, batch 16 @@ -52,7 +52,7 @@ 'num_res_units': 1, }, torch.randn(16, 4, 32, 64, 48), - (16, 32, 64, 48), + (16, 3, 32, 64, 48), ] @@ -63,7 +63,7 @@ def test_shape(self, input_param, input_data, expected_shape): net = UNet(**input_param) net.eval() with torch.no_grad(): - result = net.forward(input_data)[1] + result = net.forward(input_data) self.assertEqual(result.shape, expected_shape) From f0d4ebfe3fd0e7b412e9fb7317f8bd1354bfff9a Mon Sep 17 00:00:00 2001 From: root Date: Thu, 12 Mar 2020 02:39:45 +0000 Subject: [PATCH 3/7] [DLMED] add 3D classification inference examples --- ...py => densenet_classification_3d_array.py} | 36 +++-- examples/densenet_classification_3d_dict.py | 147 ++++++++++++++++++ examples/densenet_inference_3d_array.py | 93 +++++++++++ examples/densenet_inference_3d_dict.py | 95 +++++++++++ examples/unet_inference_3d_array.py | 6 +- examples/unet_inference_3d_dict.py | 2 +- examples/unet_segmentation_3d_array.py | 16 +- examples/unet_segmentation_3d_dict.py | 12 +- monai/handlers/classification_saver.py | 91 +++++++++++ monai/transforms/composables.py | 56 ++++++- monai/transforms/transforms.py | 34 ++-- tests/test_dice_loss.py | 2 +- tests/test_generalized_dice_loss.py | 2 +- tests/test_resize.py | 21 +-- 14 files changed, 548 insertions(+), 65 deletions(-) rename examples/{densenet_classification_3d.py => densenet_classification_3d_array.py} (83%) create mode 100644 examples/densenet_classification_3d_dict.py create mode 100644 examples/densenet_inference_3d_array.py create mode 100644 examples/densenet_inference_3d_dict.py create mode 100644 monai/handlers/classification_saver.py diff --git a/examples/densenet_classification_3d.py b/examples/densenet_classification_3d_array.py similarity index 83% rename from examples/densenet_classification_3d.py rename to examples/densenet_classification_3d_array.py index 753097a2a9..001eb079e2 100644 --- a/examples/densenet_classification_3d.py +++ b/examples/densenet_classification_3d_array.py @@ -23,14 +23,14 @@ import monai.transforms.compose as transforms from monai.data.nifti_reader import NiftiDataset -from monai.transforms import (AddChannel, Rescale, ToTensor, UniformRandomPatch) +from monai.transforms import (AddChannel, Rescale, Resize, RandRotate90) from monai.handlers.stats_handler import StatsHandler from ignite.metrics import Accuracy from monai.handlers.utils import stopping_fn_from_metric monai.config.print_config() -# FIXME: temp test dataset, Wenqi will replace later +# demo dataset, user can easily change to own dataset images = [ "/workspace/data/medical/ixi/IXI-T1/IXI314-IOP-0889-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI249-Guys-1072-T1.nii.gz", @@ -57,18 +57,23 @@ 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0 ]) -# Define transforms for image and segmentation -imtrans = transforms.Compose([ +# Define transforms +train_transforms = transforms.Compose([ Rescale(), AddChannel(), - UniformRandomPatch((96, 96, 96)), - ToTensor() + Resize((96, 96, 96)), + RandRotate90() +]) +val_transforms = transforms.Compose([ + Rescale(), + AddChannel(), + Resize((96, 96, 96)) ]) # Define nifti dataset, dataloader. -ds = NiftiDataset(image_files=images, labels=labels, transform=imtrans) -loader = DataLoader(ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available()) -im, label = monai.utils.misc.first(loader) +check_ds = NiftiDataset(image_files=images, labels=labels, transform=train_transforms) +check_loader = DataLoader(check_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available()) +im, label = monai.utils.misc.first(check_loader) print(type(im), im.shape, label) lr = 1e-5 @@ -84,7 +89,8 @@ # Create trainer device = torch.device("cuda:0") -trainer = create_supervised_trainer(net, opt, loss, device, False) +trainer = create_supervised_trainer(net, opt, loss, device, False, + output_transform=lambda x, y, y_pred, loss: [y_pred, loss.item(), y]) # adding checkpoint handler to save models (network params and optimizer stats) during training checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False) @@ -94,10 +100,6 @@ train_stats_handler = StatsHandler(output_transform=lambda x: x[3]) train_stats_handler.attach(trainer) -@trainer.on(Events.EPOCH_COMPLETED) -def log_training_loss(engine): - engine.logger.info("Epoch[%s] Loss: %s", engine.state.epoch, engine.state.output) - # Set parameters for validation validation_every_n_epochs = 1 metric_name = 'Accuracy' @@ -118,8 +120,8 @@ def log_training_loss(engine): evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) # create a validation data loader -val_ds = NiftiDataset(image_files=images[-5:], labels=labels[-5:], transform=imtrans) -val_loader = DataLoader(ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available()) +val_ds = NiftiDataset(image_files=images[-10:], labels=labels[-10:], transform=val_transforms) +val_loader = DataLoader(val_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available()) @trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) @@ -129,7 +131,7 @@ def run_validation(engine): # create a training data loader logging.basicConfig(stream=sys.stdout, level=logging.INFO) -train_ds = NiftiDataset(image_files=images[:15], labels=labels[:15], transform=imtrans) +train_ds = NiftiDataset(image_files=images[:10], labels=labels[:10], transform=train_transforms) train_loader = DataLoader(train_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available()) train_epochs = 30 diff --git a/examples/densenet_classification_3d_dict.py b/examples/densenet_classification_3d_dict.py new file mode 100644 index 0000000000..1be49de05d --- /dev/null +++ b/examples/densenet_classification_3d_dict.py @@ -0,0 +1,147 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import logging +import numpy as np +import torch +from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator, _prepare_batch +from ignite.handlers import ModelCheckpoint, EarlyStopping +from torch.utils.data import DataLoader + +# assumes the framework is found here, change as necessary +sys.path.append("..") +import monai +import monai.transforms.compose as transforms +from monai.transforms.composables import \ + LoadNiftid, AddChanneld, Rescaled, Resized, RandRotate90d +from monai.handlers.stats_handler import StatsHandler +from ignite.metrics import Accuracy +from monai.handlers.utils import stopping_fn_from_metric + +monai.config.print_config() + +# demo dataset, user can easily change to own dataset +images = [ + "/workspace/data/medical/ixi/IXI-T1/IXI314-IOP-0889-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI249-Guys-1072-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI609-HH-2600-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI173-HH-1590-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI020-Guys-0700-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI342-Guys-0909-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI134-Guys-0780-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI577-HH-2661-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI066-Guys-0731-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI130-HH-1528-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz" +] +labels = np.array([ + 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0 +]) +train_files = [{'img': img, 'label': label} for img, label in zip(images[:10], labels[:10])] +val_files = [{'img': img, 'label': label} for img, label in zip(images[-10:], labels[-10:])] + +# Define transforms for image +train_transforms = transforms.Compose([ + LoadNiftid(keys=['img']), + AddChanneld(keys=['img']), + Rescaled(keys=['img']), + Resized(keys=['img'], output_shape=(96, 96, 96)), + RandRotate90d(keys=['img'], prob=0.8, axes=[1, 3]) +]) +val_transforms = transforms.Compose([ + LoadNiftid(keys=['img']), + AddChanneld(keys=['img']), + Rescaled(keys=['img']), + Resized(keys=['img'], output_shape=(96, 96, 96)) +]) + +# Define dataset, dataloader. +check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) +check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) +check_data = monai.utils.misc.first(check_loader) +print(check_data['img'].shape, check_data['label']) + +lr = 1e-5 + +# Create DenseNet121, CrossEntropyLoss and Adam optimizer. +net = monai.networks.nets.densenet3d.densenet121( + in_channels=1, + out_channels=2, +) + +loss = torch.nn.CrossEntropyLoss() +opt = torch.optim.Adam(net.parameters(), lr) + + +# Create trainer +def prepare_batch(batch, device=None, non_blocking=False): + return _prepare_batch((batch['img'], batch['label']), device, non_blocking) + + +# Create trainer +device = torch.device("cuda:0") +trainer = create_supervised_trainer(net, opt, loss, device, False, prepare_batch=prepare_batch, + output_transform=lambda x, y, y_pred, loss: [y_pred, loss.item(), y]) + +# adding checkpoint handler to save models (network params and optimizer stats) during training +checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False) +trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, + handler=checkpoint_handler, + to_save={'net': net, 'opt': opt}) +train_stats_handler = StatsHandler() +train_stats_handler.attach(trainer) + +# Set parameters for validation +validation_every_n_epochs = 1 +metric_name = 'Accuracy' + +# add evaluation metric to the evaluator engine +val_metrics = {metric_name: Accuracy()} +evaluator = create_supervised_evaluator(net, val_metrics, device, True, prepare_batch=prepare_batch) + +# Add stats event handler to print validation stats via evaluator +logging.basicConfig(stream=sys.stdout, level=logging.INFO) +val_stats_handler = StatsHandler() +val_stats_handler.attach(evaluator) + +# Add early stopping handler to evaluator. +early_stopper = EarlyStopping(patience=4, + score_function=stopping_fn_from_metric(metric_name), + trainer=trainer) +evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) + +# create a validation data loader +val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) +val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) + + +@trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) +def run_validation(engine): + evaluator.run(val_loader) + +# create a training data loader +logging.basicConfig(stream=sys.stdout, level=logging.INFO) + +train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) +train_loader = DataLoader(train_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) + +train_epochs = 30 +state = trainer.run(train_loader, train_epochs) diff --git a/examples/densenet_inference_3d_array.py b/examples/densenet_inference_3d_array.py new file mode 100644 index 0000000000..25908bb821 --- /dev/null +++ b/examples/densenet_inference_3d_array.py @@ -0,0 +1,93 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import logging +import numpy as np +import torch +from ignite.engine import create_supervised_evaluator, _prepare_batch +from torch.utils.data import DataLoader + +# assumes the framework is found here, change as necessary +sys.path.append("..") +import monai +import monai.transforms.compose as transforms +from monai.data.nifti_reader import NiftiDataset +from monai.transforms import (AddChannel, Rescale, Resize) +from monai.handlers.stats_handler import StatsHandler +from monai.handlers.classification_saver import ClassificationSaver +from monai.handlers.checkpoint_loader import CheckpointLoader +from ignite.metrics import Accuracy + +monai.config.print_config() + +# demo dataset, user can easily change to own dataset +images = [ + "/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz" +] +labels = np.array([ + 0, 0, 1, 0, 1, 0, 1, 0, 1, 0 +]) + +# Define transforms for image +val_transforms = transforms.Compose([ + Rescale(), + AddChannel(), + Resize((96, 96, 96)) +]) +# Define nifti dataset +val_ds = NiftiDataset(image_files=images, labels=labels, transform=val_transforms, image_only=False) +# Create DenseNet121 +net = monai.networks.nets.densenet3d.densenet121( + in_channels=1, + out_channels=2, +) + +device = torch.device("cuda:0") +metric_name = 'Accuracy' + +# add evaluation metric to the evaluator engine +val_metrics = {metric_name: Accuracy()} + + +def prepare_batch(batch, device=None, non_blocking=False): + return _prepare_batch((batch[0], batch[1]), device, non_blocking) + + +evaluator = create_supervised_evaluator(net, val_metrics, device, True, prepare_batch=prepare_batch) + +# Add stats event handler to print validation stats via evaluator +logging.basicConfig(stream=sys.stdout, level=logging.INFO) +val_stats_handler = StatsHandler() +val_stats_handler.attach(evaluator) + +# for the arrary data format, assume the 3rd item of batch data is the meta_data +prediction_saver = ClassificationSaver(output_dir='tempdir', batch_transform=lambda batch: batch[2], + output_transform=lambda output: output[0].argmax(1)) +prediction_saver.attach(evaluator) + +# the model was trained by "densenet_classification_3d_array" exmple +CheckpointLoader(load_path='./runs/net_checkpoint_40.pth', load_dict={'net': net}).attach(evaluator) + +# create a validation data loader +val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) + +state = evaluator.run(val_loader) +prediction_saver.finalize() diff --git a/examples/densenet_inference_3d_dict.py b/examples/densenet_inference_3d_dict.py new file mode 100644 index 0000000000..8f7dc75c72 --- /dev/null +++ b/examples/densenet_inference_3d_dict.py @@ -0,0 +1,95 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ignite.metrics import Accuracy +import sys +import logging +import numpy as np +import torch +from ignite.engine import create_supervised_evaluator, _prepare_batch +from torch.utils.data import DataLoader + +# assumes the framework is found here, change as necessary +sys.path.append("..") +from monai.handlers.classification_saver import ClassificationSaver +from monai.handlers.checkpoint_loader import CheckpointLoader +from monai.handlers.stats_handler import StatsHandler +from monai.transforms.composables import LoadNiftid, AddChanneld, Rescaled, Resized +import monai.transforms.compose as transforms +import monai + +monai.config.print_config() + +# demo dataset, user can easily change to own dataset +images = [ + "/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz" +] +labels = np.array([ + 0, 0, 1, 0, 1, 0, 1, 0, 1, 0 +]) +val_files = [{'img': img, 'label': label} for img, label in zip(images, labels)] + +# Define transforms for image +val_transforms = transforms.Compose([ + LoadNiftid(keys=['img']), + AddChanneld(keys=['img']), + Rescaled(keys=['img']), + Resized(keys=['img'], output_shape=(96, 96, 96)) +]) + +# Create DenseNet121 +net = monai.networks.nets.densenet3d.densenet121( + in_channels=1, + out_channels=2, +) + + +def prepare_batch(batch, device=None, non_blocking=False): + return _prepare_batch((batch['img'], batch['label']), device, non_blocking) + + +# Create trainer +device = torch.device("cuda:0") +metric_name = 'Accuracy' + +# add evaluation metric to the evaluator engine +val_metrics = {metric_name: Accuracy()} +evaluator = create_supervised_evaluator(net, val_metrics, device, True, prepare_batch=prepare_batch) + +# Add stats event handler to print validation stats via evaluator +logging.basicConfig(stream=sys.stdout, level=logging.INFO) +val_stats_handler = StatsHandler() +val_stats_handler.attach(evaluator) + +# for the arrary data format, assume the 3rd item of batch data is the meta_data +prediction_saver = ClassificationSaver(output_dir='tempdir', batch_transform=lambda batch: { + 'filename_or_obj': batch['img.filename_or_obj']}, + output_transform=lambda output: output[0].argmax(1)) +prediction_saver.attach(evaluator) + +# the model was trained by "densenet_classification_3d_dict" exmple +CheckpointLoader(load_path='./runs/net_checkpoint_40.pth', load_dict={'net': net}).attach(evaluator) + +# create a validation data loader +val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) +val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) + +state = evaluator.run(val_loader) +prediction_saver.finalize() diff --git a/examples/unet_inference_3d_array.py b/examples/unet_inference_3d_array.py index 8fe417c7dd..ea87590285 100644 --- a/examples/unet_inference_3d_array.py +++ b/examples/unet_inference_3d_array.py @@ -28,7 +28,7 @@ from monai.handlers.segmentation_saver import SegmentationSaver import monai.transforms.compose as transforms from monai.data.nifti_reader import NiftiDataset -from monai.transforms import AddChannel, Rescale, ToTensor +from monai.transforms import AddChannel, Rescale from monai.networks.nets.unet import UNet from monai.networks.utils import predict_segmentation from monai.data.synthetic import create_test_image_3d @@ -50,8 +50,8 @@ images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz'))) segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) -imtrans = transforms.Compose([Rescale(), AddChannel(), ToTensor()]) -segtrans = transforms.Compose([AddChannel(), ToTensor()]) +imtrans = transforms.Compose([Rescale(), AddChannel()]) +segtrans = transforms.Compose([AddChannel()]) ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False) device = torch.device("cuda:0") diff --git a/examples/unet_inference_3d_dict.py b/examples/unet_inference_3d_dict.py index 405b49aa8d..7274453b22 100644 --- a/examples/unet_inference_3d_dict.py +++ b/examples/unet_inference_3d_dict.py @@ -87,7 +87,7 @@ def _sliding_window_processor(engine, batch): 'original_affine': batch['img.original_affine'], 'affine': batch['img.affine'], }).attach(infer_engine) -# the model was trained by "unet_segmentation_3d_array" exmple +# the model was trained by "unet_segmentation_3d_dict" exmple CheckpointLoader(load_path='./runs/net_checkpoint_120.pth', load_dict={'net': net}).attach(infer_engine) val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate, diff --git a/examples/unet_segmentation_3d_array.py b/examples/unet_segmentation_3d_array.py index 3b0d880f10..d43ead1d2b 100644 --- a/examples/unet_segmentation_3d_array.py +++ b/examples/unet_segmentation_3d_array.py @@ -29,7 +29,7 @@ import monai.transforms.compose as transforms from monai.data.nifti_reader import NiftiDataset -from monai.transforms import AddChannel, Rescale, ToTensor, UniformRandomPatch +from monai.transforms import AddChannel, Rescale, UniformRandomPatch from monai.handlers.stats_handler import StatsHandler from monai.handlers.tensorboard_handlers import TensorBoardStatsHandler, TensorBoardImageHandler from monai.handlers.mean_dice import MeanDice @@ -57,19 +57,17 @@ imtrans = transforms.Compose([ Rescale(), AddChannel(), - UniformRandomPatch((96, 96, 96)), - ToTensor() + UniformRandomPatch((96, 96, 96)) ]) segtrans = transforms.Compose([ AddChannel(), - UniformRandomPatch((96, 96, 96)), - ToTensor() + UniformRandomPatch((96, 96, 96)) ]) # Define nifti dataset, dataloader. -ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans) -loader = DataLoader(ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) -im, seg = monai.utils.misc.first(loader) +check_ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans) +check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) +im, seg = monai.utils.misc.first(check_loader) print(im.shape, seg.shape) lr = 1e-5 @@ -159,7 +157,7 @@ def log_training_loss(engine): # create a validation data loader val_ds = NiftiDataset(images[-20:], segs[-20:], transform=imtrans, seg_transform=segtrans) -val_loader = DataLoader(ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available()) +val_loader = DataLoader(val_ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available()) @trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) diff --git a/examples/unet_segmentation_3d_dict.py b/examples/unet_segmentation_3d_dict.py index 0e1e78811b..b8622bb7eb 100644 --- a/examples/unet_segmentation_3d_dict.py +++ b/examples/unet_segmentation_3d_dict.py @@ -68,11 +68,11 @@ AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1) ]) -# Define nifti dataset, dataloader. -ds = monai.data.Dataset(data=train_files, transform=train_transforms) -loader = DataLoader(ds, batch_size=2, num_workers=4, collate_fn=list_data_collate, - pin_memory=torch.cuda.is_available()) -check_data = monai.utils.misc.first(loader) +# Define dataset, dataloader. +check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) +check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, collate_fn=list_data_collate, + pin_memory=torch.cuda.is_available()) +check_data = monai.utils.misc.first(check_loader) print(check_data['img'].shape, check_data['seg'].shape) lr = 1e-5 @@ -171,7 +171,7 @@ def log_training_loss(engine): # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) -val_loader = DataLoader(ds, batch_size=5, num_workers=8, collate_fn=list_data_collate, +val_loader = DataLoader(val_ds, batch_size=5, num_workers=8, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()) diff --git a/monai/handlers/classification_saver.py b/monai/handlers/classification_saver.py new file mode 100644 index 0000000000..48f0ac59bb --- /dev/null +++ b/monai/handlers/classification_saver.py @@ -0,0 +1,91 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import csv +import numpy as np +import torch +from ignite.engine import Events + + +class ClassificationSaver: + """ + Event handler triggered on completing every iteration to save the classification predictions as CSV file. + """ + + def __init__(self, output_dir='./', overwrite=True, + batch_transform=lambda x: x, output_transform=lambda x: x, name=None): + """ + Args: + output_dir (str): output CSV file directory. + overwrite (bool): whether to overwriting existing CSV file content. If we are not overwriting, + then we check if the results have been previously saved, and load them to the prediction_dict. + batch_transform (Callable): a callable that is used to transform the + ignite.engine.batch into expected format to extract the meta_data dictionary. + output_transform (Callable): a callable that is used to transform the + ignite.engine.output into the form expected model prediction data. + The first dimension of this transform's output will be treated as the + batch dimension. Each item in the batch will be saved individually. + name (str): identifier of logging.logger to use, defaulting to `engine.logger`. + + """ + self.output_dir = output_dir + self._prediction_dict = {} + self._preds_filepath = os.path.join(output_dir, 'predictions.csv') + if not overwrite: + if os.path.exists(self._preds_filepath): + with open(self._path, 'r') as f: + reader = csv.reader(f) + for row in reader: + self._prediction_dict[row[0]] = np.array(row[1:]).astype(np.float32) + + self.batch_transform = batch_transform + self.output_transform = output_transform + + self.logger = None if name is None else logging.getLogger(name) + + def attach(self, engine): + if self.logger is None: + self.logger = engine.logger + return engine.add_event_handler(Events.ITERATION_COMPLETED, self) + + def finalize(self): + """ + Writes the prediction dict to a csv + + """ + if not os.path.exists(self.output_dir): + os.makedirs(self.output_dir) + with open(self._preds_filepath, 'w') as f: + for k, v in sorted(self._prediction_dict.items()): + f.write(k) + for result in v.flatten(): + f.write("," + str(result)) + f.write("\n") + self.logger.info('saved classification predictions into: {}'.format(self._preds_filepath)) + + def __call__(self, engine): + """ + This method assumes self.batch_transform will extract Metadata from the input batch. + Metadata should have the following keys: + + - ``'filename_or_obj'`` -- save the prediction corresponding to file name. + + """ + meta_data = self.batch_transform(engine.state.batch) + filenames = meta_data['filename_or_obj'] + + engine_output = self.output_transform(engine.state.output) + for batch_id, filename in enumerate(filenames): # save a batch of files + output = engine_output[batch_id] + if isinstance(output, torch.Tensor): + output = output.detach().cpu().numpy() + self._prediction_dict[filename] = output diff --git a/monai/transforms/composables.py b/monai/transforms/composables.py index c19c3f7df9..b27b64780e 100644 --- a/monai/transforms/composables.py +++ b/monai/transforms/composables.py @@ -14,6 +14,7 @@ """ import torch +import numpy as np from collections.abc import Hashable import monai @@ -23,7 +24,7 @@ from monai.transforms.transforms import (LoadNifti, AsChannelFirst, Orientation, AddChannel, Spacing, Rotate90, SpatialCrop, RandAffine, Rand2DElastic, Rand3DElastic, - Flip, Rotate, Zoom) + Rescale, Resize, Flip, Rotate, Zoom) from monai.utils.misc import ensure_tuple from monai.transforms.utils import generate_pos_neg_label_crop_centers, create_grid from monai.utils.aliases import alias @@ -249,6 +250,59 @@ def __call__(self, data): return d +@export +@alias('RescaleD', 'RescaleDict') +class Rescaled(MapTransform): + """ + dictionary-based wrapper of Rescale. + """ + + def __init__(self, keys, minv=0.0, maxv=1.0, dtype=np.float32): + MapTransform.__init__(self, keys) + self.rescaler = Rescale(minv, maxv, dtype) + + def __call__(self, data): + d = dict(data) + for key in self.keys: + d[key] = self.rescaler(d[key]) + return d + + +@export +@alias('ResizeD', 'ResizeDict') +class Resized(MapTransform): + """ + dictionary-based wrapper of Resize. + Args: + keys (hashable items): keys of the corresponding items to be transformed. + See also: monai.transform.composables.MapTransform + output_shape (tuple or list): expected shape after resize operation. + order (int): Order of spline interpolation. Default=1. + mode (str): Points outside boundaries are filled according to given mode. + Options are 'constant', 'edge', 'symmetric', 'reflect', 'wrap'. + cval (float): Used with mode 'constant', the value outside image boundaries. + clip (bool): Wheter to clip range of output values after interpolation. Default: True. + preserve_range (bool): Whether to keep original range of values. Default is True. + If False, input is converted according to conventions of img_as_float. See + https://scikit-image.org/docs/dev/user_guide/data_types.html. + anti_aliasing (bool): Whether to apply a gaussian filter to image before down-scaling. + Default is True. + anti_aliasing_sigma (float, tuple of floats): Standard deviation for gaussian filtering. + """ + + def __init__(self, keys, output_shape, order=1, mode='reflect', cval=0, + clip=True, preserve_range=True, anti_aliasing=True, anti_aliasing_sigma=None): + MapTransform.__init__(self, keys) + self.resizer = Resize(output_shape, order, mode, cval, clip, preserve_range, + anti_aliasing, anti_aliasing_sigma) + + def __call__(self, data): + d = dict(data) + for key in self.keys: + d[key] = self.resizer(d[key]) + return d + + @export @alias('UniformRandomPatchD', 'UniformRandomPatchDict') class UniformRandomPatchd(Randomizable, MapTransform): diff --git a/monai/transforms/transforms.py b/monai/transforms/transforms.py index 8f140972f6..be6d38a4fa 100644 --- a/monai/transforms/transforms.py +++ b/monai/transforms/transforms.py @@ -289,6 +289,7 @@ class Resize: For additional details, see https://scikit-image.org/docs/dev/api/skimage.transform.html#skimage.transform.resize. Args: + output_shape (tuple or list): expected shape after resize operation. order (int): Order of spline interpolation. Default=1. mode (str): Points outside boundaries are filled according to given mode. Options are 'constant', 'edge', 'symmetric', 'reflect', 'wrap'. @@ -315,11 +316,16 @@ def __init__(self, output_shape, order=1, mode='reflect', cval=0, self.anti_aliasing_sigma = anti_aliasing_sigma def __call__(self, img): - return resize(img, self.output_shape, order=self.order, - mode=self.mode, cval=self.cval, - clip=self.clip, preserve_range=self.preserve_range, - anti_aliasing=self.anti_aliasing, - anti_aliasing_sigma=self.anti_aliasing_sigma) + resized = list() + for channel in img: + resized.append( + resize(channel, self.output_shape, order=self.order, + mode=self.mode, cval=self.cval, + clip=self.clip, preserve_range=self.preserve_range, + anti_aliasing=self.anti_aliasing, + anti_aliasing_sigma=self.anti_aliasing_sigma) + ) + return np.stack(resized).astype(np.float32) @export @@ -353,7 +359,7 @@ def __init__(self, angle, axes=(1, 2), reshape=True, order=1, mode='constant', c def __call__(self, img): return scipy.ndimage.rotate(img, self.angle, self.axes, reshape=self.reshape, order=self.order, mode=self.mode, cval=self.cval, - prefilter=self.prefilter) + prefilter=self.prefilter).astype(np.float32) @export @@ -420,7 +426,7 @@ def __call__(self, img): mode=self.mode, cval=self.cval, prefilter=self.prefilter)) - zoomed = np.stack(zoomed) + zoomed = np.stack(zoomed).astype(np.float32) if not self.keep_size or np.allclose(img.shape, zoomed.shape): return zoomed @@ -478,16 +484,14 @@ class IntensityNormalizer: Args: subtrahend (ndarray): the amount to subtract by (usually the mean) divisor (ndarray): the amount to divide by (usually the standard deviation) - dtype: output data format """ - def __init__(self, subtrahend=None, divisor=None, dtype=np.float32): + def __init__(self, subtrahend=None, divisor=None): if subtrahend is not None or divisor is not None: assert isinstance(subtrahend, np.ndarray) and isinstance(divisor, np.ndarray), \ 'subtrahend and divisor must be set in pair and in numpy array.' self.subtrahend = subtrahend self.divisor = divisor - self.dtype = dtype def __call__(self, img): if self.subtrahend is not None and self.divisor is not None: @@ -497,8 +501,6 @@ def __call__(self, img): img -= np.mean(img) img /= np.std(img) - if self.dtype != img.dtype: - img = img.astype(self.dtype) return img @@ -511,15 +513,13 @@ class ImageEndPadder: Args: out_size (list): the size of region of interest at the end of the operation. mode (string): a portion from numpy.lib.arraypad.pad is copied below. - dtype: output data format. """ - def __init__(self, out_size, mode, dtype=np.float32): - assert out_size is not None and isinstance(out_size, (list, tuple)), 'out_size must be list or tuple' + def __init__(self, out_size, mode): + assert out_size is not None and isinstance(out_size, (list, tuple)), 'out_size must be list or tuple.' self.out_size = out_size - assert isinstance(mode, str), 'mode must be str' + assert isinstance(mode, str), 'mode must be str.' self.mode = mode - self.dtype = dtype def _determine_data_pad_width(self, data_shape): return [(0, max(self.out_size[i] - data_shape[i], 0)) for i in range(len(self.out_size))] diff --git a/tests/test_dice_loss.py b/tests/test_dice_loss.py index c5640a5660..e937185f91 100644 --- a/tests/test_dice_loss.py +++ b/tests/test_dice_loss.py @@ -82,7 +82,7 @@ TEST_CASE_6 = [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) { - 'include_background': False, + 'include_background': True, 'do_sigmoid': True, }, { diff --git a/tests/test_generalized_dice_loss.py b/tests/test_generalized_dice_loss.py index e08ff1d296..b2ce96169e 100644 --- a/tests/test_generalized_dice_loss.py +++ b/tests/test_generalized_dice_loss.py @@ -94,7 +94,7 @@ TEST_CASE_6 = [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) { - 'include_background': False, + 'include_background': True, 'do_sigmoid': True, }, { diff --git a/tests/test_resize.py b/tests/test_resize.py index 7feaf9f634..7a3d327b28 100644 --- a/tests/test_resize.py +++ b/tests/test_resize.py @@ -30,9 +30,9 @@ def test_invalid_inputs(self, _, order, raises): resize(self.imt) @parameterized.expand([ - ((1, 1, 64, 64), 1, 'reflect', 0, True, True, True, None), - ((1, 1, 32, 32), 2, 'constant', 3, False, False, False, None), - ((1, 1, 256, 256), 3, 'constant', 3, False, False, False, None), + ((64, 64), 1, 'reflect', 0, True, True, True, None), + ((32, 32), 2, 'constant', 3, False, False, False, None), + ((256, 256), 3, 'constant', 3, False, False, False, None), ]) def test_correct_results(self, output_shape, order, mode, cval, clip, preserve_range, @@ -40,12 +40,15 @@ def test_correct_results(self, output_shape, order, mode, resize = Resize(output_shape, order, mode, cval, clip, preserve_range, anti_aliasing, anti_aliasing_sigma) - expected = skimage.transform.resize(self.imt, output_shape, - order=order, mode=mode, - cval=cval, clip=clip, - preserve_range=preserve_range, - anti_aliasing=anti_aliasing, - anti_aliasing_sigma=anti_aliasing_sigma) + expected = list() + for channel in self.imt: + expected.append(skimage.transform.resize(channel, output_shape, + order=order, mode=mode, + cval=cval, clip=clip, + preserve_range=preserve_range, + anti_aliasing=anti_aliasing, + anti_aliasing_sigma=anti_aliasing_sigma)) + expected = np.stack(expected).astype(np.float32) self.assertTrue(np.allclose(resize(self.imt), expected)) From c4f31aa3438e2c8dce5a11b5d0cc09622c2965b4 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 12 Mar 2020 18:59:10 +0800 Subject: [PATCH 4/7] [DLMED] update and clear all the examples also added unit tests for ClassificationSaver and Resized --- .../densenet_evaluation_array.py} | 16 ++- .../densenet_evaluation_dict.py} | 21 ++-- .../densenet_training_array.py} | 51 +++++--- .../densenet_training_dict.py} | 52 +++++--- .../unet_evaluation_array.py} | 44 +++++-- .../unet_evaluation_dict.py} | 51 +++++--- .../unet_training_array.py} | 93 +++++++------- .../unet_training_dict.py} | 119 ++++++++---------- monai/handlers/classification_saver.py | 18 +-- monai/handlers/segmentation_saver.py | 2 +- monai/handlers/tensorboard_handlers.py | 45 ++++--- tests/test_handler_classification_saver.py | 56 +++++++++ tests/test_resize.py | 2 +- tests/test_resized.py | 56 +++++++++ 14 files changed, 401 insertions(+), 225 deletions(-) rename examples/{densenet_inference_3d_array.py => classification_3d/densenet_evaluation_array.py} (89%) rename examples/{densenet_inference_3d_dict.py => classification_3d/densenet_evaluation_dict.py} (85%) rename examples/{densenet_classification_3d_array.py => classification_3d/densenet_training_array.py} (75%) rename examples/{densenet_classification_3d_dict.py => classification_3d/densenet_training_dict.py} (76%) rename examples/{unet_inference_3d_array.py => segmentation_3d/unet_evaluation_array.py} (66%) rename examples/{unet_inference_3d_dict.py => segmentation_3d/unet_evaluation_dict.py} (67%) rename examples/{unet_segmentation_3d_array.py => segmentation_3d/unet_training_array.py} (66%) rename examples/{unet_segmentation_3d_dict.py => segmentation_3d/unet_training_dict.py} (58%) create mode 100644 tests/test_handler_classification_saver.py create mode 100644 tests/test_resized.py diff --git a/examples/densenet_inference_3d_array.py b/examples/classification_3d/densenet_evaluation_array.py similarity index 89% rename from examples/densenet_inference_3d_array.py rename to examples/classification_3d/densenet_evaluation_array.py index 25908bb821..7b1015ea4b 100644 --- a/examples/densenet_inference_3d_array.py +++ b/examples/classification_3d/densenet_evaluation_array.py @@ -17,7 +17,7 @@ from torch.utils.data import DataLoader # assumes the framework is found here, change as necessary -sys.path.append("..") +sys.path.append("../..") import monai import monai.transforms.compose as transforms from monai.data.nifti_reader import NiftiDataset @@ -28,6 +28,7 @@ from ignite.metrics import Accuracy monai.config.print_config() +logging.basicConfig(stream=sys.stdout, level=logging.INFO) # demo dataset, user can easily change to own dataset images = [ @@ -59,10 +60,9 @@ in_channels=1, out_channels=2, ) - device = torch.device("cuda:0") -metric_name = 'Accuracy' +metric_name = 'Accuracy' # add evaluation metric to the evaluator engine val_metrics = {metric_name: Accuracy()} @@ -71,11 +71,15 @@ def prepare_batch(batch, device=None, non_blocking=False): return _prepare_batch((batch[0], batch[1]), device, non_blocking) +# ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration, +# user can add output_transform to return other values evaluator = create_supervised_evaluator(net, val_metrics, device, True, prepare_batch=prepare_batch) # Add stats event handler to print validation stats via evaluator -logging.basicConfig(stream=sys.stdout, level=logging.INFO) -val_stats_handler = StatsHandler() +val_stats_handler = StatsHandler( + name='evaluator', + output_transform=lambda x: None # no need to print loss value, so disable per iteration output +) val_stats_handler.attach(evaluator) # for the arrary data format, assume the 3rd item of batch data is the meta_data @@ -83,7 +87,7 @@ def prepare_batch(batch, device=None, non_blocking=False): output_transform=lambda output: output[0].argmax(1)) prediction_saver.attach(evaluator) -# the model was trained by "densenet_classification_3d_array" exmple +# the model was trained by "densenet_training_array" exmple CheckpointLoader(load_path='./runs/net_checkpoint_40.pth', load_dict={'net': net}).attach(evaluator) # create a validation data loader diff --git a/examples/densenet_inference_3d_dict.py b/examples/classification_3d/densenet_evaluation_dict.py similarity index 85% rename from examples/densenet_inference_3d_dict.py rename to examples/classification_3d/densenet_evaluation_dict.py index 8f7dc75c72..7f4852afa8 100644 --- a/examples/densenet_inference_3d_dict.py +++ b/examples/classification_3d/densenet_evaluation_dict.py @@ -18,7 +18,7 @@ from torch.utils.data import DataLoader # assumes the framework is found here, change as necessary -sys.path.append("..") +sys.path.append("../..") from monai.handlers.classification_saver import ClassificationSaver from monai.handlers.checkpoint_loader import CheckpointLoader from monai.handlers.stats_handler import StatsHandler @@ -27,6 +27,7 @@ import monai monai.config.print_config() +logging.basicConfig(stream=sys.stdout, level=logging.INFO) # demo dataset, user can easily change to own dataset images = [ @@ -59,32 +60,34 @@ in_channels=1, out_channels=2, ) +device = torch.device("cuda:0") def prepare_batch(batch, device=None, non_blocking=False): return _prepare_batch((batch['img'], batch['label']), device, non_blocking) -# Create trainer -device = torch.device("cuda:0") metric_name = 'Accuracy' - # add evaluation metric to the evaluator engine val_metrics = {metric_name: Accuracy()} +# ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration, +# user can add output_transform to return other values evaluator = create_supervised_evaluator(net, val_metrics, device, True, prepare_batch=prepare_batch) # Add stats event handler to print validation stats via evaluator -logging.basicConfig(stream=sys.stdout, level=logging.INFO) -val_stats_handler = StatsHandler() +val_stats_handler = StatsHandler( + name='evaluator', + output_transform=lambda x: None # no need to print loss value, so disable per iteration output +) val_stats_handler.attach(evaluator) # for the arrary data format, assume the 3rd item of batch data is the meta_data -prediction_saver = ClassificationSaver(output_dir='tempdir', batch_transform=lambda batch: { - 'filename_or_obj': batch['img.filename_or_obj']}, +prediction_saver = ClassificationSaver(output_dir='tempdir', name='evaluator', + batch_transform=lambda batch: {'filename_or_obj': batch['img.filename_or_obj']}, output_transform=lambda output: output[0].argmax(1)) prediction_saver.attach(evaluator) -# the model was trained by "densenet_classification_3d_dict" exmple +# the model was trained by "densenet_training_dict" exmple CheckpointLoader(load_path='./runs/net_checkpoint_40.pth', load_dict={'net': net}).attach(evaluator) # create a validation data loader diff --git a/examples/densenet_classification_3d_array.py b/examples/classification_3d/densenet_training_array.py similarity index 75% rename from examples/densenet_classification_3d_array.py rename to examples/classification_3d/densenet_training_array.py index 8fa61ca284..be8504adec 100644 --- a/examples/densenet_classification_3d_array.py +++ b/examples/classification_3d/densenet_training_array.py @@ -18,16 +18,18 @@ from torch.utils.data import DataLoader # assumes the framework is found here, change as necessary -sys.path.append("..") +sys.path.append("../..") import monai import monai.transforms.compose as transforms from monai.data.nifti_reader import NiftiDataset from monai.transforms import (AddChannel, Rescale, Resize, RandRotate90) from monai.handlers.stats_handler import StatsHandler +from monai.handlers.tensorboard_handlers import TensorBoardStatsHandler from ignite.metrics import Accuracy from monai.handlers.utils import stopping_fn_from_metric monai.config.print_config() +logging.basicConfig(stream=sys.stdout, level=logging.INFO) # demo dataset, user can easily change to own dataset images = [ @@ -69,27 +71,25 @@ Resize((96, 96, 96)) ]) -# Define nifti dataset, dataloader. +# Define nifti dataset, dataloader check_ds = NiftiDataset(image_files=images, labels=labels, transform=train_transforms) check_loader = DataLoader(check_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available()) im, label = monai.utils.misc.first(check_loader) print(type(im), im.shape, label) -lr = 1e-5 - -# Create DenseNet121, CrossEntropyLoss and Adam optimizer. +# Create DenseNet121, CrossEntropyLoss and Adam optimizer net = monai.networks.nets.densenet3d.densenet121( in_channels=1, out_channels=2, ) - loss = torch.nn.CrossEntropyLoss() +lr = 1e-5 opt = torch.optim.Adam(net.parameters(), lr) - -# Create trainer device = torch.device("cuda:0") -trainer = create_supervised_trainer(net, opt, loss, device, False, - output_transform=lambda x, y, y_pred, loss: [y_pred, loss.item(), y]) + +# ignite trainer expects batch=(img, label) and returns output=loss at every iteration, +# user can add output_transform to return other values, like: y_pred, y, etc. +trainer = create_supervised_trainer(net, opt, loss, device, False) # adding checkpoint handler to save models (network params and optimizer stats) during training checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False) @@ -97,24 +97,40 @@ handler=checkpoint_handler, to_save={'net': net, 'opt': opt}) -train_stats_handler = StatsHandler(output_transform=lambda x: x[3]) +# StatsHandler prints loss at every iteration and print metrics at every epoch, +# we don't set metrics for trainer here, so just print loss, user can also customize print functions +# and can use output_transform to convert engine.state.output if it's not loss value +train_stats_handler = StatsHandler(name='trainer') train_stats_handler.attach(trainer) +# TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler +train_tensorboard_stats_handler = TensorBoardStatsHandler() +train_tensorboard_stats_handler.attach(trainer) + # Set parameters for validation validation_every_n_epochs = 1 -metric_name = 'Accuracy' +metric_name = 'Accuracy' # add evaluation metric to the evaluator engine val_metrics = {metric_name: Accuracy()} +# ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration, +# user can add output_transform to return other values evaluator = create_supervised_evaluator(net, val_metrics, device, True) # Add stats event handler to print validation stats via evaluator -logging.basicConfig(stream=sys.stdout, level=logging.INFO) - -val_stats_handler = StatsHandler(output_transform=lambda x: None) +val_stats_handler = StatsHandler( + name='evaluator', + output_transform=lambda x: None, # no need to print loss value, so disable per iteration output + global_epoch_transform=lambda x: trainer.state.epoch) # fetch global epoch number from trainer val_stats_handler.attach(evaluator) -# Add early stopping handler to evaluator. +# add handler to record metrics to TensorBoard at every epoch +val_tensorboard_stats_handler = TensorBoardStatsHandler( + output_transform=lambda x: None, # no need to plot loss value, so disable per iteration output + global_epoch_transform=lambda x: trainer.state.epoch) # fetch global epoch number from trainer +val_tensorboard_stats_handler.attach(evaluator) + +# Add early stopping handler to evaluator early_stopper = EarlyStopping(patience=4, score_function=stopping_fn_from_metric(metric_name), trainer=trainer) @@ -129,9 +145,8 @@ def run_validation(engine): evaluator.run(val_loader) -# create a training data loader -logging.basicConfig(stream=sys.stdout, level=logging.INFO) +# create a training data loader train_ds = NiftiDataset(image_files=images[:10], labels=labels[:10], transform=train_transforms) train_loader = DataLoader(train_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available()) diff --git a/examples/densenet_classification_3d_dict.py b/examples/classification_3d/densenet_training_dict.py similarity index 76% rename from examples/densenet_classification_3d_dict.py rename to examples/classification_3d/densenet_training_dict.py index 1be49de05d..8f5d994f08 100644 --- a/examples/densenet_classification_3d_dict.py +++ b/examples/classification_3d/densenet_training_dict.py @@ -18,16 +18,18 @@ from torch.utils.data import DataLoader # assumes the framework is found here, change as necessary -sys.path.append("..") +sys.path.append("../..") import monai import monai.transforms.compose as transforms from monai.transforms.composables import \ LoadNiftid, AddChanneld, Rescaled, Resized, RandRotate90d from monai.handlers.stats_handler import StatsHandler +from monai.handlers.tensorboard_handlers import TensorBoardStatsHandler from ignite.metrics import Accuracy from monai.handlers.utils import stopping_fn_from_metric monai.config.print_config() +logging.basicConfig(stream=sys.stdout, level=logging.INFO) # demo dataset, user can easily change to own dataset images = [ @@ -73,56 +75,71 @@ Resized(keys=['img'], output_shape=(96, 96, 96)) ]) -# Define dataset, dataloader. +# Define dataset, dataloader check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) check_data = monai.utils.misc.first(check_loader) print(check_data['img'].shape, check_data['label']) -lr = 1e-5 - -# Create DenseNet121, CrossEntropyLoss and Adam optimizer. +# Create DenseNet121, CrossEntropyLoss and Adam optimizer net = monai.networks.nets.densenet3d.densenet121( in_channels=1, out_channels=2, ) - loss = torch.nn.CrossEntropyLoss() +lr = 1e-5 opt = torch.optim.Adam(net.parameters(), lr) +device = torch.device("cuda:0") -# Create trainer +# ignite trainer expects batch=(img, label) and returns output=loss at every iteration, +# user can add output_transform to return other values, like: y_pred, y, etc. def prepare_batch(batch, device=None, non_blocking=False): return _prepare_batch((batch['img'], batch['label']), device, non_blocking) -# Create trainer -device = torch.device("cuda:0") -trainer = create_supervised_trainer(net, opt, loss, device, False, prepare_batch=prepare_batch, - output_transform=lambda x, y, y_pred, loss: [y_pred, loss.item(), y]) +trainer = create_supervised_trainer(net, opt, loss, device, False, prepare_batch=prepare_batch) # adding checkpoint handler to save models (network params and optimizer stats) during training checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False) trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={'net': net, 'opt': opt}) -train_stats_handler = StatsHandler() + +# StatsHandler prints loss at every iteration and print metrics at every epoch, +# we don't set metrics for trainer here, so just print loss, user can also customize print functions +# and can use output_transform to convert engine.state.output if it's not loss value +train_stats_handler = StatsHandler(name='trainer') train_stats_handler.attach(trainer) +# TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler +train_tensorboard_stats_handler = TensorBoardStatsHandler() +train_tensorboard_stats_handler.attach(trainer) + # Set parameters for validation validation_every_n_epochs = 1 -metric_name = 'Accuracy' +metric_name = 'Accuracy' # add evaluation metric to the evaluator engine val_metrics = {metric_name: Accuracy()} +# ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration, +# user can add output_transform to return other values evaluator = create_supervised_evaluator(net, val_metrics, device, True, prepare_batch=prepare_batch) # Add stats event handler to print validation stats via evaluator -logging.basicConfig(stream=sys.stdout, level=logging.INFO) -val_stats_handler = StatsHandler() +val_stats_handler = StatsHandler( + name='evaluator', + output_transform=lambda x: None, # no need to print loss value, so disable per iteration output + global_epoch_transform=lambda x: trainer.state.epoch) # fetch global epoch number from trainer val_stats_handler.attach(evaluator) -# Add early stopping handler to evaluator. +# add handler to record metrics to TensorBoard at every epoch +val_tensorboard_stats_handler = TensorBoardStatsHandler( + output_transform=lambda x: None, # no need to plot loss value, so disable per iteration output + global_epoch_transform=lambda x: trainer.state.epoch) # fetch global epoch number from trainer +val_tensorboard_stats_handler.attach(evaluator) + +# Add early stopping handler to evaluator early_stopper = EarlyStopping(patience=4, score_function=stopping_fn_from_metric(metric_name), trainer=trainer) @@ -137,9 +154,8 @@ def prepare_batch(batch, device=None, non_blocking=False): def run_validation(engine): evaluator.run(val_loader) -# create a training data loader -logging.basicConfig(stream=sys.stdout, level=logging.INFO) +# create a training data loader train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) train_loader = DataLoader(train_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) diff --git a/examples/unet_inference_3d_array.py b/examples/segmentation_3d/unet_evaluation_array.py similarity index 66% rename from examples/unet_inference_3d_array.py rename to examples/segmentation_3d/unet_evaluation_array.py index ea87590285..988af02bc9 100644 --- a/examples/unet_inference_3d_array.py +++ b/examples/segmentation_3d/unet_evaluation_array.py @@ -13,7 +13,7 @@ import sys import tempfile from glob import glob - +import logging import nibabel as nib import numpy as np import torch @@ -21,7 +21,7 @@ from torch.utils.data import DataLoader # assumes the framework is found here, change as necessary -sys.path.append("..") +sys.path.append("../..") from monai import config from monai.handlers.checkpoint_loader import CheckpointLoader @@ -33,8 +33,11 @@ from monai.networks.utils import predict_segmentation from monai.data.synthetic import create_test_image_3d from monai.utils.sliding_window_inference import sliding_window_inference +from monai.handlers.stats_handler import StatsHandler +from monai.handlers.mean_dice import MeanDice config.print_config() +logging.basicConfig(stream=sys.stdout, level=logging.INFO) tempdir = tempfile.mkdtemp() # tempdir = './temp' @@ -50,13 +53,13 @@ images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz'))) segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) + +# Define transforms for image and segmentation imtrans = transforms.Compose([Rescale(), AddChannel()]) segtrans = transforms.Compose([AddChannel()]) ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False) device = torch.device("cuda:0") -roi_size = (64, 64, 64) -sw_batch_size = 4 net = UNet( dimensions=3, in_channels=1, @@ -67,22 +70,39 @@ ) net.to(device) +# define sliding window size and batch size for windows inference +roi_size = (64, 64, 64) +sw_batch_size = 4 + def _sliding_window_processor(engine, batch): net.eval() img, seg, meta_data = batch with torch.no_grad(): - seg_probs = sliding_window_inference(img, roi_size, sw_batch_size, lambda x: net(x)[0], device) - return predict_segmentation(seg_probs) + seg_probs = sliding_window_inference(img, roi_size, sw_batch_size, lambda x: net(x), device) + return seg_probs, seg.to(device) + +evaluator = Engine(_sliding_window_processor) -infer_engine = Engine(_sliding_window_processor) +# add evaluation metric to the evaluator engine +MeanDice(add_sigmoid=True, to_onehot_y=False).attach(evaluator, 'Mean_Dice') + +# StatsHandler prints loss at every iteration and print metrics at every epoch, +# we don't need to print loss for evaluator, so just print metrics, user can also customize print functions +val_stats_handler = StatsHandler( + name='evaluator', + output_transform=lambda x: None # no need to print loss value, so disable per iteration output +) +val_stats_handler.attach(evaluator) # for the arrary data format, assume the 3rd item of batch data is the meta_data -SegmentationSaver(output_path='tempdir', output_ext='.nii.gz', output_postfix='seg', - batch_transform=lambda x: x[2]).attach(infer_engine) -# the model was trained by "unet_segmentation_3d_array" exmple -CheckpointLoader(load_path='./runs/net_checkpoint_120.pth', load_dict={'net': net}).attach(infer_engine) +SegmentationSaver(output_path='tempdir', output_ext='.nii.gz', output_postfix='seg', name='evaluator', + batch_transform=lambda x: x[2], output_transform=lambda output: predict_segmentation(output[0]) + ).attach(evaluator) +# the model was trained by "unet_training_array" exmple +CheckpointLoader(load_path='./runs/net_checkpoint_320.pth', load_dict={'net': net}).attach(evaluator) +# sliding window inferene need to input 1 image in every iteration loader = DataLoader(ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()) -state = infer_engine.run(loader) +state = evaluator.run(loader) diff --git a/examples/unet_inference_3d_dict.py b/examples/segmentation_3d/unet_evaluation_dict.py similarity index 67% rename from examples/unet_inference_3d_dict.py rename to examples/segmentation_3d/unet_evaluation_dict.py index 7274453b22..23e79e4484 100644 --- a/examples/unet_inference_3d_dict.py +++ b/examples/segmentation_3d/unet_evaluation_dict.py @@ -13,7 +13,7 @@ import sys import tempfile from glob import glob - +import logging import nibabel as nib import numpy as np import torch @@ -21,7 +21,7 @@ from torch.utils.data import DataLoader # assumes the framework is found here, change as necessary -sys.path.append("..") +sys.path.append("../..") import monai from monai.data.utils import list_data_collate @@ -29,13 +29,16 @@ from monai.data.synthetic import create_test_image_3d from monai.networks.utils import predict_segmentation from monai.networks.nets.unet import UNet -from monai.transforms.composables import LoadNiftid, AsChannelFirstd +from monai.transforms.composables import LoadNiftid, AsChannelFirstd, Rescaled import monai.transforms.compose as transforms from monai.handlers.segmentation_saver import SegmentationSaver from monai.handlers.checkpoint_loader import CheckpointLoader +from monai.handlers.stats_handler import StatsHandler +from monai.handlers.mean_dice import MeanDice from monai import config config.print_config() +logging.basicConfig(stream=sys.stdout, level=logging.INFO) tempdir = tempfile.mkdtemp() # tempdir = './temp' @@ -52,15 +55,16 @@ images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz'))) segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) val_files = [{'img': img, 'seg': seg} for img, seg in zip(images, segs)] + +# Define transforms for image and segmentation val_transforms = transforms.Compose([ LoadNiftid(keys=['img', 'seg']), - AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1) + AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1), + Rescaled(keys=['img', 'seg']) ]) val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) device = torch.device("cuda:0") -roi_size = (64, 64, 64) -sw_batch_size = 4 net = UNet( dimensions=3, in_channels=1, @@ -71,25 +75,42 @@ ) net.to(device) +# define sliding window size and batch size for windows inference +roi_size = (64, 64, 64) +sw_batch_size = 4 + def _sliding_window_processor(engine, batch): net.eval() with torch.no_grad(): - seg_probs = sliding_window_inference(batch['img'], roi_size, sw_batch_size, lambda x: net(x)[0], device) - return predict_segmentation(seg_probs) + seg_probs = sliding_window_inference(batch['img'], roi_size, sw_batch_size, lambda x: net(x), device) + return seg_probs, batch['seg'].to(device) + +evaluator = Engine(_sliding_window_processor) -infer_engine = Engine(_sliding_window_processor) +# add evaluation metric to the evaluator engine +MeanDice(add_sigmoid=True, to_onehot_y=False).attach(evaluator, 'Mean_Dice') + +# StatsHandler prints loss at every iteration and print metrics at every epoch, +# we don't need to print loss for evaluator, so just print metrics, user can also customize print functions +val_stats_handler = StatsHandler( + name='evaluator', + output_transform=lambda x: None # no need to print loss value, so disable per iteration output +) +val_stats_handler.attach(evaluator) -# for the arrary data format, assume the 3rd item of batch data is the meta_data -SegmentationSaver(output_path='tempdir', output_ext='.nii.gz', output_postfix='seg', +# convert the necessary metadata from batch data +SegmentationSaver(output_path='tempdir', output_ext='.nii.gz', output_postfix='seg', name='evaluator', batch_transform=lambda batch: {'filename_or_obj': batch['img.filename_or_obj'], 'original_affine': batch['img.original_affine'], 'affine': batch['img.affine'], - }).attach(infer_engine) -# the model was trained by "unet_segmentation_3d_dict" exmple -CheckpointLoader(load_path='./runs/net_checkpoint_120.pth', load_dict={'net': net}).attach(infer_engine) + }, + output_transform=lambda output: predict_segmentation(output[0])).attach(evaluator) +# the model was trained by "unet_training_dict" exmple +CheckpointLoader(load_path='./runs/net_checkpoint_120.pth', load_dict={'net': net}).attach(evaluator) +# sliding window inferene need to input 1 image in every iteration val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()) -state = infer_engine.run(val_loader) +state = evaluator.run(val_loader) diff --git a/examples/unet_segmentation_3d_array.py b/examples/segmentation_3d/unet_training_array.py similarity index 66% rename from examples/unet_segmentation_3d_array.py rename to examples/segmentation_3d/unet_training_array.py index d43ead1d2b..abf8ac8b1a 100644 --- a/examples/unet_segmentation_3d_array.py +++ b/examples/segmentation_3d/unet_training_array.py @@ -14,7 +14,6 @@ import tempfile from glob import glob import logging - import nibabel as nib import numpy as np import torch @@ -23,20 +22,22 @@ from torch.utils.data import DataLoader # assumes the framework is found here, change as necessary -sys.path.append("..") +sys.path.append("../..") import monai import monai.transforms.compose as transforms from monai.data.nifti_reader import NiftiDataset -from monai.transforms import AddChannel, Rescale, UniformRandomPatch +from monai.transforms import AddChannel, Rescale, UniformRandomPatch, Resize from monai.handlers.stats_handler import StatsHandler from monai.handlers.tensorboard_handlers import TensorBoardStatsHandler, TensorBoardImageHandler from monai.handlers.mean_dice import MeanDice from monai.data.synthetic import create_test_image_3d from monai.handlers.utils import stopping_fn_from_metric +from monai.networks.utils import predict_segmentation monai.config.print_config() +logging.basicConfig(stream=sys.stdout, level=logging.INFO) # Create a temporary directory and 50 random image, mask paris tempdir = tempfile.mkdtemp() @@ -54,25 +55,32 @@ segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) # Define transforms for image and segmentation -imtrans = transforms.Compose([ +train_imtrans = transforms.Compose([ Rescale(), AddChannel(), UniformRandomPatch((96, 96, 96)) ]) -segtrans = transforms.Compose([ +train_segtrans = transforms.Compose([ AddChannel(), UniformRandomPatch((96, 96, 96)) ]) +val_imtrans = transforms.Compose([ + Rescale(), + AddChannel(), + Resize((96, 96, 96)) +]) +val_segtrans = transforms.Compose([ + AddChannel(), + Resize((96, 96, 96)) +]) -# Define nifti dataset, dataloader. -check_ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans) +# Define nifti dataset, dataloader +check_ds = NiftiDataset(images, segs, transform=train_imtrans, seg_transform=train_segtrans) check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) im, seg = monai.utils.misc.first(check_loader) print(im.shape, seg.shape) -lr = 1e-5 - -# Create UNet, DiceLoss and Adam optimizer. +# Create UNet, DiceLoss and Adam optimizer net = monai.networks.nets.UNet( dimensions=3, in_channels=1, @@ -81,70 +89,59 @@ strides=(2, 2, 2, 2), num_res_units=2, ) - loss = monai.losses.DiceLoss(do_sigmoid=True) +lr = 1e-5 opt = torch.optim.Adam(net.parameters(), lr) +device = torch.device("cuda:0") - -# Since network outputs logits and segmentation, we need a custom function. -def _loss_fn(i, j): - return loss(i[0], j) - - -# Create trainer -device = torch.device("cpu:0") -trainer = create_supervised_trainer(net, opt, _loss_fn, device, False, - output_transform=lambda x, y, y_pred, loss: [y_pred[1], loss.item(), y]) +# ignite trainer expects batch=(img, seg) and returns output=loss at every iteration, +# user can add output_transform to return other values, like: y_pred, y, etc. +trainer = create_supervised_trainer(net, opt, loss, device, False) # adding checkpoint handler to save models (network params and optimizer stats) during training checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False) trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={'net': net, 'opt': opt}) -logging.basicConfig(stream=sys.stdout, level=logging.INFO) -# print training loss to commandline -train_stats_handler = StatsHandler(output_transform=lambda x: x[1]) +# StatsHandler prints loss at every iteration and print metrics at every epoch, +# we don't set metrics for trainer here, so just print loss, user can also customize print functions +# and can use output_transform to convert engine.state.output if it's not loss value +train_stats_handler = StatsHandler(name='trainer') train_stats_handler.attach(trainer) -# record training loss to TensorBoard at every iteration -train_tensorboard_stats_handler = TensorBoardStatsHandler( - output_transform=lambda x: {'training_dice_loss': x[1]}, # plot under tag name taining_dice_loss - global_epoch_transform=lambda x: trainer.state.epoch) +# TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler +train_tensorboard_stats_handler = TensorBoardStatsHandler() train_tensorboard_stats_handler.attach(trainer) - -@trainer.on(Events.EPOCH_COMPLETED) -def log_training_loss(engine): - engine.logger.info("Epoch[%s] Loss: %s", engine.state.epoch, engine.state.output[1]) - - # Set parameters for validation validation_every_n_epochs = 1 -metric_name = 'Mean_Dice' +metric_name = 'Mean_Dice' # add evaluation metric to the evaluator engine -val_metrics = {metric_name: MeanDice( - add_sigmoid=True, to_onehot_y=False, output_transform=lambda output: (output[0][0], output[1])) -} +val_metrics = {metric_name: MeanDice(add_sigmoid=True, to_onehot_y=False)} + +# ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration, +# user can add output_transform to return other values evaluator = create_supervised_evaluator(net, val_metrics, device, True) # Add stats event handler to print validation stats via evaluator val_stats_handler = StatsHandler( - output_transform=lambda x: None, # disable per iteration output - global_epoch_transform=lambda x: trainer.state.epoch) + name='evaluator', + output_transform=lambda x: None, # no need to print loss value, so disable per iteration output + global_epoch_transform=lambda x: trainer.state.epoch) # fetch global epoch number from trainer val_stats_handler.attach(evaluator) # add handler to record metrics to TensorBoard at every epoch val_tensorboard_stats_handler = TensorBoardStatsHandler( - output_transform=lambda x: None, # no iteration plot - global_epoch_transform=lambda x: trainer.state.epoch) # use epoch number from trainer + output_transform=lambda x: None, # no need to plot loss value, so disable per iteration output + global_epoch_transform=lambda x: trainer.state.epoch) # fetch global epoch number from trainer val_tensorboard_stats_handler.attach(evaluator) -# add handler to draw several images and the corresponding labels and model outputs -# here we draw the first 3 images(draw the first channel) as GIF format along Depth axis +# add handler to draw the first image and the corresponding label and model output in the last batch +# here we draw the 3D output as GIF format along Depth axis val_tensorboard_image_handler = TensorBoardImageHandler( batch_transform=lambda batch: (batch[0], batch[1]), - output_transform=lambda output: output[0][1], + output_transform=lambda output: predict_segmentation(output[0]), global_iter_transform=lambda x: trainer.state.epoch ) evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=val_tensorboard_image_handler) @@ -156,7 +153,7 @@ def log_training_loss(engine): evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) # create a validation data loader -val_ds = NiftiDataset(images[-20:], segs[-20:], transform=imtrans, seg_transform=segtrans) +val_ds = NiftiDataset(images[-20:], segs[-20:], transform=val_imtrans, seg_transform=val_segtrans) val_loader = DataLoader(val_ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available()) @@ -166,9 +163,7 @@ def run_validation(engine): # create a training data loader -logging.basicConfig(stream=sys.stdout, level=logging.INFO) - -train_ds = NiftiDataset(images[:20], segs[:20], transform=imtrans, seg_transform=segtrans) +train_ds = NiftiDataset(images[:20], segs[:20], transform=train_imtrans, seg_transform=train_segtrans) train_loader = DataLoader(train_ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available()) train_epochs = 30 diff --git a/examples/unet_segmentation_3d_dict.py b/examples/segmentation_3d/unet_training_dict.py similarity index 58% rename from examples/unet_segmentation_3d_dict.py rename to examples/segmentation_3d/unet_training_dict.py index b8622bb7eb..223faf8707 100644 --- a/examples/unet_segmentation_3d_dict.py +++ b/examples/segmentation_3d/unet_training_dict.py @@ -14,30 +14,30 @@ import tempfile from glob import glob import logging - import nibabel as nib import numpy as np import torch -from torch.utils.tensorboard import SummaryWriter from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator, _prepare_batch from ignite.handlers import ModelCheckpoint, EarlyStopping from torch.utils.data import DataLoader # assumes the framework is found here, change as necessary -sys.path.append("..") +sys.path.append("../..") import monai import monai.transforms.compose as transforms from monai.transforms.composables import \ - LoadNiftid, AsChannelFirstd, RandCropByPosNegLabeld, RandRotate90d + LoadNiftid, AsChannelFirstd, Rescaled, RandCropByPosNegLabeld, RandRotate90d from monai.handlers.stats_handler import StatsHandler +from monai.handlers.tensorboard_handlers import TensorBoardStatsHandler, TensorBoardImageHandler from monai.handlers.mean_dice import MeanDice -from monai.visualize import img2tensorboard from monai.data.synthetic import create_test_image_3d from monai.handlers.utils import stopping_fn_from_metric from monai.data.utils import list_data_collate +from monai.networks.utils import predict_segmentation monai.config.print_config() +logging.basicConfig(stream=sys.stdout, level=logging.INFO) # Create a temporary directory and 50 random image, mask paris tempdir = tempfile.mkdtemp() @@ -60,24 +60,25 @@ train_transforms = transforms.Compose([ LoadNiftid(keys=['img', 'seg']), AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1), + Rescaled(keys=['img', 'seg']), RandCropByPosNegLabeld(keys=['img', 'seg'], label_key='seg', size=[96, 96, 96], pos=1, neg=1, num_samples=4), RandRotate90d(keys=['img', 'seg'], prob=0.8, axes=[1, 3]) ]) val_transforms = transforms.Compose([ LoadNiftid(keys=['img', 'seg']), - AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1) + AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1), + Rescaled(keys=['img', 'seg']) ]) -# Define dataset, dataloader. +# Define dataset, dataloader check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) +# use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()) check_data = monai.utils.misc.first(check_loader) print(check_data['img'].shape, check_data['seg'].shape) -lr = 1e-5 - -# Create UNet, DiceLoss and Adam optimizer. +# Create UNet, DiceLoss and Adam optimizer net = monai.networks.nets.UNet( dimensions=3, in_channels=1, @@ -86,84 +87,69 @@ strides=(2, 2, 2, 2), num_res_units=2, ) - loss = monai.losses.DiceLoss(do_sigmoid=True) +lr = 1e-5 opt = torch.optim.Adam(net.parameters(), lr) +device = torch.device("cuda:0") -# Since network outputs logits and segmentation, we need a custom function. -def _loss_fn(i, j): - return loss(i[0], j) - - -# Create trainer +# ignite trainer expects batch=(img, seg) and returns output=loss at every iteration, +# user can add output_transform to return other values, like: y_pred, y, etc. def prepare_batch(batch, device=None, non_blocking=False): return _prepare_batch((batch['img'], batch['seg']), device, non_blocking) -device = torch.device("cuda:0") -trainer = create_supervised_trainer(net, opt, _loss_fn, device, False, - prepare_batch=prepare_batch, - output_transform=lambda x, y, y_pred, loss: [y_pred, loss.item(), y]) +trainer = create_supervised_trainer(net, opt, loss, device, False, prepare_batch=prepare_batch) # adding checkpoint handler to save models (network params and optimizer stats) during training checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False) trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={'net': net, 'opt': opt}) -train_stats_handler = StatsHandler(output_transform=lambda x: x[1]) -train_stats_handler.attach(trainer) +# StatsHandler prints loss at every iteration and print metrics at every epoch, +# we don't set metrics for trainer here, so just print loss, user can also customize print functions +# and can use output_transform to convert engine.state.output if it's not loss value +train_stats_handler = StatsHandler(name='trainer') +train_stats_handler.attach(trainer) -@trainer.on(Events.EPOCH_COMPLETED) -def log_training_loss(engine): - # log loss to tensorboard with second item of engine.state.output, loss.item() from output_transform - writer.add_scalar('Loss/train', engine.state.output[1], engine.state.epoch) - - # tensor of ones to use where for converting labels to zero and ones - ones = torch.ones(engine.state.batch['seg'][0].shape, dtype=torch.int32) - first_output_tensor = engine.state.output[0][1][0].detach().cpu() - # log model output to tensorboard, as three dimensional tensor with no channels dimension - img2tensorboard.add_animated_gif_no_channels(writer, "first_output_final_batch", first_output_tensor, 64, - 255, engine.state.epoch) - # get label tensor and convert to single class - first_label_tensor = torch.where(engine.state.batch['seg'][0] > 0, ones, engine.state.batch['seg'][0]) - # log label tensor to tensorboard, there is a channel dimension when getting label from batch - img2tensorboard.add_animated_gif(writer, "first_label_final_batch", first_label_tensor, 64, - 255, engine.state.epoch) - second_output_tensor = engine.state.output[0][1][1].detach().cpu() - img2tensorboard.add_animated_gif_no_channels(writer, "second_output_final_batch", second_output_tensor, 64, - 255, engine.state.epoch) - second_label_tensor = torch.where(engine.state.batch['seg'][1] > 0, ones, engine.state.batch['seg'][1]) - img2tensorboard.add_animated_gif(writer, "second_label_final_batch", second_label_tensor, 64, - 255, engine.state.epoch) - third_output_tensor = engine.state.output[0][1][2].detach().cpu() - img2tensorboard.add_animated_gif_no_channels(writer, "third_output_final_batch", third_output_tensor, 64, - 255, engine.state.epoch) - third_label_tensor = torch.where(engine.state.batch['seg'][2] > 0, ones, engine.state.batch['seg'][2]) - img2tensorboard.add_animated_gif(writer, "third_label_final_batch", third_label_tensor, 64, - 255, engine.state.epoch) - engine.logger.info("Epoch[%s] Loss: %s", engine.state.epoch, engine.state.output[1]) - - -writer = SummaryWriter() +# TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler +train_tensorboard_stats_handler = TensorBoardStatsHandler() +train_tensorboard_stats_handler.attach(trainer) # Set parameters for validation validation_every_n_epochs = 1 -metric_name = 'Mean_Dice' +metric_name = 'Mean_Dice' # add evaluation metric to the evaluator engine val_metrics = {metric_name: MeanDice(add_sigmoid=True, to_onehot_y=False)} -evaluator = create_supervised_evaluator(net, val_metrics, device, True, - prepare_batch=prepare_batch, - output_transform=lambda x, y, y_pred: (y_pred[0], y)) + +# ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration, +# user can add output_transform to return other values +evaluator = create_supervised_evaluator(net, val_metrics, device, True, prepare_batch=prepare_batch) # Add stats event handler to print validation stats via evaluator -logging.basicConfig(stream=sys.stdout, level=logging.INFO) -val_stats_handler = StatsHandler(output_transform=lambda x: None) +val_stats_handler = StatsHandler( + name='evaluator', + output_transform=lambda x: None, # no need to print loss value, so disable per iteration output + global_epoch_transform=lambda x: trainer.state.epoch) # fetch global epoch number from trainer val_stats_handler.attach(evaluator) -# Add early stopping handler to evaluator. +# add handler to record metrics to TensorBoard at every epoch +val_tensorboard_stats_handler = TensorBoardStatsHandler( + output_transform=lambda x: None, # no need to plot loss value, so disable per iteration output + global_epoch_transform=lambda x: trainer.state.epoch) # fetch global epoch number from trainer +val_tensorboard_stats_handler.attach(evaluator) +# add handler to draw the first image and the corresponding label and model output in the last batch +# here we draw the 3D output as GIF format along Depth axis +val_tensorboard_image_handler = TensorBoardImageHandler( + batch_transform=lambda batch: (batch['img'], batch['seg']), + output_transform=lambda output: predict_segmentation(output[0]), + global_iter_transform=lambda x: trainer.state.epoch +) +evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=val_tensorboard_image_handler) + +# Add early stopping handler to evaluator early_stopper = EarlyStopping(patience=4, score_function=stopping_fn_from_metric(metric_name), trainer=trainer) @@ -180,16 +166,9 @@ def run_validation(engine): evaluator.run(val_loader) -@evaluator.on(Events.EPOCH_COMPLETED) -def log_metrics_to_tensorboard(engine): - for _, value in engine.state.metrics.items(): - writer.add_scalar('Metrics/' + metric_name, value, trainer.state.epoch) - - # create a training data loader -logging.basicConfig(stream=sys.stdout, level=logging.INFO) - train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) +# use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training train_loader = DataLoader(train_ds, batch_size=2, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()) diff --git a/monai/handlers/classification_saver.py b/monai/handlers/classification_saver.py index 48f0ac59bb..d0c9f2126c 100644 --- a/monai/handlers/classification_saver.py +++ b/monai/handlers/classification_saver.py @@ -14,6 +14,7 @@ import numpy as np import torch from ignite.engine import Events +import logging class ClassificationSaver: @@ -40,13 +41,7 @@ def __init__(self, output_dir='./', overwrite=True, self.output_dir = output_dir self._prediction_dict = {} self._preds_filepath = os.path.join(output_dir, 'predictions.csv') - if not overwrite: - if os.path.exists(self._preds_filepath): - with open(self._path, 'r') as f: - reader = csv.reader(f) - for row in reader: - self._prediction_dict[row[0]] = np.array(row[1:]).astype(np.float32) - + self.overwrite = overwrite self.batch_transform = batch_transform self.output_transform = output_transform @@ -62,6 +57,13 @@ def finalize(self): Writes the prediction dict to a csv """ + if not self.overwrite: + if os.path.exists(self._preds_filepath): + with open(self._preds_filepath, 'r') as f: + reader = csv.reader(f) + for row in reader: + self._prediction_dict[row[0]] = np.array(row[1:]).astype(np.float32) + if not os.path.exists(self.output_dir): os.makedirs(self.output_dir) with open(self._preds_filepath, 'w') as f: @@ -88,4 +90,4 @@ def __call__(self, engine): output = engine_output[batch_id] if isinstance(output, torch.Tensor): output = output.detach().cpu().numpy() - self._prediction_dict[filename] = output + self._prediction_dict[filename] = output.astype(np.float32) diff --git a/monai/handlers/segmentation_saver.py b/monai/handlers/segmentation_saver.py index 1f3fe2615d..3a136ee3ec 100644 --- a/monai/handlers/segmentation_saver.py +++ b/monai/handlers/segmentation_saver.py @@ -13,7 +13,7 @@ import numpy as np import torch from ignite.engine import Events - +import logging from monai.data.nifti_writer import write_nifti diff --git a/monai/handlers/tensorboard_handlers.py b/monai/handlers/tensorboard_handlers.py index 2d5116bd59..80d524b56e 100644 --- a/monai/handlers/tensorboard_handlers.py +++ b/monai/handlers/tensorboard_handlers.py @@ -18,6 +18,8 @@ from monai.utils.misc import is_scalar from monai.transforms.utils import rescale_array +DEFAULT_TAG = 'Loss' + class TensorBoardStatsHandler(object): """TensorBoardStatsHandler defines a set of Ignite Event-handlers for all the TensorBoard logics. @@ -37,7 +39,8 @@ def __init__(self, epoch_event_writer=None, iteration_event_writer=None, output_transform=lambda x: {'Loss': x}, - global_epoch_transform=lambda x: x): + global_epoch_transform=lambda x: x, + tag_name=DEFAULT_TAG): """ Args: summary_writer (SummaryWriter): user can specify TensorBoard SummaryWriter, @@ -52,12 +55,14 @@ def __init__(self, global_epoch_transform (Callable): a callable that is used to customize global epoch number. For example, in evaluation, the evaluator engine might want to use trainer engines epoch number when plotting epoch vs metric curves. + tag_name (string): when iteration output is a scalar, tag_name is used to plot, defaults to ``'Loss'``. """ self._writer = SummaryWriter() if summary_writer is None else summary_writer self.epoch_event_writer = epoch_event_writer self.iteration_event_writer = iteration_event_writer self.output_transform = output_transform self.global_epoch_transform = global_epoch_transform + self.tag_name = tag_name def attach(self, engine: Engine): """Register a set of Ignite Event-Handlers to a specified Ignite engine. @@ -121,23 +126,27 @@ def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter): writer (SummaryWriter): TensorBoard writer, created in TensorBoardHandler. """ - loss_dict = self.output_transform(engine.state.output) - if loss_dict is None: + loss = self.output_transform(engine.state.output) + if loss is None: return # do nothing if output is empty - if not isinstance(loss_dict, dict): - raise ValueError('TensorBoardStatsHandler requires' - ' output_transform(engine.state.output) returning a dictionary' - ' of key and scalar pairs to plot' - ' got {}.'.format(type(loss_dict))) - for name, value in loss_dict.items(): - if not is_scalar(value): - warnings.warn('ignoring non-scalar output in tensorboard curve plotting,' + if isinstance(loss, dict): + for name in sorted(loss): + value = loss[name] + if not is_scalar(value): + warnings.warn('ignoring non-scalar output in TensorBoardStatsHandler,' + ' make sure `output_transform(engine.state.output)` returns' + ' a scalar or dictionary of key and scalar pairs to avoid this warning.' + ' {}:{}'.format(name, type(value))) + continue # not plot multi dimensional output + writer.add_scalar(name, value.item() if torch.is_tensor(value) else value, engine.state.iteration) + else: + if is_scalar(loss): # not printing multi dimensional output + writer.add_scalar(name, loss.item() if torch.is_tensor(loss) else loss, engine.state.iteration) + else: + warnings.warn('ignoring non-scalar output in TensorBoardStatsHandler,' ' make sure `output_transform(engine.state.output)` returns' - ' a dictionary of key and scalar pairs to avoid this warning.' - ' Got {}:{}'.format(name, type(value))) - continue - plot_value = value.item() if torch.is_tensor(value) else value - writer.add_scalar(name, plot_value, engine.state.iteration) + ' a scalar or a dictionary of key and scalar pairs to avoid this warning.' + ' {}'.format(type(loss))) writer.flush() @@ -230,9 +239,9 @@ def _add_2_or_3_d(self, data, step, tag='output'): return if d.ndim == 3: - if d.shape[0] == 3 and self.max_channels == 3: # rgb? + if d.shape[0] == 3 and self.max_channels == 3: # RGB dataformats = 'CHW' - self._writer.add_image('{}_{}'.format(tag, dataformats), d, step, dataformats='CHW') + self._writer.add_image('{}_{}'.format(tag, dataformats), d, step, dataformats=dataformats) return for j, d2 in enumerate(d[:self.max_channels]): d2 = rescale_array(d2, 0, 1) diff --git a/tests/test_handler_classification_saver.py b/tests/test_handler_classification_saver.py new file mode 100644 index 0000000000..78a773e092 --- /dev/null +++ b/tests/test_handler_classification_saver.py @@ -0,0 +1,56 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import csv +import shutil +import unittest +import numpy as np +import torch +from ignite.engine import Engine + +from monai.handlers.classification_saver import ClassificationSaver + + +class TestHandlerClassificationSaver(unittest.TestCase): + + def test_saved_content(self): + default_dir = os.path.join('.', 'tempdir') + shutil.rmtree(default_dir, ignore_errors=True) + + # set up engine + def _train_func(engine, batch): + return torch.zeros(8) + + engine = Engine(_train_func) + + # set up testing handler + saver = ClassificationSaver(output_dir=default_dir) + saver.attach(engine) + + data = [{'filename_or_obj': ['testfile' + str(i) for i in range(8)]}] + engine.run(data, epoch_length=2, max_epochs=1) + saver.finalize() + filepath = os.path.join(default_dir, 'predictions.csv') + self.assertTrue(os.path.exists(filepath)) + with open(filepath, 'r') as f: + reader = csv.reader(f) + i = 0 + for row in reader: + self.assertEqual(row[0], 'testfile' + str(i)) + self.assertEqual(np.array(row[1:]).astype(np.float32), 0.0) + i += 1 + self.assertEqual(i, 8) + shutil.rmtree(default_dir) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_resize.py b/tests/test_resize.py index 7a3d327b28..8040c6ebe4 100644 --- a/tests/test_resize.py +++ b/tests/test_resize.py @@ -19,7 +19,7 @@ from tests.utils import NumpyImageTestCase2D -class ResizeTest(NumpyImageTestCase2D): +class TestResized(NumpyImageTestCase2D): @parameterized.expand([ ("invalid_order", "order", AssertionError) diff --git a/tests/test_resized.py b/tests/test_resized.py new file mode 100644 index 0000000000..f1ec07b74b --- /dev/null +++ b/tests/test_resized.py @@ -0,0 +1,56 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import skimage +from parameterized import parameterized + +from monai.transforms import Resized +from tests.utils import NumpyImageTestCase2D + + +class TestResized(NumpyImageTestCase2D): + + @parameterized.expand([ + ("invalid_order", "order", AssertionError) + ]) + def test_invalid_inputs(self, _, order, raises): + with self.assertRaises(raises): + resize = Resized(keys='img', output_shape=(128, 128, 3), order=order) + resize({'img': self.imt}) + + @parameterized.expand([ + ((64, 64), 1, 'reflect', 0, True, True, True, None), + ((32, 32), 2, 'constant', 3, False, False, False, None), + ((256, 256), 3, 'constant', 3, False, False, False, None), + ]) + def test_correct_results(self, output_shape, order, mode, + cval, clip, preserve_range, + anti_aliasing, anti_aliasing_sigma): + resize = Resized('img', output_shape, order, mode, cval, clip, + preserve_range, anti_aliasing, + anti_aliasing_sigma) + expected = list() + for channel in self.imt: + expected.append(skimage.transform.resize(channel, output_shape, + order=order, mode=mode, + cval=cval, clip=clip, + preserve_range=preserve_range, + anti_aliasing=anti_aliasing, + anti_aliasing_sigma=anti_aliasing_sigma)) + expected = np.stack(expected).astype(np.float32) + self.assertTrue(np.allclose(resize({'img': self.imt})['img'], expected)) + + +if __name__ == '__main__': + unittest.main() From 56ce7cd23f4771750003260278ba9f8df7607bcd Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 13 Mar 2020 00:41:44 +0800 Subject: [PATCH 5/7] [DLMED] fix transforms spatial axis issue and update unit tests --- .../densenet_evaluation_dict.py | 2 +- .../densenet_training_dict.py | 6 +- .../segmentation_3d/unet_evaluation_array.py | 4 +- .../segmentation_3d/unet_evaluation_dict.py | 2 +- .../segmentation_3d/unet_training_dict.py | 2 +- monai/transforms/composables.py | 80 +++++++----- monai/transforms/transforms.py | 123 +++++++++++------- tests/integration_sliding_window.py | 2 +- tests/test_flip.py | 36 ++--- tests/test_flipd.py | 48 +++++++ ...t_random_affine.py => test_rand_affine.py} | 0 ...ffine_grid.py => test_rand_affine_grid.py} | 0 ...random_affined.py => test_rand_affined.py} | 0 ...eform_grid.py => test_rand_deform_grid.py} | 0 ..._elastic_2d.py => test_rand_elastic_2d.py} | 0 ..._elastic_3d.py => test_rand_elastic_3d.py} | 0 ...lasticd_2d.py => test_rand_elasticd_2d.py} | 0 ...lasticd_3d.py => test_rand_elasticd_3d.py} | 0 tests/test_rand_flip.py | 31 ++--- tests/test_rand_flipd.py | 38 ++++++ tests/test_rand_rotate.py | 21 +-- tests/test_rand_rotate90.py | 38 ++++-- tests/test_rand_rotate90d.py | 50 ++++--- tests/test_rand_rotated.py | 46 +++++++ tests/test_rand_zoom.py | 45 +++---- tests/test_rand_zoomd.py | 83 ++++++++++++ tests/test_resize.py | 16 +-- tests/test_resized.py | 14 +- tests/test_rotate.py | 37 ++---- tests/test_rotate90.py | 38 ++++-- tests/test_rotate90d.py | 48 ++++--- tests/test_rotated.py | 43 ++++++ tests/test_uniform_rand_patch.py | 10 +- tests/test_uniform_rand_patchd.py | 14 +- tests/test_zoom.py | 32 ++--- tests/test_zoomd.py | 82 ++++++++++++ 36 files changed, 693 insertions(+), 298 deletions(-) create mode 100644 tests/test_flipd.py rename tests/{test_random_affine.py => test_rand_affine.py} (100%) rename tests/{test_random_affine_grid.py => test_rand_affine_grid.py} (100%) rename tests/{test_random_affined.py => test_rand_affined.py} (100%) rename tests/{test_random_deform_grid.py => test_rand_deform_grid.py} (100%) rename tests/{test_random_elastic_2d.py => test_rand_elastic_2d.py} (100%) rename tests/{test_random_elastic_3d.py => test_rand_elastic_3d.py} (100%) rename tests/{test_random_elasticd_2d.py => test_rand_elasticd_2d.py} (100%) rename tests/{test_random_elasticd_3d.py => test_rand_elasticd_3d.py} (100%) create mode 100644 tests/test_rand_flipd.py create mode 100644 tests/test_rand_rotated.py create mode 100644 tests/test_rand_zoomd.py create mode 100644 tests/test_rotated.py create mode 100644 tests/test_zoomd.py diff --git a/examples/classification_3d/densenet_evaluation_dict.py b/examples/classification_3d/densenet_evaluation_dict.py index 7f4852afa8..99dd98f546 100644 --- a/examples/classification_3d/densenet_evaluation_dict.py +++ b/examples/classification_3d/densenet_evaluation_dict.py @@ -52,7 +52,7 @@ LoadNiftid(keys=['img']), AddChanneld(keys=['img']), Rescaled(keys=['img']), - Resized(keys=['img'], output_shape=(96, 96, 96)) + Resized(keys=['img'], output_spatial_shape=(96, 96, 96)) ]) # Create DenseNet121 diff --git a/examples/classification_3d/densenet_training_dict.py b/examples/classification_3d/densenet_training_dict.py index 8f5d994f08..9633459522 100644 --- a/examples/classification_3d/densenet_training_dict.py +++ b/examples/classification_3d/densenet_training_dict.py @@ -65,14 +65,14 @@ LoadNiftid(keys=['img']), AddChanneld(keys=['img']), Rescaled(keys=['img']), - Resized(keys=['img'], output_shape=(96, 96, 96)), - RandRotate90d(keys=['img'], prob=0.8, axes=[1, 3]) + Resized(keys=['img'], output_spatial_shape=(96, 96, 96)), + RandRotate90d(keys=['img'], prob=0.8, spatial_axes=[0, 2]) ]) val_transforms = transforms.Compose([ LoadNiftid(keys=['img']), AddChanneld(keys=['img']), Rescaled(keys=['img']), - Resized(keys=['img'], output_shape=(96, 96, 96)) + Resized(keys=['img'], output_spatial_shape=(96, 96, 96)) ]) # Define dataset, dataloader diff --git a/examples/segmentation_3d/unet_evaluation_array.py b/examples/segmentation_3d/unet_evaluation_array.py index 988af02bc9..005b9d0d85 100644 --- a/examples/segmentation_3d/unet_evaluation_array.py +++ b/examples/segmentation_3d/unet_evaluation_array.py @@ -79,7 +79,7 @@ def _sliding_window_processor(engine, batch): net.eval() img, seg, meta_data = batch with torch.no_grad(): - seg_probs = sliding_window_inference(img, roi_size, sw_batch_size, lambda x: net(x), device) + seg_probs = sliding_window_inference(img, roi_size, sw_batch_size, net, device) return seg_probs, seg.to(device) @@ -101,7 +101,7 @@ def _sliding_window_processor(engine, batch): batch_transform=lambda x: x[2], output_transform=lambda output: predict_segmentation(output[0]) ).attach(evaluator) # the model was trained by "unet_training_array" exmple -CheckpointLoader(load_path='./runs/net_checkpoint_320.pth', load_dict={'net': net}).attach(evaluator) +CheckpointLoader(load_path='./runs/net_checkpoint_120.pth', load_dict={'net': net}).attach(evaluator) # sliding window inferene need to input 1 image in every iteration loader = DataLoader(ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()) diff --git a/examples/segmentation_3d/unet_evaluation_dict.py b/examples/segmentation_3d/unet_evaluation_dict.py index 23e79e4484..3f43299f76 100644 --- a/examples/segmentation_3d/unet_evaluation_dict.py +++ b/examples/segmentation_3d/unet_evaluation_dict.py @@ -83,7 +83,7 @@ def _sliding_window_processor(engine, batch): net.eval() with torch.no_grad(): - seg_probs = sliding_window_inference(batch['img'], roi_size, sw_batch_size, lambda x: net(x), device) + seg_probs = sliding_window_inference(batch['img'], roi_size, sw_batch_size, net, device) return seg_probs, batch['seg'].to(device) diff --git a/examples/segmentation_3d/unet_training_dict.py b/examples/segmentation_3d/unet_training_dict.py index 223faf8707..e3b8d4fe88 100644 --- a/examples/segmentation_3d/unet_training_dict.py +++ b/examples/segmentation_3d/unet_training_dict.py @@ -62,7 +62,7 @@ AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1), Rescaled(keys=['img', 'seg']), RandCropByPosNegLabeld(keys=['img', 'seg'], label_key='seg', size=[96, 96, 96], pos=1, neg=1, num_samples=4), - RandRotate90d(keys=['img', 'seg'], prob=0.8, axes=[1, 3]) + RandRotate90d(keys=['img', 'seg'], prob=0.8, spatial_axes=[0, 2]) ]) val_transforms = transforms.Compose([ LoadNiftid(keys=['img', 'seg']), diff --git a/monai/transforms/composables.py b/monai/transforms/composables.py index 8b4ee7d2b1..f86afd546e 100644 --- a/monai/transforms/composables.py +++ b/monai/transforms/composables.py @@ -230,17 +230,18 @@ class Rotate90d(MapTransform): dictionary-based wrapper of Rotate90. """ - def __init__(self, keys, k=1, axes=(1, 2)): + def __init__(self, keys, k=1, spatial_axes=(0, 1)): """ Args: k (int): number of times to rotate by 90 degrees. - axes (2 ints): defines the plane to rotate with 2 axes. + spatial_axes (2 ints): defines the plane to rotate with 2 spatial axes. + Default: (0, 1), this is the first two axis in spatial dimensions. """ MapTransform.__init__(self, keys) self.k = k - self.plane_axes = axes + self.spatial_axes = spatial_axes - self.rotator = Rotate90(self.k, self.plane_axes) + self.rotator = Rotate90(self.k, self.spatial_axes) def __call__(self, data): d = dict(data) @@ -275,7 +276,7 @@ class Resized(MapTransform): Args: keys (hashable items): keys of the corresponding items to be transformed. See also: monai.transform.composables.MapTransform - output_shape (tuple or list): expected shape after resize operation. + output_spatial_shape (tuple or list): expected shape of spatial dimensions after resize operation. order (int): Order of spline interpolation. Default=1. mode (str): Points outside boundaries are filled according to given mode. Options are 'constant', 'edge', 'symmetric', 'reflect', 'wrap'. @@ -289,10 +290,10 @@ class Resized(MapTransform): anti_aliasing_sigma (float, tuple of floats): Standard deviation for gaussian filtering. """ - def __init__(self, keys, output_shape, order=1, mode='reflect', cval=0, + def __init__(self, keys, output_spatial_shape, order=1, mode='reflect', cval=0, clip=True, preserve_range=True, anti_aliasing=True, anti_aliasing_sigma=None): MapTransform.__init__(self, keys) - self.resizer = Resize(output_shape, order, mode, cval, clip, preserve_range, + self.resizer = Resize(output_spatial_shape, order, mode, cval, clip, preserve_range, anti_aliasing, anti_aliasing_sigma) def __call__(self, data): @@ -307,12 +308,15 @@ def __call__(self, data): class UniformRandomPatchd(Randomizable, MapTransform): """ Selects a patch of the given size chosen at a uniformly random position in the image. + + Args: + patch_spatial_size (tuple or list): Expected patch size of spatial dimensions. """ - def __init__(self, keys, patch_size): + def __init__(self, keys, patch_spatial_size): MapTransform.__init__(self, keys) - self.patch_size = (None,) + tuple(patch_size) + self.patch_spatial_size = (None,) + tuple(patch_spatial_size) self._slices = None @@ -323,8 +327,8 @@ def __call__(self, data): d = dict(data) image_shape = d[self.keys[0]].shape # image shape from the first data key - patch_size = get_valid_patch_size(image_shape, self.patch_size) - self.randomize(image_shape, patch_size) + patch_spatial_size = get_valid_patch_size(image_shape, self.patch_spatial_size) + self.randomize(image_shape, patch_spatial_size) for key in self.keys: d[key] = d[key][self._slices] return d @@ -335,10 +339,10 @@ def __call__(self, data): class RandRotate90d(Randomizable, MapTransform): """ With probability `prob`, input arrays are rotated by 90 degrees - in the plane specified by `axes`. + in the plane specified by `spatial_axes`. """ - def __init__(self, keys, prob=0.1, max_k=3, axes=(1, 2)): + def __init__(self, keys, prob=0.1, max_k=3, spatial_axes=(0, 1)): """ Args: keys (hashable items): keys of the corresponding items to be transformed. @@ -347,14 +351,14 @@ def __init__(self, keys, prob=0.1, max_k=3, axes=(1, 2)): (Default 0.1, with 10% probability it returns a rotated array.) max_k (int): number of rotations will be sampled from `np.random.randint(max_k) + 1`. (Default 3) - axes (2 ints): defines the plane to rotate with 2 axes. - (Default to (1, 2)) + spatial_axes (2 ints): defines the plane to rotate with 2 spatial axes. + Default: (0, 1), this is the first two axis in spatial dimensions. """ MapTransform.__init__(self, keys) self.prob = min(max(prob, 0.0), 1.0) self.max_k = max_k - self.axes = axes + self.spatial_axes = spatial_axes self._do_transform = False self._rand_k = 0 @@ -368,7 +372,7 @@ def __call__(self, data): if not self._do_transform: return data - rotator = Rotate90(self._rand_k, self.axes) + rotator = Rotate90(self._rand_k, self.spatial_axes) d = dict(data) for key in self.keys: d[key] = rotator(d[key]) @@ -652,15 +656,17 @@ def __call__(self, data): @alias('FlipD', 'FlipDict') class Flipd(MapTransform): """Dictionary-based wrapper of Flip. + See numpy.flip for additional details. + https://docs.scipy.org/doc/numpy/reference/generated/numpy.flip.html Args: keys (dict): Keys to pick data for transformation. - axis (None, int or tuple of ints): Axes along which to flip over. Default is None. + spatial_axis (None, int or tuple of ints): Spatial axes along which to flip over. Default is None. """ - def __init__(self, keys, axis=None): + def __init__(self, keys, spatial_axis=None): MapTransform.__init__(self, keys) - self.flipper = Flip(axis=axis) + self.flipper = Flip(spatial_axis=spatial_axis) def __call__(self, data): d = dict(data) @@ -673,19 +679,21 @@ def __call__(self, data): @alias('RandFlipD', 'RandFlipDict') class RandFlipd(Randomizable, MapTransform): """Dict-based wrapper of RandFlip. + See numpy.flip for additional details. + https://docs.scipy.org/doc/numpy/reference/generated/numpy.flip.html Args: prob (float): Probability of flipping. - axis (None, int or tuple of ints): Axes along which to flip over. Default is None. + spatial_axis (None, int or tuple of ints): Spatial axes along which to flip over. Default is None. """ - def __init__(self, keys, prob=0.1, axis=None): + def __init__(self, keys, prob=0.1, spatial_axis=None): MapTransform.__init__(self, keys) - self.axis = axis + self.spatial_axis = spatial_axis self.prob = prob self._do_transform = False - self.flipper = Flip(axis=axis) + self.flipper = Flip(spatial_axis=spatial_axis) def randomize(self): self._do_transform = self.R.random_sample() < self.prob @@ -708,8 +716,8 @@ class Rotated(MapTransform): Args: keys (dict): Keys to pick data for transformation. angle (float): Rotation angle in degrees. - axes (tuple of 2 ints): Axes of rotation. Default: (1, 2). This is the first two - axis in spatial dimensions according to MONAI channel first shape assumption. + spatial_axes (tuple of 2 ints): Spatial axes of rotation. Default: (0, 1). + This is the first two axis in spatial dimensions. reshape (bool): If true, output shape is made same as input. Default: True. order (int): Order of spline interpolation. Range 0-5. Default: 1. This is different from scipy where default interpolation is 3. @@ -719,10 +727,10 @@ class Rotated(MapTransform): prefiter (bool): Apply spline_filter before interpolation. Default: True. """ - def __init__(self, keys, angle, axes=(1, 2), reshape=True, order=1, + def __init__(self, keys, angle, spatial_axes=(0, 1), reshape=True, order=1, mode='constant', cval=0, prefilter=True): MapTransform.__init__(self, keys) - self.rotator = Rotate(angle=angle, axes=axes, reshape=reshape, + self.rotator = Rotate(angle=angle, spatial_axes=spatial_axes, reshape=reshape, order=order, mode=mode, cval=cval, prefilter=prefilter) def __call__(self, data): @@ -741,8 +749,8 @@ class RandRotated(Randomizable, MapTransform): prob (float): Probability of rotation. degrees (tuple of float or float): Range of rotation in degrees. If single number, angle is picked from (-degrees, degrees). - axes (tuple of 2 ints): Axes of rotation. Default: (1, 2). This is the first two - axis in spatial dimensions according to MONAI channel first shape assumption. + spatial_axes (tuple of 2 ints): Spatial axes of rotation. Default: (0, 1). + This is the first two axis in spatial dimensions. reshape (bool): If true, output shape is made same as input. Default: True. order (int): Order of spline interpolation. Range 0-5. Default: 1. This is different from scipy where default interpolation is 3. @@ -751,7 +759,7 @@ class RandRotated(Randomizable, MapTransform): cval (scalar): Value to fill outside boundary. Default: 0. prefiter (bool): Apply spline_filter before interpolation. Default: True. """ - def __init__(self, keys, degrees, prob=0.1, axes=(1, 2), reshape=True, order=1, + def __init__(self, keys, degrees, prob=0.1, spatial_axes=(0, 1), reshape=True, order=1, mode='constant', cval=0, prefilter=True): MapTransform.__init__(self, keys) self.prob = prob @@ -761,7 +769,7 @@ def __init__(self, keys, degrees, prob=0.1, axes=(1, 2), reshape=True, order=1, self.mode = mode self.cval = cval self.prefilter = prefilter - self.axes = axes + self.spatial_axes = spatial_axes if not hasattr(self.degrees, '__iter__'): self.degrees = (-self.degrees, self.degrees) @@ -779,10 +787,10 @@ def __call__(self, data): d = dict(data) if not self._do_transform: return d - rotator = Rotate(self.angle, self.axes, self.reshape, self.order, + rotator = Rotate(self.angle, self.spatial_axes, self.reshape, self.order, self.mode, self.cval, self.prefilter) for key in self.keys: - d[key] = self.flipper(d[key]) + d[key] = rotator(d[key]) return d @@ -825,7 +833,11 @@ class RandZoomd(Randomizable, MapTransform): keys (dict): Keys to pick data for transformation. prob (float): Probability of zooming. min_zoom (float or sequence): Min zoom factor. Can be float or sequence same size as image. + If a float, min_zoom is the same for each spatial axis. + If a sequence, min_zoom should contain one value for each spatial axis. max_zoom (float or sequence): Max zoom factor. Can be float or sequence same size as image. + If a float, max_zoom is the same for each spatial axis. + If a sequence, max_zoom should contain one value for each spatial axis. order (int): order of interpolation. Default=3. mode ('reflect', 'constant', 'nearest', 'mirror', 'wrap'): Determines how input is extended beyond boundaries. Default: 'constant'. diff --git a/monai/transforms/transforms.py b/monai/transforms/transforms.py index be6d38a4fa..8d9d18ccda 100644 --- a/monai/transforms/transforms.py +++ b/monai/transforms/transforms.py @@ -253,8 +253,7 @@ class GaussianNoise(Randomizable): Args: mean (float or array of floats): Mean or “centre” of the distribution. - scale (float): Standard deviation (spread) of distribution. - size (int or tuple of ints): Output shape. Default: None (single value is returned). + std (float): Standard deviation (spread) of distribution. """ def __init__(self, mean=0.0, std=0.1): @@ -267,19 +266,28 @@ def __call__(self, img): @export class Flip: - """Reverses the order of elements along the given axis. Preserves shape. + """Reverses the order of elements along the given spatial axis. Preserves shape. Uses ``np.flip`` in practice. See numpy.flip for additional details. https://docs.scipy.org/doc/numpy/reference/generated/numpy.flip.html Args: - axis (None, int or tuple of ints): Axes along which to flip over. Default is None. + spatial_axis (None, int or tuple of ints): spatial axes along which to flip over. Default is None. """ - def __init__(self, axis=None): - self.axis = axis + def __init__(self, spatial_axis=None): + self.spatial_axis = spatial_axis def __call__(self, img): - return np.flip(img, self.axis) + """ + Args: + img (ndarray): channel first array, must have shape: (num_channels, H[, W, ..., ]), + """ + flipped = list() + for channel in img: + flipped.append( + np.flip(channel, self.spatial_axis) + ) + return np.stack(flipped) @export @@ -289,7 +297,7 @@ class Resize: For additional details, see https://scikit-image.org/docs/dev/api/skimage.transform.html#skimage.transform.resize. Args: - output_shape (tuple or list): expected shape after resize operation. + output_spatial_shape (tuple or list): expected shape of spatial dimensions after resize operation. order (int): Order of spline interpolation. Default=1. mode (str): Points outside boundaries are filled according to given mode. Options are 'constant', 'edge', 'symmetric', 'reflect', 'wrap'. @@ -303,10 +311,10 @@ class Resize: anti_aliasing_sigma (float, tuple of floats): Standard deviation for gaussian filtering. """ - def __init__(self, output_shape, order=1, mode='reflect', cval=0, + def __init__(self, output_spatial_shape, order=1, mode='reflect', cval=0, clip=True, preserve_range=True, anti_aliasing=True, anti_aliasing_sigma=None): assert isinstance(order, int), "order must be integer." - self.output_shape = output_shape + self.output_spatial_shape = output_spatial_shape self.order = order self.mode = mode self.cval = cval @@ -316,10 +324,14 @@ def __init__(self, output_shape, order=1, mode='reflect', cval=0, self.anti_aliasing_sigma = anti_aliasing_sigma def __call__(self, img): + """ + Args: + img (ndarray): channel first array, must have shape: (num_channels, H[, W, ..., ]), + """ resized = list() for channel in img: resized.append( - resize(channel, self.output_shape, order=self.order, + resize(channel, self.output_spatial_shape, order=self.order, mode=self.mode, cval=self.cval, clip=self.clip, preserve_range=self.preserve_range, anti_aliasing=self.anti_aliasing, @@ -336,8 +348,8 @@ class Rotate: Args: angle (float): Rotation angle in degrees. - axes (tuple of 2 ints): Axes of rotation. Default: (1, 2). This is the first two - axis in spatial dimensions according to MONAI channel first shape assumption. + spatial_axes (tuple of 2 ints): Spatial axes of rotation. Default: (0, 1). + This is the first two axis in spatial dimensions. reshape (bool): If true, output shape is made same as input. Default: True. order (int): Order of spline interpolation. Range 0-5. Default: 1. This is different from scipy where default interpolation is 3. @@ -347,19 +359,27 @@ class Rotate: prefiter (bool): Apply spline_filter before interpolation. Default: True. """ - def __init__(self, angle, axes=(1, 2), reshape=True, order=1, mode='constant', cval=0, prefilter=True): + def __init__(self, angle, spatial_axes=(0, 1), reshape=True, order=1, mode='constant', cval=0, prefilter=True): self.angle = angle self.reshape = reshape self.order = order self.mode = mode self.cval = cval self.prefilter = prefilter - self.axes = axes + self.spatial_axes = spatial_axes def __call__(self, img): - return scipy.ndimage.rotate(img, self.angle, self.axes, - reshape=self.reshape, order=self.order, mode=self.mode, cval=self.cval, - prefilter=self.prefilter).astype(np.float32) + """ + Args: + img (ndarray): channel first array, must have shape: (num_channels, H[, W, ..., ]), + """ + rotated = list() + for channel in img: + rotated.append( + scipy.ndimage.rotate(channel, self.angle, self.spatial_axes, reshape=self.reshape, + order=self.order, mode=self.mode, cval=self.cval, prefilter=self.prefilter) + ) + return np.stack(rotated).astype(np.float32) @export @@ -406,7 +426,7 @@ def __call__(self, img): Args: img (ndarray): channel first array, must have shape: (num_channels, H[, W, ..., ]), """ - zoomed = [] + zoomed = list() if self.use_gpu: import cupy for channel in cupy.array(img): @@ -458,10 +478,13 @@ def __call__(self, img): class UniformRandomPatch(Randomizable): """ Selects a patch of the given size chosen at a uniformly random position in the image. + + Args: + patch_spatial_size (tuple or list): Expected patch size of spatial dimensions. """ - def __init__(self, patch_size): - self.patch_size = (None,) + tuple(patch_size) + def __init__(self, patch_spatial_size): + self.patch_spatial_size = (None,) + tuple(patch_spatial_size) self._slices = None @@ -469,8 +492,8 @@ def randomize(self, image_shape, patch_shape): self._slices = get_random_patch(image_shape, patch_shape, self.R) def __call__(self, img): - patch_size = get_valid_patch_size(img.shape, self.patch_size) - self.randomize(img.shape, patch_size) + patch_spatial_size = get_valid_patch_size(img.shape, self.patch_spatial_size) + self.randomize(img.shape, patch_spatial_size) return img[self._slices] @@ -537,39 +560,49 @@ class Rotate90: Rotate an array by 90 degrees in the plane specified by `axes`. """ - def __init__(self, k=1, axes=(1, 2)): + def __init__(self, k=1, spatial_axes=(0, 1)): """ Args: k (int): number of times to rotate by 90 degrees. - axes (2 ints): defines the plane to rotate with 2 axes. + spatial_axes (2 ints): defines the plane to rotate with 2 spatial axes. + Default: (0, 1), this is the first two axis in spatial dimensions. """ self.k = k - self.plane_axes = axes + self.spatial_axes = spatial_axes def __call__(self, img): - return np.ascontiguousarray(np.rot90(img, self.k, self.plane_axes)) + """ + Args: + img (ndarray): channel first array, must have shape: (num_channels, H[, W, ..., ]), + """ + rotated = list() + for channel in img: + rotated.append( + np.rot90(channel, self.k, self.spatial_axes) + ) + return np.stack(rotated) @export class RandRotate90(Randomizable): """ With probability `prob`, input arrays are rotated by 90 degrees - in the plane specified by `axes`. + in the plane specified by `spatial_axes`. """ - def __init__(self, prob=0.1, max_k=3, axes=(1, 2)): + def __init__(self, prob=0.1, max_k=3, spatial_axes=(0, 1)): """ Args: prob (float): probability of rotating. (Default 0.1, with 10% probability it returns a rotated array) max_k (int): number of rotations will be sampled from `np.random.randint(max_k) + 1`. (Default 3) - axes (2 ints): defines the plane to rotate with 2 axes. - (Default (1, 2)) + spatial_axes (2 ints): defines the plane to rotate with 2 spatial axes. + Default: (0, 1), this is the first two axis in spatial dimensions. """ self.prob = min(max(prob, 0.0), 1.0) self.max_k = max_k - self.axes = axes + self.spatial_axes = spatial_axes self._do_transform = False self._rand_k = 0 @@ -582,7 +615,7 @@ def __call__(self, img): self.randomize() if not self._do_transform: return img - rotator = Rotate90(self._rand_k, self.axes) + rotator = Rotate90(self._rand_k, self.spatial_axes) return rotator(img) @@ -590,7 +623,7 @@ def __call__(self, img): class SpatialCrop: """General purpose cropper to produce sub-volume region of interest (ROI). It can support to crop ND spatial (channel-first) data. - Either a center and size must be provided, or alternatively if center and size + Either a spatial center and size must be provided, or alternatively if center and size are not provided, the start and end coordinates of the ROI must be provided. The sub-volume must sit the within original image. @@ -638,8 +671,8 @@ class RandRotate(Randomizable): prob (float): Probability of rotation. degrees (tuple of float or float): Range of rotation in degrees. If single number, angle is picked from (-degrees, degrees). - axes (tuple of 2 ints): Axes of rotation. Default: (1, 2). This is the first two - axis in spatial dimensions according to MONAI channel first shape assumption. + spatial_axes (tuple of 2 ints): Spatial axes of rotation. Default: (0, 1). + This is the first two axis in spatial dimensions. reshape (bool): If true, output shape is made same as input. Default: True. order (int): Order of spline interpolation. Range 0-5. Default: 1. This is different from scipy where default interpolation is 3. @@ -649,7 +682,7 @@ class RandRotate(Randomizable): prefiter (bool): Apply spline_filter before interpolation. Default: True. """ - def __init__(self, degrees, prob=0.1, axes=(1, 2), reshape=True, order=1, + def __init__(self, degrees, prob=0.1, spatial_axes=(0, 1), reshape=True, order=1, mode='constant', cval=0, prefilter=True): self.prob = prob self.degrees = degrees @@ -658,7 +691,7 @@ def __init__(self, degrees, prob=0.1, axes=(1, 2), reshape=True, order=1, self.mode = mode self.cval = cval self.prefilter = prefilter - self.axes = axes + self.spatial_axes = spatial_axes if not hasattr(self.degrees, '__iter__'): self.degrees = (-self.degrees, self.degrees) @@ -675,7 +708,7 @@ def __call__(self, img): self.randomize() if not self._do_transform: return img - rotator = Rotate(self.angle, self.axes, self.reshape, self.order, + rotator = Rotate(self.angle, self.spatial_axes, self.reshape, self.order, self.mode, self.cval, self.prefilter) return rotator(img) @@ -688,15 +721,13 @@ class RandFlip(Randomizable): Args: prob (float): Probability of flipping. - axis (None, int or tuple of ints): Axes along which to flip over. Default is None. + spatial_axis (None, int or tuple of ints): Spatial axes along which to flip over. Default is None. """ - def __init__(self, prob=0.1, axis=None): + def __init__(self, prob=0.1, spatial_axis=None): self.prob = prob - self.flipper = Flip(axis=axis) - + self.flipper = Flip(spatial_axis=spatial_axis) self._do_transform = False - self.flipper = Flip(axis=axis) def randomize(self): self._do_transform = self.R.random_sample() < self.prob @@ -715,7 +746,11 @@ class RandZoom(Randomizable): Args: prob (float): Probability of zooming. min_zoom (float or sequence): Min zoom factor. Can be float or sequence same size as image. + If a float, min_zoom is the same for each spatial axis. + If a sequence, min_zoom should contain one value for each spatial axis. max_zoom (float or sequence): Max zoom factor. Can be float or sequence same size as image. + If a float, max_zoom is the same for each spatial axis. + If a sequence, max_zoom should contain one value for each spatial axis. order (int): order of interpolation. Default=3. mode ('reflect', 'constant', 'nearest', 'mirror', 'wrap'): Determines how input is extended beyond boundaries. Default: 'constant'. diff --git a/tests/integration_sliding_window.py b/tests/integration_sliding_window.py index 4677b33b3a..e41f6574d3 100644 --- a/tests/integration_sliding_window.py +++ b/tests/integration_sliding_window.py @@ -52,7 +52,7 @@ def _sliding_window_processor(_engine, batch): net.eval() img, seg, meta_data = batch with torch.no_grad(): - seg_probs = sliding_window_inference(img, roi_size, sw_batch_size, lambda x: net(x), device) + seg_probs = sliding_window_inference(img, roi_size, sw_batch_size, net, device) return predict_segmentation(seg_probs) infer_engine = Engine(_sliding_window_processor) diff --git a/tests/test_flip.py b/tests/test_flip.py index a261c315e2..050d66d8db 100644 --- a/tests/test_flip.py +++ b/tests/test_flip.py @@ -14,7 +14,7 @@ import numpy as np from parameterized import parameterized -from monai.transforms import Flip, Flipd +from monai.transforms import Flip from tests.utils import NumpyImageTestCase2D INVALID_CASES = [("wrong_axis", ['s', 1], TypeError), @@ -22,35 +22,25 @@ VALID_CASES = [("no_axis", None), ("one_axis", 1), - ("many_axis", [0, 1, 2])] + ("many_axis", [0, 1])] -class FlipTest(NumpyImageTestCase2D): +class TestFlip(NumpyImageTestCase2D): @parameterized.expand(INVALID_CASES) - def test_invalid_inputs(self, _, axis, raises): + def test_invalid_inputs(self, _, spatial_axis, raises): with self.assertRaises(raises): - flip = Flip(axis) - flip(self.imt) - - @parameterized.expand(INVALID_CASES) - def test_invalid_cases_dict(self, _, axis, raises): - with self.assertRaises(raises): - flip = Flipd(keys='img', axis=axis) - flip({'img': self.imt}) - - @parameterized.expand(VALID_CASES) - def test_correct_results(self, _, axis): - flip = Flip(axis=axis) - expected = np.flip(self.imt, axis) - self.assertTrue(np.allclose(expected, flip(self.imt))) + flip = Flip(spatial_axis) + flip(self.imt[0]) @parameterized.expand(VALID_CASES) - def test_correct_results_dict(self, _, axis): - flip = Flipd(keys='img', axis=axis) - expected = np.flip(self.imt, axis) - res = flip({'img': self.imt}) - assert np.allclose(expected, res['img']) + def test_correct_results(self, _, spatial_axis): + flip = Flip(spatial_axis=spatial_axis) + expected = list() + for channel in self.imt[0]: + expected.append(np.flip(channel, spatial_axis)) + expected = np.stack(expected) + self.assertTrue(np.allclose(expected, flip(self.imt[0]))) if __name__ == '__main__': diff --git a/tests/test_flipd.py b/tests/test_flipd.py new file mode 100644 index 0000000000..e2fcb6b915 --- /dev/null +++ b/tests/test_flipd.py @@ -0,0 +1,48 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import Flipd +from tests.utils import NumpyImageTestCase2D + +INVALID_CASES = [("wrong_axis", ['s', 1], TypeError), + ("not_numbers", 's', TypeError)] + +VALID_CASES = [("no_axis", None), + ("one_axis", 1), + ("many_axis", [0, 1])] + + +class TestFlipd(NumpyImageTestCase2D): + + @parameterized.expand(INVALID_CASES) + def test_invalid_cases(self, _, spatial_axis, raises): + with self.assertRaises(raises): + flip = Flipd(keys='img', spatial_axis=spatial_axis) + flip({'img': self.imt[0]}) + + @parameterized.expand(VALID_CASES) + def test_correct_results(self, _, spatial_axis): + flip = Flipd(keys='img', spatial_axis=spatial_axis) + expected = list() + for channel in self.imt[0]: + expected.append(np.flip(channel, spatial_axis)) + expected = np.stack(expected) + res = flip({'img': self.imt[0]}) + assert np.allclose(expected, res['img']) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_random_affine.py b/tests/test_rand_affine.py similarity index 100% rename from tests/test_random_affine.py rename to tests/test_rand_affine.py diff --git a/tests/test_random_affine_grid.py b/tests/test_rand_affine_grid.py similarity index 100% rename from tests/test_random_affine_grid.py rename to tests/test_rand_affine_grid.py diff --git a/tests/test_random_affined.py b/tests/test_rand_affined.py similarity index 100% rename from tests/test_random_affined.py rename to tests/test_rand_affined.py diff --git a/tests/test_random_deform_grid.py b/tests/test_rand_deform_grid.py similarity index 100% rename from tests/test_random_deform_grid.py rename to tests/test_rand_deform_grid.py diff --git a/tests/test_random_elastic_2d.py b/tests/test_rand_elastic_2d.py similarity index 100% rename from tests/test_random_elastic_2d.py rename to tests/test_rand_elastic_2d.py diff --git a/tests/test_random_elastic_3d.py b/tests/test_rand_elastic_3d.py similarity index 100% rename from tests/test_random_elastic_3d.py rename to tests/test_rand_elastic_3d.py diff --git a/tests/test_random_elasticd_2d.py b/tests/test_rand_elasticd_2d.py similarity index 100% rename from tests/test_random_elasticd_2d.py rename to tests/test_rand_elasticd_2d.py diff --git a/tests/test_random_elasticd_3d.py b/tests/test_rand_elasticd_3d.py similarity index 100% rename from tests/test_random_elasticd_3d.py rename to tests/test_rand_elasticd_3d.py diff --git a/tests/test_rand_flip.py b/tests/test_rand_flip.py index be03ff5a28..1206c85571 100644 --- a/tests/test_rand_flip.py +++ b/tests/test_rand_flip.py @@ -14,7 +14,7 @@ import numpy as np from parameterized import parameterized -from monai.transforms import RandFlip, RandFlipd +from monai.transforms import RandFlip from tests.utils import NumpyImageTestCase2D INVALID_CASES = [("wrong_axis", ['s', 1], TypeError), @@ -22,29 +22,24 @@ VALID_CASES = [("no_axis", None), ("one_axis", 1), - ("many_axis", [0, 1, 2])] + ("many_axis", [0, 1])] -class RandFlipTest(NumpyImageTestCase2D): +class TestRandFlip(NumpyImageTestCase2D): @parameterized.expand(INVALID_CASES) - def test_invalid_inputs(self, _, axis, raises): + def test_invalid_inputs(self, _, spatial_axis, raises): with self.assertRaises(raises): - flip = RandFlip(prob=1.0, axis=axis) - flip(self.imt) + flip = RandFlip(prob=1.0, spatial_axis=spatial_axis) + flip(self.imt[0]) @parameterized.expand(VALID_CASES) - def test_correct_results(self, _, axis): - flip = RandFlip(prob=1.0, axis=axis) - expected = np.flip(self.imt, axis) - self.assertTrue(np.allclose(expected, flip(self.imt))) - - @parameterized.expand(VALID_CASES) - def test_correct_results_dict(self, _, axis): - flip = RandFlipd(keys='img', prob=1.0, axis=axis) - res = flip({'img': self.imt}) - - expected = np.flip(self.imt, axis) - self.assertTrue(np.allclose(expected, res['img'])) + def test_correct_results(self, _, spatial_axis): + flip = RandFlip(prob=1.0, spatial_axis=spatial_axis) + expected = list() + for channel in self.imt[0]: + expected.append(np.flip(channel, spatial_axis)) + expected = np.stack(expected) + self.assertTrue(np.allclose(expected, flip(self.imt[0]))) if __name__ == '__main__': diff --git a/tests/test_rand_flipd.py b/tests/test_rand_flipd.py new file mode 100644 index 0000000000..bcda54eecd --- /dev/null +++ b/tests/test_rand_flipd.py @@ -0,0 +1,38 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import RandFlipd +from tests.utils import NumpyImageTestCase2D + +VALID_CASES = [("no_axis", None), + ("one_axis", 1), + ("many_axis", [0, 1])] + +class TestRandFlipd(NumpyImageTestCase2D): + + @parameterized.expand(VALID_CASES) + def test_correct_results(self, _, spatial_axis): + flip = RandFlipd(keys='img', prob=1.0, spatial_axis=spatial_axis) + res = flip({'img': self.imt[0]}) + expected = list() + for channel in self.imt[0]: + expected.append(np.flip(channel, spatial_axis)) + expected = np.stack(expected) + self.assertTrue(np.allclose(expected, res['img'])) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_rand_rotate.py b/tests/test_rand_rotate.py index 29036663af..1e5a18bfc8 100644 --- a/tests/test_rand_rotate.py +++ b/tests/test_rand_rotate.py @@ -19,23 +19,26 @@ from tests.utils import NumpyImageTestCase2D -class RandomRotateTest(NumpyImageTestCase2D): +class TestRandRotate(NumpyImageTestCase2D): @parameterized.expand([ - (90, (1, 2), True, 1, 'reflect', 0, True), - ((-45, 45), (2, 1), True, 3, 'constant', 0, True), - (180, (2, 3), False, 2, 'constant', 4, False), + (90, (0, 1), True, 1, 'reflect', 0, True), + ((-45, 45), (1, 0), True, 3, 'constant', 0, True), + (180, (1, 0), False, 2, 'constant', 4, False), ]) - def test_correct_results(self, degrees, axes, reshape, + def test_correct_results(self, degrees, spatial_axes, reshape, order, mode, cval, prefilter): - rotate_fn = RandRotate(degrees, prob=1.0, axes=axes, reshape=reshape, + rotate_fn = RandRotate(degrees, prob=1.0, spatial_axes=spatial_axes, reshape=reshape, order=order, mode=mode, cval=cval, prefilter=prefilter) rotate_fn.set_random_state(243) - rotated = rotate_fn(self.imt) + rotated = rotate_fn(self.imt[0]) angle = rotate_fn.angle - expected = scipy.ndimage.rotate(self.imt, angle, axes, reshape, order=order, - mode=mode, cval=cval, prefilter=prefilter) + expected = list() + for channel in self.imt[0]: + expected.append(scipy.ndimage.rotate(channel, angle, spatial_axes, reshape, order=order, + mode=mode, cval=cval, prefilter=prefilter)) + expected = np.stack(expected).astype(np.float32) self.assertTrue(np.allclose(expected, rotated)) diff --git a/tests/test_rand_rotate90.py b/tests/test_rand_rotate90.py index 4b291d8cf0..e50c3e0c67 100644 --- a/tests/test_rand_rotate90.py +++ b/tests/test_rand_rotate90.py @@ -17,34 +17,46 @@ from tests.utils import NumpyImageTestCase2D -class Rotate90Test(NumpyImageTestCase2D): +class TestRandRotate90(NumpyImageTestCase2D): def test_default(self): rotate = RandRotate90() rotate.set_random_state(123) - rotated = rotate(self.imt) - expected = np.rot90(self.imt, 0, (1, 2)) + rotated = rotate(self.imt[0]) + expected = list() + for channel in self.imt[0]: + expected.append(np.rot90(channel, 0, (0, 1))) + expected = np.stack(expected) self.assertTrue(np.allclose(rotated, expected)) def test_k(self): rotate = RandRotate90(max_k=2) rotate.set_random_state(234) - rotated = rotate(self.imt) - expected = np.rot90(self.imt, 0, (1, 2)) + rotated = rotate(self.imt[0]) + expected = list() + for channel in self.imt[0]: + expected.append(np.rot90(channel, 0, (0, 1))) + expected = np.stack(expected) self.assertTrue(np.allclose(rotated, expected)) - def test_axes(self): - rotate = RandRotate90(axes=(1, 2)) + def test_spatial_axes(self): + rotate = RandRotate90(spatial_axes=(0, 1)) rotate.set_random_state(234) - rotated = rotate(self.imt) - expected = np.rot90(self.imt, 0, (1, 2)) + rotated = rotate(self.imt[0]) + expected = list() + for channel in self.imt[0]: + expected.append(np.rot90(channel, 0, (0, 1))) + expected = np.stack(expected) self.assertTrue(np.allclose(rotated, expected)) - def test_prob_k_axes(self): - rotate = RandRotate90(prob=1.0, max_k=2, axes=(2, 3)) + def test_prob_k_spatial_axes(self): + rotate = RandRotate90(prob=1.0, max_k=2, spatial_axes=(0, 1)) rotate.set_random_state(234) - rotated = rotate(self.imt) - expected = np.rot90(self.imt, 1, (2, 3)) + rotated = rotate(self.imt[0]) + expected = list() + for channel in self.imt[0]: + expected.append(np.rot90(channel, 1, (0, 1))) + expected = np.stack(expected) self.assertTrue(np.allclose(rotated, expected)) diff --git a/tests/test_rand_rotate90d.py b/tests/test_rand_rotate90d.py index c52a82389f..193627fef1 100644 --- a/tests/test_rand_rotate90d.py +++ b/tests/test_rand_rotate90d.py @@ -17,45 +17,57 @@ from tests.utils import NumpyImageTestCase2D -class Rotate90Test(NumpyImageTestCase2D): +class TestRandRotate90d(NumpyImageTestCase2D): def test_default(self): key = None rotate = RandRotate90d(keys=key) rotate.set_random_state(123) - rotated = rotate({key: self.imt}) - expected = np.rot90(self.imt, 0, (1, 2)) + rotated = rotate({key: self.imt[0]}) + expected = list() + for channel in self.imt[0]: + expected.append(np.rot90(channel, 0, (0, 1))) + expected = np.stack(expected) self.assertTrue(np.allclose(rotated[key], expected)) def test_k(self): key = 'test' rotate = RandRotate90d(keys=key, max_k=2) rotate.set_random_state(234) - rotated = rotate({key: self.imt}) - expected = np.rot90(self.imt, 0, (1, 2)) + rotated = rotate({key: self.imt[0]}) + expected = list() + for channel in self.imt[0]: + expected.append(np.rot90(channel, 0, (0, 1))) + expected = np.stack(expected) self.assertTrue(np.allclose(rotated[key], expected)) - def test_axes(self): - key = ['test'] - rotate = RandRotate90d(keys=key, axes=(1, 2)) + def test_spatial_axes(self): + key = 'test' + rotate = RandRotate90d(keys=key, spatial_axes=(0, 1)) rotate.set_random_state(234) - rotated = rotate({key[0]: self.imt}) - expected = np.rot90(self.imt, 0, (1, 2)) - self.assertTrue(np.allclose(rotated[key[0]], expected)) + rotated = rotate({key: self.imt[0]}) + expected = list() + for channel in self.imt[0]: + expected.append(np.rot90(channel, 0, (0, 1))) + expected = np.stack(expected) + self.assertTrue(np.allclose(rotated[key], expected)) - def test_prob_k_axes(self): - key = ('test',) - rotate = RandRotate90d(keys=key, prob=1.0, max_k=2, axes=(2, 3)) + def test_prob_k_spatial_axes(self): + key = 'test' + rotate = RandRotate90d(keys=key, prob=1.0, max_k=2, spatial_axes=(0, 1)) rotate.set_random_state(234) - rotated = rotate({key[0]: self.imt}) - expected = np.rot90(self.imt, 1, (2, 3)) - self.assertTrue(np.allclose(rotated[key[0]], expected)) + rotated = rotate({key: self.imt[0]}) + expected = list() + for channel in self.imt[0]: + expected.append(np.rot90(channel, 1, (0, 1))) + expected = np.stack(expected) + self.assertTrue(np.allclose(rotated[key], expected)) def test_no_key(self): key = 'unknown' - rotate = RandRotate90d(keys=key, prob=1.0, max_k=2, axes=(2, 3)) + rotate = RandRotate90d(keys=key, prob=1.0, max_k=2, spatial_axes=(0, 1)) with self.assertRaisesRegex(KeyError, ''): - rotated = rotate({'test': self.imt}) + rotated = rotate({'test': self.imt[0]}) if __name__ == '__main__': diff --git a/tests/test_rand_rotated.py b/tests/test_rand_rotated.py new file mode 100644 index 0000000000..1c9d98e83e --- /dev/null +++ b/tests/test_rand_rotated.py @@ -0,0 +1,46 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np + +import scipy.ndimage +from parameterized import parameterized + +from monai.transforms import RandRotated +from tests.utils import NumpyImageTestCase2D + + +class TestRandRotated(NumpyImageTestCase2D): + + @parameterized.expand([ + (90, (0, 1), True, 1, 'reflect', 0, True), + ((-45, 45), (1, 0), True, 3, 'constant', 0, True), + (180, (1, 0), False, 2, 'constant', 4, False), + ]) + def test_correct_results(self, degrees, spatial_axes, reshape, + order, mode, cval, prefilter): + rotate_fn = RandRotated('img', degrees, prob=1.0, spatial_axes=spatial_axes, reshape=reshape, + order=order, mode=mode, cval=cval, prefilter=prefilter) + rotate_fn.set_random_state(243) + rotated = rotate_fn({'img': self.imt[0]}) + + angle = rotate_fn.angle + expected = list() + for channel in self.imt[0]: + expected.append(scipy.ndimage.rotate(channel, angle, spatial_axes, reshape, order=order, + mode=mode, cval=cval, prefilter=prefilter)) + expected = np.stack(expected).astype(np.float32) + self.assertTrue(np.allclose(expected, rotated['img'])) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_rand_zoom.py b/tests/test_rand_zoom.py index 530504b887..7dfdb7a522 100644 --- a/tests/test_rand_zoom.py +++ b/tests/test_rand_zoom.py @@ -17,12 +17,12 @@ from scipy.ndimage import zoom as zoom_scipy from parameterized import parameterized -from monai.transforms import RandZoom, RandZoomd +from monai.transforms import RandZoom from tests.utils import NumpyImageTestCase2D VALID_CASES = [(0.9, 1.1, 3, 'constant', 0, True, False, False)] -class ZoomTest(NumpyImageTestCase2D): +class TestRandZoom(NumpyImageTestCase2D): @parameterized.expand(VALID_CASES) def test_correct_results(self, min_zoom, max_zoom, order, mode, @@ -31,28 +31,14 @@ def test_correct_results(self, min_zoom, max_zoom, order, mode, mode=mode, cval=cval, prefilter=prefilter, use_gpu=use_gpu, keep_size=keep_size) random_zoom.set_random_state(234) - - zoomed = random_zoom(self.imt) - expected = zoom_scipy(self.imt, zoom=random_zoom._zoom, mode=mode, - order=order, cval=cval, prefilter=prefilter) - + zoomed = random_zoom(self.imt[0]) + expected = list() + for channel in self.imt[0]: + expected.append(zoom_scipy(channel, zoom=random_zoom._zoom, mode=mode, order=order, + cval=cval, prefilter=prefilter)) + expected = np.stack(expected).astype(np.float32) self.assertTrue(np.allclose(expected, zoomed)) - @parameterized.expand(VALID_CASES) - def test_correct_results_dict(self, min_zoom, max_zoom, order, mode, - cval, prefilter, use_gpu, keep_size): - keys = 'img' - random_zoom = RandZoomd(keys, prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, order=order, - mode=mode, cval=cval, prefilter=prefilter, use_gpu=use_gpu, - keep_size=keep_size) - random_zoom.set_random_state(234) - - zoomed = random_zoom({keys: self.imt}) - expected = zoom_scipy(self.imt, zoom=random_zoom._zoom, mode=mode, - order=order, cval=cval, prefilter=prefilter) - - self.assertTrue(np.allclose(expected, zoomed[keys])) - @parameterized.expand([ (0.8, 1.2, 1, 'constant', 0, True) ]) @@ -64,17 +50,20 @@ def test_gpu_zoom(self, min_zoom, max_zoom, order, mode, cval, prefilter): keep_size=False) random_zoom.set_random_state(234) - zoomed = random_zoom(self.imt) - expected = zoom_scipy(self.imt, zoom=random_zoom._zoom, mode=mode, order=order, - cval=cval, prefilter=prefilter) + zoomed = random_zoom(self.imt[0]) + expected = list() + for channel in self.imt[0]: + expected.append(zoom_scipy(channel, zoom=random_zoom._zoom, mode=mode, order=order, + cval=cval, prefilter=prefilter)) + expected = np.stack(expected).astype(np.float32) self.assertTrue(np.allclose(expected, zoomed)) def test_keep_size(self): random_zoom = RandZoom(prob=1.0, min_zoom=0.6, max_zoom=0.7, keep_size=True) - zoomed = random_zoom(self.imt) - self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape)) + zoomed = random_zoom(self.imt[0]) + self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:])) @parameterized.expand([ ("no_min_zoom", None, 1.1, 1, TypeError), @@ -83,7 +72,7 @@ def test_keep_size(self): def test_invalid_inputs(self, _, min_zoom, max_zoom, order, raises): with self.assertRaises(raises): random_zoom = RandZoom(prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, order=order) - zoomed = random_zoom(self.imt) + zoomed = random_zoom(self.imt[0]) if __name__ == '__main__': diff --git a/tests/test_rand_zoomd.py b/tests/test_rand_zoomd.py new file mode 100644 index 0000000000..9a5838da4b --- /dev/null +++ b/tests/test_rand_zoomd.py @@ -0,0 +1,83 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import importlib + +from scipy.ndimage import zoom as zoom_scipy +from parameterized import parameterized + +from monai.transforms import RandZoomd +from tests.utils import NumpyImageTestCase2D + +VALID_CASES = [(0.9, 1.1, 3, 'constant', 0, True, False, False)] + +class TestRandZoomd(NumpyImageTestCase2D): + + @parameterized.expand(VALID_CASES) + def test_correct_results(self, min_zoom, max_zoom, order, mode, + cval, prefilter, use_gpu, keep_size): + key = 'img' + random_zoom = RandZoomd(key, prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, order=order, + mode=mode, cval=cval, prefilter=prefilter, use_gpu=use_gpu, + keep_size=keep_size) + random_zoom.set_random_state(234) + + zoomed = random_zoom({key: self.imt[0]}) + expected = list() + for channel in self.imt[0]: + expected.append(zoom_scipy(channel, zoom=random_zoom._zoom, mode=mode, order=order, + cval=cval, prefilter=prefilter)) + expected = np.stack(expected).astype(np.float32) + self.assertTrue(np.allclose(expected, zoomed[key])) + + @parameterized.expand([ + (0.8, 1.2, 1, 'constant', 0, True) + ]) + def test_gpu_zoom(self, min_zoom, max_zoom, order, mode, cval, prefilter): + key = 'img' + if importlib.util.find_spec('cupy'): + random_zoom = RandZoomd( + key, prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, order=order, + mode=mode, cval=cval, prefilter=prefilter, use_gpu=True, + keep_size=False) + random_zoom.set_random_state(234) + + zoomed = random_zoom({key: self.imt[0]}) + expected = list() + for channel in self.imt[0]: + expected.append(zoom_scipy(channel, zoom=random_zoom._zoom, mode=mode, order=order, + cval=cval, prefilter=prefilter)) + expected = np.stack(expected).astype(np.float32) + self.assertTrue(np.allclose(expected, zoomed)) + + def test_keep_size(self): + key = 'img' + random_zoom = RandZoomd(key, prob=1.0, min_zoom=0.6, + max_zoom=0.7, keep_size=True) + zoomed = random_zoom({key: self.imt[0]}) + self.assertTrue(np.array_equal(zoomed[key].shape, self.imt.shape[1:])) + + @parameterized.expand([ + ("no_min_zoom", None, 1.1, 1, TypeError), + ("invalid_order", 0.9, 1.1 , 's', AssertionError) + ]) + def test_invalid_inputs(self, _, min_zoom, max_zoom, order, raises): + key = 'img' + with self.assertRaises(raises): + random_zoom = RandZoomd(key, prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, order=order) + zoomed = random_zoom({key: self.imt[0]}) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_resize.py b/tests/test_resize.py index 8040c6ebe4..30f8101baa 100644 --- a/tests/test_resize.py +++ b/tests/test_resize.py @@ -19,37 +19,37 @@ from tests.utils import NumpyImageTestCase2D -class TestResized(NumpyImageTestCase2D): +class TestResize(NumpyImageTestCase2D): @parameterized.expand([ ("invalid_order", "order", AssertionError) ]) def test_invalid_inputs(self, _, order, raises): with self.assertRaises(raises): - resize = Resize(output_shape=(128, 128, 3), order=order) - resize(self.imt) + resize = Resize(output_spatial_shape=(128, 128, 3), order=order) + resize(self.imt[0]) @parameterized.expand([ ((64, 64), 1, 'reflect', 0, True, True, True, None), ((32, 32), 2, 'constant', 3, False, False, False, None), ((256, 256), 3, 'constant', 3, False, False, False, None), ]) - def test_correct_results(self, output_shape, order, mode, + def test_correct_results(self, output_spatial_shape, order, mode, cval, clip, preserve_range, anti_aliasing, anti_aliasing_sigma): - resize = Resize(output_shape, order, mode, cval, clip, + resize = Resize(output_spatial_shape, order, mode, cval, clip, preserve_range, anti_aliasing, anti_aliasing_sigma) expected = list() - for channel in self.imt: - expected.append(skimage.transform.resize(channel, output_shape, + for channel in self.imt[0]: + expected.append(skimage.transform.resize(channel, output_spatial_shape, order=order, mode=mode, cval=cval, clip=clip, preserve_range=preserve_range, anti_aliasing=anti_aliasing, anti_aliasing_sigma=anti_aliasing_sigma)) expected = np.stack(expected).astype(np.float32) - self.assertTrue(np.allclose(resize(self.imt), expected)) + self.assertTrue(np.allclose(resize(self.imt[0]), expected)) if __name__ == '__main__': diff --git a/tests/test_resized.py b/tests/test_resized.py index f1ec07b74b..d7830d3e1d 100644 --- a/tests/test_resized.py +++ b/tests/test_resized.py @@ -26,30 +26,30 @@ class TestResized(NumpyImageTestCase2D): ]) def test_invalid_inputs(self, _, order, raises): with self.assertRaises(raises): - resize = Resized(keys='img', output_shape=(128, 128, 3), order=order) - resize({'img': self.imt}) + resize = Resized(keys='img', output_spatial_shape=(128, 128, 3), order=order) + resize({'img': self.imt[0]}) @parameterized.expand([ ((64, 64), 1, 'reflect', 0, True, True, True, None), ((32, 32), 2, 'constant', 3, False, False, False, None), ((256, 256), 3, 'constant', 3, False, False, False, None), ]) - def test_correct_results(self, output_shape, order, mode, + def test_correct_results(self, output_spatial_shape, order, mode, cval, clip, preserve_range, anti_aliasing, anti_aliasing_sigma): - resize = Resized('img', output_shape, order, mode, cval, clip, + resize = Resized('img', output_spatial_shape, order, mode, cval, clip, preserve_range, anti_aliasing, anti_aliasing_sigma) expected = list() - for channel in self.imt: - expected.append(skimage.transform.resize(channel, output_shape, + for channel in self.imt[0]: + expected.append(skimage.transform.resize(channel, output_spatial_shape, order=order, mode=mode, cval=cval, clip=clip, preserve_range=preserve_range, anti_aliasing=anti_aliasing, anti_aliasing_sigma=anti_aliasing_sigma)) expected = np.stack(expected).astype(np.float32) - self.assertTrue(np.allclose(resize({'img': self.imt})['img'], expected)) + self.assertTrue(np.allclose(resize({'img': self.imt[0]})['img'], expected)) if __name__ == '__main__': diff --git a/tests/test_rotate.py b/tests/test_rotate.py index 0c34f5809e..7d6d1b531b 100644 --- a/tests/test_rotate.py +++ b/tests/test_rotate.py @@ -15,38 +15,27 @@ import scipy.ndimage from parameterized import parameterized -from monai.transforms import Rotate, Rotated +from monai.transforms import Rotate from tests.utils import NumpyImageTestCase2D -TEST_CASES = [(90, (1, 2), True, 1, 'reflect', 0, True), - (-90, (2, 1), True, 3, 'constant', 0, True), - (180, (2, 3), False, 2, 'constant', 4, False)] +TEST_CASES = [(90, (0, 1), True, 1, 'reflect', 0, True), + (-90, (1, 0), True, 3, 'constant', 0, True), + (180, (1, 0), False, 2, 'constant', 4, False)] -class RotateTest(NumpyImageTestCase2D): +class TestRotate(NumpyImageTestCase2D): @parameterized.expand(TEST_CASES) - def test_correct_results(self, angle, axes, reshape, + def test_correct_results(self, angle, spatial_axes, reshape, order, mode, cval, prefilter): - rotate_fn = Rotate(angle, axes, reshape, + rotate_fn = Rotate(angle, spatial_axes, reshape, order, mode, cval, prefilter) - rotated = rotate_fn(self.imt) - - expected = scipy.ndimage.rotate(self.imt, angle, axes, reshape, order=order, - mode=mode, cval=cval, prefilter=prefilter) + rotated = rotate_fn(self.imt[0]) + expected = list() + for channel in self.imt[0]: + expected.append(scipy.ndimage.rotate(channel, angle, spatial_axes, reshape, order=order, + mode=mode, cval=cval, prefilter=prefilter)) + expected = np.stack(expected).astype(np.float32) self.assertTrue(np.allclose(expected, rotated)) - @parameterized.expand(TEST_CASES) - def test_correct_results_dict(self, angle, axes, reshape, - order, mode, cval, prefilter): - key = 'img' - rotate_fn = Rotated(key, angle, axes, reshape, order, - mode, cval, prefilter) - rotated = rotate_fn({key: self.imt}) - - expected = scipy.ndimage.rotate(self.imt, angle, axes, reshape, order=order, - mode=mode, cval=cval, prefilter=prefilter) - self.assertTrue(np.allclose(expected, rotated[key])) - - if __name__ == '__main__': unittest.main() diff --git a/tests/test_rotate90.py b/tests/test_rotate90.py index 1b1aca78df..990e489cd9 100644 --- a/tests/test_rotate90.py +++ b/tests/test_rotate90.py @@ -17,30 +17,42 @@ from tests.utils import NumpyImageTestCase2D -class Rotate90Test(NumpyImageTestCase2D): +class TestRotate90(NumpyImageTestCase2D): def test_rotate90_default(self): rotate = Rotate90() - rotated = rotate(self.imt) - expected = np.rot90(self.imt, 1, (1, 2)) + rotated = rotate(self.imt[0]) + expected = list() + for channel in self.imt[0]: + expected.append(np.rot90(channel, 1, (0, 1))) + expected = np.stack(expected) self.assertTrue(np.allclose(rotated, expected)) def test_k(self): rotate = Rotate90(k=2) - rotated = rotate(self.imt) - expected = np.rot90(self.imt, 2, (1, 2)) + rotated = rotate(self.imt[0]) + expected = list() + for channel in self.imt[0]: + expected.append(np.rot90(channel, 2, (0, 1))) + expected = np.stack(expected) self.assertTrue(np.allclose(rotated, expected)) - def test_axes(self): - rotate = Rotate90(axes=(1, 2)) - rotated = rotate(self.imt) - expected = np.rot90(self.imt, 1, (1, 2)) + def test_spatial_axes(self): + rotate = Rotate90(spatial_axes=(0, 1)) + rotated = rotate(self.imt[0]) + expected = list() + for channel in self.imt[0]: + expected.append(np.rot90(channel, 1, (0, 1))) + expected = np.stack(expected) self.assertTrue(np.allclose(rotated, expected)) - def test_k_axes(self): - rotate = Rotate90(k=2, axes=(2, 3)) - rotated = rotate(self.imt) - expected = np.rot90(self.imt, 2, (2, 3)) + def test_prob_k_spatial_axes(self): + rotate = Rotate90(k=2, spatial_axes=(0, 1)) + rotated = rotate(self.imt[0]) + expected = list() + for channel in self.imt[0]: + expected.append(np.rot90(channel, 2, (0, 1))) + expected = np.stack(expected) self.assertTrue(np.allclose(rotated, expected)) diff --git a/tests/test_rotate90d.py b/tests/test_rotate90d.py index ccfb2380f0..4b54a9a296 100644 --- a/tests/test_rotate90d.py +++ b/tests/test_rotate90d.py @@ -17,41 +17,53 @@ from tests.utils import NumpyImageTestCase2D -class Rotate90Test(NumpyImageTestCase2D): +class TestRotate90d(NumpyImageTestCase2D): def test_rotate90_default(self): key = 'test' rotate = Rotate90d(keys=key) - rotated = rotate({key: self.imt}) - expected = np.rot90(self.imt, 1, (1, 2)) + rotated = rotate({key: self.imt[0]}) + expected = list() + for channel in self.imt[0]: + expected.append(np.rot90(channel, 1, (0, 1))) + expected = np.stack(expected) self.assertTrue(np.allclose(rotated[key], expected)) def test_k(self): key = None rotate = Rotate90d(keys=key, k=2) - rotated = rotate({key: self.imt}) - expected = np.rot90(self.imt, 2, (1, 2)) + rotated = rotate({key: self.imt[0]}) + expected = list() + for channel in self.imt[0]: + expected.append(np.rot90(channel, 2, (0, 1))) + expected = np.stack(expected) self.assertTrue(np.allclose(rotated[key], expected)) - def test_axes(self): - key = ['test'] - rotate = Rotate90d(keys=key, axes=(1, 2)) - rotated = rotate({key[0]: self.imt}) - expected = np.rot90(self.imt, 1, (1, 2)) - self.assertTrue(np.allclose(rotated[key[0]], expected)) + def test_spatial_axes(self): + key = 'test' + rotate = Rotate90d(keys=key, spatial_axes=(0, 1)) + rotated = rotate({key: self.imt[0]}) + expected = list() + for channel in self.imt[0]: + expected.append(np.rot90(channel, 1, (0, 1))) + expected = np.stack(expected) + self.assertTrue(np.allclose(rotated[key], expected)) - def test_k_axes(self): - key = ('test',) - rotate = Rotate90d(keys=key, k=2, axes=(2, 3)) - rotated = rotate({key[0]: self.imt}) - expected = np.rot90(self.imt, 2, (2, 3)) - self.assertTrue(np.allclose(rotated[key[0]], expected)) + def test_prob_k_spatial_axes(self): + key = 'test' + rotate = Rotate90d(keys=key, k=2, spatial_axes=(0, 1)) + rotated = rotate({key: self.imt[0]}) + expected = list() + for channel in self.imt[0]: + expected.append(np.rot90(channel, 2, (0, 1))) + expected = np.stack(expected) + self.assertTrue(np.allclose(rotated[key], expected)) def test_no_key(self): key = 'unknown' rotate = Rotate90d(keys=key) with self.assertRaisesRegex(KeyError, ''): - rotate({'test': self.imt}) + rotate({'test': self.imt[0]}) if __name__ == '__main__': diff --git a/tests/test_rotated.py b/tests/test_rotated.py new file mode 100644 index 0000000000..af7a758d8d --- /dev/null +++ b/tests/test_rotated.py @@ -0,0 +1,43 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np + +import scipy.ndimage +from parameterized import parameterized + +from monai.transforms import Rotated +from tests.utils import NumpyImageTestCase2D + +TEST_CASES = [(90, (0, 1), True, 1, 'reflect', 0, True), + (-90, (1, 0), True, 3, 'constant', 0, True), + (180, (1, 0), False, 2, 'constant', 4, False)] + +class TestRotated(NumpyImageTestCase2D): + + @parameterized.expand(TEST_CASES) + def test_correct_results(self, angle, spatial_axes, reshape, + order, mode, cval, prefilter): + key = 'img' + rotate_fn = Rotated(key, angle, spatial_axes, reshape, order, + mode, cval, prefilter) + rotated = rotate_fn({key: self.imt[0]}) + expected = list() + for channel in self.imt[0]: + expected.append(scipy.ndimage.rotate(channel, angle, spatial_axes, reshape, order=order, + mode=mode, cval=cval, prefilter=prefilter)) + expected = np.stack(expected).astype(np.float32) + self.assertTrue(np.allclose(expected, rotated[key])) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_uniform_rand_patch.py b/tests/test_uniform_rand_patch.py index f11c4b43f4..32b1077ef7 100644 --- a/tests/test_uniform_rand_patch.py +++ b/tests/test_uniform_rand_patch.py @@ -17,13 +17,13 @@ from tests.utils import NumpyImageTestCase2D -class UniformRandomPatchTest(NumpyImageTestCase2D): +class TestUniformRandomPatch(NumpyImageTestCase2D): def test_2d(self): - patch_size = (1, 10, 10) - patch_transform = UniformRandomPatch(patch_size=patch_size) - patch = patch_transform(self.imt) - self.assertTrue(np.allclose(patch.shape[:-2], patch_size[:-2])) + patch_spatial_size = (10, 10) + patch_transform = UniformRandomPatch(patch_spatial_size=patch_spatial_size) + patch = patch_transform(self.imt[0]) + self.assertTrue(np.allclose(patch.shape[1:], patch_spatial_size)) if __name__ == '__main__': diff --git a/tests/test_uniform_rand_patchd.py b/tests/test_uniform_rand_patchd.py index 1ab03b4b6f..a87438acce 100644 --- a/tests/test_uniform_rand_patchd.py +++ b/tests/test_uniform_rand_patchd.py @@ -17,20 +17,20 @@ from tests.utils import NumpyImageTestCase2D -class UniformRandomPatchdTest(NumpyImageTestCase2D): +class TestUniformRandomPatchd(NumpyImageTestCase2D): def test_2d(self): - patch_size = (1, 10, 10) + patch_spatial_size = (10, 10) key = 'test' - patch_transform = UniformRandomPatchd(keys='test', patch_size=patch_size) - patch = patch_transform({key: self.imt}) - self.assertTrue(np.allclose(patch[key].shape[:-2], patch_size[:-2])) + patch_transform = UniformRandomPatchd(keys='test', patch_spatial_size=patch_spatial_size) + patch = patch_transform({key: self.imt[0]}) + self.assertTrue(np.allclose(patch[key].shape[1:], patch_spatial_size)) def test_sync(self): - patch_size = (1, 4, 4) + patch_spatial_size = (4, 4) key_1, key_2 = 'foo', 'bar' rand_image = np.random.rand(3, 10, 10) - patch_transform = UniformRandomPatchd(keys=(key_1, key_2), patch_size=patch_size) + patch_transform = UniformRandomPatchd(keys=(key_1, key_2), patch_spatial_size=patch_spatial_size) patch = patch_transform({key_1: rand_image, key_2: rand_image}) self.assertTrue(np.allclose(patch[key_1], patch[key_2])) diff --git a/tests/test_zoom.py b/tests/test_zoom.py index 83795542bc..cb0af47fef 100644 --- a/tests/test_zoom.py +++ b/tests/test_zoom.py @@ -17,7 +17,7 @@ from scipy.ndimage import zoom as zoom_scipy from parameterized import parameterized -from monai.transforms import Zoom, Zoomd +from monai.transforms import Zoom from tests.utils import NumpyImageTestCase2D VALID_CASES = [(1.1, 3, 'constant', 0, True, False, False), @@ -30,37 +30,31 @@ ("invalid_order", 0.9, 's', AssertionError)] -class ZoomTest(NumpyImageTestCase2D): +class TestZoom(NumpyImageTestCase2D): @parameterized.expand(VALID_CASES) def test_correct_results(self, zoom, order, mode, cval, prefilter, use_gpu, keep_size): zoom_fn = Zoom(zoom=zoom, order=order, mode=mode, cval=cval, prefilter=prefilter, use_gpu=use_gpu, keep_size=keep_size) zoomed = zoom_fn(self.imt[0]) - expected = zoom_scipy(self.imt, zoom=zoom, mode=mode, order=order, - cval=cval, prefilter=prefilter) + expected = list() + for channel in self.imt[0]: + expected.append(zoom_scipy(channel, zoom=zoom, mode=mode, order=order, + cval=cval, prefilter=prefilter)) + expected = np.stack(expected).astype(np.float32) self.assertTrue(np.allclose(expected, zoomed)) - @parameterized.expand(VALID_CASES) - def test_correct_results_dict(self, zoom, order, mode, cval, prefilter, use_gpu, keep_size): - key = 'img' - zoom_fn = Zoomd(key, zoom=zoom, order=order, mode=mode, cval=cval, - prefilter=prefilter, use_gpu=use_gpu, keep_size=keep_size) - zoomed = zoom_fn({key: self.imt[0]}) - - expected = zoom_scipy(self.imt, zoom=zoom, mode=mode, order=order, - cval=cval, prefilter=prefilter) - self.assertTrue(np.allclose(expected, zoomed[key])) - - @parameterized.expand(GPU_CASES) def test_gpu_zoom(self, _, zoom, order, mode, cval, prefilter): if importlib.util.find_spec('cupy'): zoom_fn = Zoom(zoom=zoom, order=order, mode=mode, cval=cval, prefilter=prefilter, use_gpu=True, keep_size=False) zoomed = zoom_fn(self.imt[0]) - expected = zoom_scipy(self.imt, zoom=zoom, mode=mode, order=order, - cval=cval, prefilter=prefilter) + expected = list() + for channel in self.imt[0]: + expected.append(zoom_scipy(channel, zoom=zoom, mode=mode, order=order, + cval=cval, prefilter=prefilter)) + expected = np.stack(expected).astype(np.float32) self.assertTrue(np.allclose(expected, zoomed)) def test_keep_size(self): @@ -76,7 +70,7 @@ def test_keep_size(self): def test_invalid_inputs(self, _, zoom, order, raises): with self.assertRaises(raises): zoom_fn = Zoom(zoom=zoom, order=order) - zoomed = zoom_fn(self.imt) + zoomed = zoom_fn(self.imt[0]) if __name__ == '__main__': diff --git a/tests/test_zoomd.py b/tests/test_zoomd.py new file mode 100644 index 0000000000..6ef85cb1fe --- /dev/null +++ b/tests/test_zoomd.py @@ -0,0 +1,82 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import importlib + +from scipy.ndimage import zoom as zoom_scipy +from parameterized import parameterized + +from monai.transforms import Zoomd +from tests.utils import NumpyImageTestCase2D + +VALID_CASES = [(1.1, 3, 'constant', 0, True, False, False), + (0.9, 3, 'constant', 0, True, False, False), + (0.8, 1, 'reflect', 0, False, False, False)] + +GPU_CASES = [("gpu_zoom", 0.6, 1, 'constant', 0, True)] + +INVALID_CASES = [("no_zoom", None, 1, TypeError), + ("invalid_order", 0.9, 's', AssertionError)] + + +class TestZoomd(NumpyImageTestCase2D): + + @parameterized.expand(VALID_CASES) + def test_correct_results(self, zoom, order, mode, cval, prefilter, use_gpu, keep_size): + key = 'img' + zoom_fn = Zoomd(key, zoom=zoom, order=order, mode=mode, cval=cval, + prefilter=prefilter, use_gpu=use_gpu, keep_size=keep_size) + zoomed = zoom_fn({key: self.imt[0]}) + expected = list() + for channel in self.imt[0]: + expected.append(zoom_scipy(channel, zoom=zoom, mode=mode, order=order, + cval=cval, prefilter=prefilter)) + expected = np.stack(expected).astype(np.float32) + self.assertTrue(np.allclose(expected, zoomed[key])) + + + @parameterized.expand(GPU_CASES) + def test_gpu_zoom(self, _, zoom, order, mode, cval, prefilter): + key = 'img' + if importlib.util.find_spec('cupy'): + zoom_fn = Zoomd(key, zoom=zoom, order=order, mode=mode, cval=cval, + prefilter=prefilter, use_gpu=True, keep_size=False) + zoomed = zoom_fn({key: self.imt[0]}) + expected = list() + for channel in self.imt[0]: + expected.append(zoom_scipy(channel, zoom=zoom, mode=mode, order=order, + cval=cval, prefilter=prefilter)) + expected = np.stack(expected).astype(np.float32) + self.assertTrue(np.allclose(expected, zoomed[key])) + + def test_keep_size(self): + key = 'img' + zoom_fn = Zoomd(key, zoom=0.6, keep_size=True) + zoomed = zoom_fn({key: self.imt[0]}) + self.assertTrue(np.array_equal(zoomed[key].shape, self.imt.shape[1:])) + + zoom_fn = Zoomd(key, zoom=1.3, keep_size=True) + zoomed = zoom_fn({key: self.imt[0]}) + self.assertTrue(np.array_equal(zoomed[key].shape, self.imt.shape[1:])) + + @parameterized.expand(INVALID_CASES) + def test_invalid_inputs(self, _, zoom, order, raises): + key = 'img' + with self.assertRaises(raises): + zoom_fn = Zoomd(key, zoom=zoom, order=order) + zoomed = zoom_fn({key: self.imt[0]}) + + +if __name__ == '__main__': + unittest.main() From f9241b7ae7a4e85c2ac410c84a14335f9440bac9 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 13 Mar 2020 00:18:36 +0000 Subject: [PATCH 6/7] update demos: - unet: num_classes -> out_channels (signature consistency) - segmentation demo changed to binary ground truth - changed segmentation training dict to have validation every n iterations - segmentation demo image sizes to 128, window size to 96, lr 1e-3 (good results) - attach classification saver finalize() to Events.COMPLETE --- .../densenet_evaluation_array.py | 3 +- .../densenet_evaluation_dict.py | 3 +- .../densenet_training_array.py | 2 +- .../densenet_training_dict.py | 2 +- examples/multi_gpu_test.ipynb | 2 +- .../segmentation_3d/unet_evaluation_array.py | 19 ++--- .../segmentation_3d/unet_evaluation_dict.py | 10 +-- .../segmentation_3d/unet_training_array.py | 57 ++++++++------- .../segmentation_3d/unet_training_dict.py | 72 ++++++++++--------- examples/unet_segmentation_3d.ipynb | 2 +- monai/data/synthetic.py | 7 +- monai/handlers/classification_saver.py | 18 ++--- monai/handlers/segmentation_saver.py | 6 +- monai/handlers/tensorboard_handlers.py | 17 +++-- monai/networks/nets/unet.py | 6 +- monai/transforms/transforms.py | 21 +++++- tests/integration_sliding_window.py | 2 +- tests/integration_unet2d.py | 7 +- tests/test_handler_classification_saver.py | 1 - tests/test_unet.py | 6 +- 20 files changed, 145 insertions(+), 118 deletions(-) diff --git a/examples/classification_3d/densenet_evaluation_array.py b/examples/classification_3d/densenet_evaluation_array.py index 7b1015ea4b..f131c79683 100644 --- a/examples/classification_3d/densenet_evaluation_array.py +++ b/examples/classification_3d/densenet_evaluation_array.py @@ -30,7 +30,7 @@ monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) -# demo dataset, user can easily change to own dataset +# IXI dataset as a demo, dowloadable from https://brain-development.org/ixi-dataset/ images = [ "/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz", @@ -94,4 +94,3 @@ def prepare_batch(batch, device=None, non_blocking=False): val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) state = evaluator.run(val_loader) -prediction_saver.finalize() diff --git a/examples/classification_3d/densenet_evaluation_dict.py b/examples/classification_3d/densenet_evaluation_dict.py index 99dd98f546..d764e1cc61 100644 --- a/examples/classification_3d/densenet_evaluation_dict.py +++ b/examples/classification_3d/densenet_evaluation_dict.py @@ -29,7 +29,7 @@ monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) -# demo dataset, user can easily change to own dataset +# IXI dataset as a demo, dowloadable from https://brain-development.org/ixi-dataset/ images = [ "/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz", @@ -95,4 +95,3 @@ def prepare_batch(batch, device=None, non_blocking=False): val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) state = evaluator.run(val_loader) -prediction_saver.finalize() diff --git a/examples/classification_3d/densenet_training_array.py b/examples/classification_3d/densenet_training_array.py index be8504adec..e6bd9c0d6c 100644 --- a/examples/classification_3d/densenet_training_array.py +++ b/examples/classification_3d/densenet_training_array.py @@ -31,7 +31,7 @@ monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) -# demo dataset, user can easily change to own dataset +# IXI dataset as a demo, dowloadable from https://brain-development.org/ixi-dataset/ images = [ "/workspace/data/medical/ixi/IXI-T1/IXI314-IOP-0889-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI249-Guys-1072-T1.nii.gz", diff --git a/examples/classification_3d/densenet_training_dict.py b/examples/classification_3d/densenet_training_dict.py index 9633459522..c8e73fc173 100644 --- a/examples/classification_3d/densenet_training_dict.py +++ b/examples/classification_3d/densenet_training_dict.py @@ -31,7 +31,7 @@ monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) -# demo dataset, user can easily change to own dataset +# IXI dataset as a demo, dowloadable from https://brain-development.org/ixi-dataset/ images = [ "/workspace/data/medical/ixi/IXI-T1/IXI314-IOP-0889-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI249-Guys-1072-T1.nii.gz", diff --git a/examples/multi_gpu_test.ipynb b/examples/multi_gpu_test.ipynb index 8f1827f15b..98911e12f2 100644 --- a/examples/multi_gpu_test.ipynb +++ b/examples/multi_gpu_test.ipynb @@ -53,7 +53,7 @@ "net = monai.networks.nets.UNet(\n", " dimensions=2,\n", " in_channels=1,\n", - " num_classes=1,\n", + " out_channels=1,\n", " channels=(16, 32, 64, 128, 256),\n", " strides=(2, 2, 2, 2),\n", " num_res_units=2,\n", diff --git a/examples/segmentation_3d/unet_evaluation_array.py b/examples/segmentation_3d/unet_evaluation_array.py index 005b9d0d85..f5c039a39e 100644 --- a/examples/segmentation_3d/unet_evaluation_array.py +++ b/examples/segmentation_3d/unet_evaluation_array.py @@ -42,8 +42,8 @@ tempdir = tempfile.mkdtemp() # tempdir = './temp' print('generating synthetic data to {} (this may take a while)'.format(tempdir)) -for i in range(50): - im, seg = create_test_image_3d(256, 256, 256) +for i in range(5): + im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i)) @@ -63,7 +63,7 @@ net = UNet( dimensions=3, in_channels=1, - num_classes=1, + out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, @@ -71,7 +71,7 @@ net.to(device) # define sliding window size and batch size for windows inference -roi_size = (64, 64, 64) +roi_size = (96, 96, 96) sw_batch_size = 4 @@ -97,11 +97,14 @@ def _sliding_window_processor(engine, batch): val_stats_handler.attach(evaluator) # for the arrary data format, assume the 3rd item of batch data is the meta_data -SegmentationSaver(output_path='tempdir', output_ext='.nii.gz', output_postfix='seg', name='evaluator', - batch_transform=lambda x: x[2], output_transform=lambda output: predict_segmentation(output[0]) - ).attach(evaluator) +file_saver = SegmentationSaver( + output_path='tempdir', output_ext='.nii.gz', output_postfix='seg', name='evaluator', + batch_transform=lambda x: x[2], output_transform=lambda output: predict_segmentation(output[0])) +file_saver.attach(evaluator) + # the model was trained by "unet_training_array" exmple -CheckpointLoader(load_path='./runs/net_checkpoint_120.pth', load_dict={'net': net}).attach(evaluator) +ckpt_saver = CheckpointLoader(load_path='./runs/net_checkpoint_50.pth', load_dict={'net': net}) +ckpt_saver.attach(evaluator) # sliding window inferene need to input 1 image in every iteration loader = DataLoader(ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()) diff --git a/examples/segmentation_3d/unet_evaluation_dict.py b/examples/segmentation_3d/unet_evaluation_dict.py index 3f43299f76..e78abeac3f 100644 --- a/examples/segmentation_3d/unet_evaluation_dict.py +++ b/examples/segmentation_3d/unet_evaluation_dict.py @@ -43,8 +43,8 @@ tempdir = tempfile.mkdtemp() # tempdir = './temp' print('generating synthetic data to {} (this may take a while)'.format(tempdir)) -for i in range(50): - im, seg = create_test_image_3d(256, 256, 256, channel_dim=-1) +for i in range(5): + im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i)) @@ -68,7 +68,7 @@ net = UNet( dimensions=3, in_channels=1, - num_classes=1, + out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, @@ -76,7 +76,7 @@ net.to(device) # define sliding window size and batch size for windows inference -roi_size = (64, 64, 64) +roi_size = (96, 96, 96) sw_batch_size = 4 @@ -108,7 +108,7 @@ def _sliding_window_processor(engine, batch): }, output_transform=lambda output: predict_segmentation(output[0])).attach(evaluator) # the model was trained by "unet_training_dict" exmple -CheckpointLoader(load_path='./runs/net_checkpoint_120.pth', load_dict={'net': net}).attach(evaluator) +CheckpointLoader(load_path='./runs/net_checkpoint_50.pth', load_dict={'net': net}).attach(evaluator) # sliding window inferene need to input 1 image in every iteration val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate, diff --git a/examples/segmentation_3d/unet_training_array.py b/examples/segmentation_3d/unet_training_array.py index abf8ac8b1a..af96e10f93 100644 --- a/examples/segmentation_3d/unet_training_array.py +++ b/examples/segmentation_3d/unet_training_array.py @@ -43,7 +43,7 @@ tempdir = tempfile.mkdtemp() print('generating synthetic data to {} (this may take a while)'.format(tempdir)) for i in range(50): - im, seg = create_test_image_3d(128, 128, 128) + im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i)) @@ -80,20 +80,29 @@ im, seg = monai.utils.misc.first(check_loader) print(im.shape, seg.shape) +# create a training data loader +train_ds = NiftiDataset(images[:20], segs[:20], transform=train_imtrans, seg_transform=train_segtrans) +train_loader = DataLoader(train_ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available()) +# create a validation data loader +val_ds = NiftiDataset(images[-20:], segs[-20:], transform=val_imtrans, seg_transform=val_segtrans) +val_loader = DataLoader(val_ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available()) + + # Create UNet, DiceLoss and Adam optimizer net = monai.networks.nets.UNet( dimensions=3, in_channels=1, - num_classes=1, + out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ) loss = monai.losses.DiceLoss(do_sigmoid=True) -lr = 1e-5 +lr = 1e-3 opt = torch.optim.Adam(net.parameters(), lr) device = torch.device("cuda:0") + # ignite trainer expects batch=(img, seg) and returns output=loss at every iteration, # user can add output_transform to return other values, like: y_pred, y, etc. trainer = create_supervised_trainer(net, opt, loss, device, False) @@ -106,7 +115,7 @@ # StatsHandler prints loss at every iteration and print metrics at every epoch, # we don't set metrics for trainer here, so just print loss, user can also customize print functions -# and can use output_transform to convert engine.state.output if it's not loss value +# and can use output_transform to convert engine.state.output if it's not a loss value train_stats_handler = StatsHandler(name='trainer') train_stats_handler.attach(trainer) @@ -114,9 +123,9 @@ train_tensorboard_stats_handler = TensorBoardStatsHandler() train_tensorboard_stats_handler.attach(trainer) -# Set parameters for validation -validation_every_n_epochs = 1 +validation_every_n_epochs = 1 +# Set parameters for validation metric_name = 'Mean_Dice' # add evaluation metric to the evaluator engine val_metrics = {metric_name: MeanDice(add_sigmoid=True, to_onehot_y=False)} @@ -125,6 +134,18 @@ # user can add output_transform to return other values evaluator = create_supervised_evaluator(net, val_metrics, device, True) + +@trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) +def run_validation(engine): + evaluator.run(val_loader) + + +# Add early stopping handler to evaluator +early_stopper = EarlyStopping(patience=4, + score_function=stopping_fn_from_metric(metric_name), + trainer=trainer) +evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) + # Add stats event handler to print validation stats via evaluator val_stats_handler = StatsHandler( name='evaluator', @@ -132,13 +153,14 @@ global_epoch_transform=lambda x: trainer.state.epoch) # fetch global epoch number from trainer val_stats_handler.attach(evaluator) -# add handler to record metrics to TensorBoard at every epoch +# add handler to record metrics to TensorBoard at every validation epoch val_tensorboard_stats_handler = TensorBoardStatsHandler( output_transform=lambda x: None, # no need to plot loss value, so disable per iteration output global_epoch_transform=lambda x: trainer.state.epoch) # fetch global epoch number from trainer val_tensorboard_stats_handler.attach(evaluator) + # add handler to draw the first image and the corresponding label and model output in the last batch -# here we draw the 3D output as GIF format along Depth axis +# here we draw the 3D output as GIF format along Depth axis, at every validation epoch val_tensorboard_image_handler = TensorBoardImageHandler( batch_transform=lambda batch: (batch[0], batch[1]), output_transform=lambda output: predict_segmentation(output[0]), @@ -146,25 +168,6 @@ ) evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=val_tensorboard_image_handler) -# Add early stopping handler to evaluator -early_stopper = EarlyStopping(patience=4, - score_function=stopping_fn_from_metric(metric_name), - trainer=trainer) -evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) - -# create a validation data loader -val_ds = NiftiDataset(images[-20:], segs[-20:], transform=val_imtrans, seg_transform=val_segtrans) -val_loader = DataLoader(val_ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available()) - - -@trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) -def run_validation(engine): - evaluator.run(val_loader) - - -# create a training data loader -train_ds = NiftiDataset(images[:20], segs[:20], transform=train_imtrans, seg_transform=train_segtrans) -train_loader = DataLoader(train_ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available()) train_epochs = 30 state = trainer.run(train_loader, train_epochs) diff --git a/examples/segmentation_3d/unet_training_dict.py b/examples/segmentation_3d/unet_training_dict.py index e3b8d4fe88..092c552a8c 100644 --- a/examples/segmentation_3d/unet_training_dict.py +++ b/examples/segmentation_3d/unet_training_dict.py @@ -43,7 +43,7 @@ tempdir = tempfile.mkdtemp() print('generating synthetic data to {} (this may take a while)'.format(tempdir)) for i in range(50): - im, seg = create_test_image_3d(128, 128, 128, channel_dim=-1) + im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, 'img%i.nii.gz' % i)) @@ -53,8 +53,8 @@ images = sorted(glob(os.path.join(tempdir, 'img*.nii.gz'))) segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) -train_files = [{'img': img, 'seg': seg} for img, seg in zip(images[:40], segs[:40])] -val_files = [{'img': img, 'seg': seg} for img, seg in zip(images[-10:], segs[-10:])] +train_files = [{'img': img, 'seg': seg} for img, seg in zip(images[:20], segs[:20])] +val_files = [{'img': img, 'seg': seg} for img, seg in zip(images[-20:], segs[-20:])] # Define transforms for image and segmentation train_transforms = transforms.Compose([ @@ -78,17 +78,29 @@ check_data = monai.utils.misc.first(check_loader) print(check_data['img'].shape, check_data['seg'].shape) + +# create a training data loader +train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) +# use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training +train_loader = DataLoader(train_ds, batch_size=2, num_workers=4, collate_fn=list_data_collate, + pin_memory=torch.cuda.is_available()) +# create a validation data loader +val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) +val_loader = DataLoader(val_ds, batch_size=5, num_workers=8, collate_fn=list_data_collate, + pin_memory=torch.cuda.is_available()) + + # Create UNet, DiceLoss and Adam optimizer net = monai.networks.nets.UNet( dimensions=3, in_channels=1, - num_classes=1, + out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ) loss = monai.losses.DiceLoss(do_sigmoid=True) -lr = 1e-5 +lr = 1e-3 opt = torch.optim.Adam(net.parameters(), lr) device = torch.device("cuda:0") @@ -117,9 +129,9 @@ def prepare_batch(batch, device=None, non_blocking=False): train_tensorboard_stats_handler = TensorBoardStatsHandler() train_tensorboard_stats_handler.attach(trainer) -# Set parameters for validation -validation_every_n_epochs = 1 +validation_every_n_iters = 5 +# Set parameters for validation metric_name = 'Mean_Dice' # add evaluation metric to the evaluator engine val_metrics = {metric_name: MeanDice(add_sigmoid=True, to_onehot_y=False)} @@ -128,6 +140,18 @@ def prepare_batch(batch, device=None, non_blocking=False): # user can add output_transform to return other values evaluator = create_supervised_evaluator(net, val_metrics, device, True, prepare_batch=prepare_batch) + +@trainer.on(Events.ITERATION_COMPLETED(every=validation_every_n_iters)) +def run_validation(engine): + evaluator.run(val_loader) + + +# Add early stopping handler to evaluator +early_stopper = EarlyStopping(patience=4, + score_function=stopping_fn_from_metric(metric_name), + trainer=trainer) +evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) + # Add stats event handler to print validation stats via evaluator val_stats_handler = StatsHandler( name='evaluator', @@ -135,42 +159,22 @@ def prepare_batch(batch, device=None, non_blocking=False): global_epoch_transform=lambda x: trainer.state.epoch) # fetch global epoch number from trainer val_stats_handler.attach(evaluator) -# add handler to record metrics to TensorBoard at every epoch +# add handler to record metrics to TensorBoard at every validation epoch val_tensorboard_stats_handler = TensorBoardStatsHandler( output_transform=lambda x: None, # no need to plot loss value, so disable per iteration output - global_epoch_transform=lambda x: trainer.state.epoch) # fetch global epoch number from trainer + global_epoch_transform=lambda x: trainer.state.iteration) # fetch global iteration number from trainer val_tensorboard_stats_handler.attach(evaluator) + # add handler to draw the first image and the corresponding label and model output in the last batch -# here we draw the 3D output as GIF format along Depth axis +# here we draw the 3D output as GIF format along the depth axis, every 2 validation iterations. val_tensorboard_image_handler = TensorBoardImageHandler( batch_transform=lambda batch: (batch['img'], batch['seg']), output_transform=lambda output: predict_segmentation(output[0]), global_iter_transform=lambda x: trainer.state.epoch ) -evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=val_tensorboard_image_handler) - -# Add early stopping handler to evaluator -early_stopper = EarlyStopping(patience=4, - score_function=stopping_fn_from_metric(metric_name), - trainer=trainer) -evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) - -# create a validation data loader -val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) -val_loader = DataLoader(val_ds, batch_size=5, num_workers=8, collate_fn=list_data_collate, - pin_memory=torch.cuda.is_available()) - - -@trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) -def run_validation(engine): - evaluator.run(val_loader) +evaluator.add_event_handler( + event_name=Events.ITERATION_COMPLETED(every=2), handler=val_tensorboard_image_handler) -# create a training data loader -train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) -# use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training -train_loader = DataLoader(train_ds, batch_size=2, num_workers=4, collate_fn=list_data_collate, - pin_memory=torch.cuda.is_available()) - -train_epochs = 30 +train_epochs = 5 state = trainer.run(train_loader, train_epochs) diff --git a/examples/unet_segmentation_3d.ipynb b/examples/unet_segmentation_3d.ipynb index 0d49742f10..f42c33a2a5 100644 --- a/examples/unet_segmentation_3d.ipynb +++ b/examples/unet_segmentation_3d.ipynb @@ -145,7 +145,7 @@ "net = monai.networks.nets.UNet(\n", " dimensions=3,\n", " in_channels=1,\n", - " num_classes=1,\n", + " out_channels=1,\n", " channels=(16, 32, 64, 128, 256),\n", " strides=(2, 2, 2, 2),\n", " num_res_units=2,\n", diff --git a/monai/data/synthetic.py b/monai/data/synthetic.py index 4efd4fe393..063c16a965 100644 --- a/monai/data/synthetic.py +++ b/monai/data/synthetic.py @@ -20,7 +20,7 @@ def create_test_image_2d(width, height, num_objs=12, rad_max=30, noise_max=0.0, `radMax'. The mask will have `numSegClasses' number of classes for segmentations labeled sequentially from 1, plus a background class represented as 0. If `noiseMax' is greater than 0 then noise will be added to the image taken from the uniform distribution on range [0,noiseMax). If `channel_dim' is None, will create an image without channel - dimemsion, otherwise create an image with channel dimension as first dim or last dim. + dimension, otherwise create an image with channel dimension as first dim or last dim. """ image = np.zeros((width, height)) @@ -44,7 +44,7 @@ def create_test_image_2d(width, height, num_objs=12, rad_max=30, noise_max=0.0, if channel_dim is not None: assert isinstance(channel_dim, int) and channel_dim in (-1, 0, 2), 'invalid channel dim.' noisyimage, labels = noisyimage[None], labels[None] \ - if channel_dim == 0 else noisyimage[..., None], labels[..., None] + if channel_dim == 0 else (noisyimage[..., None], labels[..., None]) return noisyimage, labels @@ -54,7 +54,8 @@ def create_test_image_3d(height, width, depth, num_objs=12, rad_max=30, """ Return a noisy 3D image and segmentation. - See also: create_test_image_2d + See also: + ``create_test_image_2d`` """ image = np.zeros((width, height, depth)) diff --git a/monai/handlers/classification_saver.py b/monai/handlers/classification_saver.py index d0c9f2126c..501dce816f 100644 --- a/monai/handlers/classification_saver.py +++ b/monai/handlers/classification_saver.py @@ -50,19 +50,21 @@ def __init__(self, output_dir='./', overwrite=True, def attach(self, engine): if self.logger is None: self.logger = engine.logger - return engine.add_event_handler(Events.ITERATION_COMPLETED, self) + if not engine.has_event_handler(self, Events.ITERATION_COMPLETED): + engine.add_event_handler(Events.ITERATION_COMPLETED, self) + if not engine.has_event_handler(self.finalize, Events.COMPLETED): + engine.add_event_handler(Events.COMPLETED, self.finalize) - def finalize(self): + def finalize(self, _engine=None): """ Writes the prediction dict to a csv """ - if not self.overwrite: - if os.path.exists(self._preds_filepath): - with open(self._preds_filepath, 'r') as f: - reader = csv.reader(f) - for row in reader: - self._prediction_dict[row[0]] = np.array(row[1:]).astype(np.float32) + if not self.overwrite and os.path.exists(self._preds_filepath): + with open(self._preds_filepath, 'r') as f: + reader = csv.reader(f) + for row in reader: + self._prediction_dict[row[0]] = np.array(row[1:]).astype(np.float32) if not os.path.exists(self.output_dir): os.makedirs(self.output_dir) diff --git a/monai/handlers/segmentation_saver.py b/monai/handlers/segmentation_saver.py index 3a136ee3ec..98eb972d3a 100644 --- a/monai/handlers/segmentation_saver.py +++ b/monai/handlers/segmentation_saver.py @@ -50,7 +50,8 @@ def __init__(self, output_path='./', dtype='float32', output_postfix='seg', outp def attach(self, engine): if self.logger is None: self.logger = engine.logger - return engine.add_event_handler(Events.ITERATION_COMPLETED, self) + if not engine.has_event_handler(self, Events.ITERATION_COMPLETED): + engine.add_event_handler(Events.ITERATION_COMPLETED, self) @staticmethod def _create_file_basename(postfix, input_file_name, folder_path, data_root_dir=""): @@ -115,5 +116,6 @@ def __call__(self, engine): output_filename = self._create_file_basename(self.output_postfix, filename, self.output_path) output_filename = '{}{}'.format(output_filename, self.output_ext) # change output to "channel last" format and write to nifti format file - write_nifti(np.moveaxis(seg_output, 0, -1), affine_, output_filename, original_affine_, dtype=seg_output.dtype) + to_save = np.moveaxis(seg_output, 0, -1) + write_nifti(to_save, affine_, output_filename, original_affine_, dtype=seg_output.dtype) self.logger.info('saved: {}'.format(output_filename)) diff --git a/monai/handlers/tensorboard_handlers.py b/monai/handlers/tensorboard_handlers.py index 80d524b56e..e6e3604802 100644 --- a/monai/handlers/tensorboard_handlers.py +++ b/monai/handlers/tensorboard_handlers.py @@ -139,22 +139,21 @@ def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter): ' {}:{}'.format(name, type(value))) continue # not plot multi dimensional output writer.add_scalar(name, value.item() if torch.is_tensor(value) else value, engine.state.iteration) + elif is_scalar(loss): # not printing multi dimensional output + writer.add_scalar(name, loss.item() if torch.is_tensor(loss) else loss, engine.state.iteration) else: - if is_scalar(loss): # not printing multi dimensional output - writer.add_scalar(name, loss.item() if torch.is_tensor(loss) else loss, engine.state.iteration) - else: - warnings.warn('ignoring non-scalar output in TensorBoardStatsHandler,' - ' make sure `output_transform(engine.state.output)` returns' - ' a scalar or a dictionary of key and scalar pairs to avoid this warning.' - ' {}'.format(type(loss))) + warnings.warn('ignoring non-scalar output in TensorBoardStatsHandler,' + ' make sure `output_transform(engine.state.output)` returns' + ' a scalar or a dictionary of key and scalar pairs to avoid this warning.' + ' {}'.format(type(loss))) writer.flush() class TensorBoardImageHandler(object): """TensorBoardImageHandler is an ignite Event handler that can visualise images, labels and outputs as 2D/3D images. 2D output (shape in Batch, channel, H, W) will be shown as simple image using the first element in the batch, - for 3D to ND output (shape in Batch, channel, H, W, D) input, - the last three dimensions will be shown as GIF image along the last axis (typically Depth). + for 3D to ND output (shape in Batch, channel, H, W, D) input, each of ``self.max_channels`` number of images' + last three dimensions will be shown as animated GIF along the last axis (typically Depth). It's can be used for any Ignite Engine (trainer, validator and evaluator). User can easily added it to engine for any expected Event, for example: ``EPOCH_COMPLETED``, diff --git a/monai/networks/nets/unet.py b/monai/networks/nets/unet.py index 6018cd06e8..ad9b3ddbf4 100644 --- a/monai/networks/nets/unet.py +++ b/monai/networks/nets/unet.py @@ -21,13 +21,13 @@ @alias("Unet", "unet") class UNet(nn.Module): - def __init__(self, dimensions, in_channels, num_classes, channels, strides, kernel_size=3, up_kernel_size=3, + def __init__(self, dimensions, in_channels, out_channels, channels, strides, kernel_size=3, up_kernel_size=3, num_res_units=0, instance_norm=True, dropout=0): super().__init__() assert len(channels) == (len(strides) + 1) self.dimensions = dimensions self.in_channels = in_channels - self.num_classes = num_classes + self.out_channels = out_channels self.channels = channels self.strides = strides self.kernel_size = kernel_size @@ -57,7 +57,7 @@ def _create_block(inc, outc, channels, strides, is_top): return nn.Sequential(down, SkipConnection(subblock), up) - self.model = _create_block(in_channels, num_classes, self.channels, self.strides, True) + self.model = _create_block(in_channels, out_channels, self.channels, self.strides, True) def _get_down_layer(self, in_channels, out_channels, strides, is_top): if self.num_res_units > 0: diff --git a/monai/transforms/transforms.py b/monai/transforms/transforms.py index 8d9d18ccda..0e326c6f73 100644 --- a/monai/transforms/transforms.py +++ b/monai/transforms/transforms.py @@ -197,6 +197,16 @@ def __call__(self, filename): class AsChannelFirst: """ Change the channel dimension of the image to the first dimension. + + Most of the image transformations in ``monai.transforms`` + assumes the input image is in the channel-first format, which has the shape + (num_channels, spatial_dim_1[, spatial_dim_2, ...]). + + This transform could be used to convert, for example, a channel-last image array in shape + (spatial_dim_1[, spatial_dim_2, ...], num_channels) into the channel-first format, + so that the multidimensional image array can be correctly interpreted by the other + transforms. + Args: channel_dim (int): which dimension of input image is the channel, default is the last dimension. """ @@ -213,6 +223,15 @@ def __call__(self, img): class AddChannel: """ Adds a 1-length channel dimension to the input image. + + Most of the image transformations in ``monai.transforms`` + assumes the input image is in the channel-first format, which has the shape + (num_channels, spatial_dim_1[, spatial_dim_2, ...]). + + This transform could be used, for example, to convert a (spatial_dim_1[, spatial_dim_2, ...]) + spatial image into the channel-first format so that the + multidimensional image array can be correctly interpreted by the other + transforms. """ def __call__(self, img): @@ -1020,7 +1039,7 @@ def __init__(self, Args: rotate_params (float, list of floats): a rotation angle in radians, - a scalar for 2D image, a tuple of 2 floats for 3D. Defaults to no rotation. + a scalar for 2D image, a tuple of 3 floats for 3D. Defaults to no rotation. shear_params (list of floats): a tuple of 2 floats for 2D, a tuple of 6 floats for 3D. Defaults to no shearing. translate_params (list of floats): diff --git a/tests/integration_sliding_window.py b/tests/integration_sliding_window.py index e41f6574d3..b99bb5c681 100644 --- a/tests/integration_sliding_window.py +++ b/tests/integration_sliding_window.py @@ -40,7 +40,7 @@ def run_test(batch_size=2, device=torch.device("cpu:0")): net = UNet( dimensions=3, in_channels=1, - num_classes=1, + out_channels=1, channels=(4, 8, 16, 32), strides=(2, 2, 2), num_res_units=2, diff --git a/tests/integration_unet2d.py b/tests/integration_unet2d.py index 3dac697397..d437258c10 100644 --- a/tests/integration_unet2d.py +++ b/tests/integration_unet2d.py @@ -35,7 +35,7 @@ def __len__(self): net = UNet( dimensions=2, in_channels=1, - num_classes=1, + out_channels=1, channels=(4, 8, 16, 32), strides=(2, 2, 2), num_res_units=2, @@ -45,10 +45,7 @@ def __len__(self): opt = torch.optim.Adam(net.parameters(), 1e-4) src = DataLoader(_TestBatch(), batch_size=batch_size) - def loss_fn(pred, grnd): - return loss(pred, grnd) - - trainer = create_supervised_trainer(net, opt, loss_fn, device, False) + trainer = create_supervised_trainer(net, opt, loss, device, False) trainer.run(src, 1) loss = trainer.state.output diff --git a/tests/test_handler_classification_saver.py b/tests/test_handler_classification_saver.py index 78a773e092..3eea9d86ed 100644 --- a/tests/test_handler_classification_saver.py +++ b/tests/test_handler_classification_saver.py @@ -38,7 +38,6 @@ def _train_func(engine, batch): data = [{'filename_or_obj': ['testfile' + str(i) for i in range(8)]}] engine.run(data, epoch_length=2, max_epochs=1) - saver.finalize() filepath = os.path.join(default_dir, 'predictions.csv') self.assertTrue(os.path.exists(filepath)) with open(filepath, 'r') as f: diff --git a/tests/test_unet.py b/tests/test_unet.py index d64407bb4b..5b8e85f915 100644 --- a/tests/test_unet.py +++ b/tests/test_unet.py @@ -20,7 +20,7 @@ { 'dimensions': 2, 'in_channels': 1, - 'num_classes': 3, + 'out_channels': 3, 'channels': (16, 32, 64), 'strides': (2, 2), 'num_res_units': 1, @@ -33,7 +33,7 @@ { 'dimensions': 3, 'in_channels': 1, - 'num_classes': 3, + 'out_channels': 3, 'channels': (16, 32, 64), 'strides': (2, 2), 'num_res_units': 1, @@ -46,7 +46,7 @@ { 'dimensions': 3, 'in_channels': 4, - 'num_classes': 3, + 'out_channels': 3, 'channels': (16, 32, 64), 'strides': (2, 2), 'num_res_units': 1, From 42299599103d58424f549b563d8581c561a51632 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 13 Mar 2020 10:13:38 +0800 Subject: [PATCH 7/7] [DLMED] fix tag_name bug, fix DenseNet3D bug, add comments --- examples/classification_3d/densenet_evaluation_array.py | 1 + examples/classification_3d/densenet_evaluation_dict.py | 1 + examples/classification_3d/densenet_training_array.py | 1 + examples/classification_3d/densenet_training_dict.py | 1 + examples/segmentation_3d/unet_training_array.py | 4 ++-- examples/segmentation_3d/unet_training_dict.py | 4 ++-- monai/handlers/tensorboard_handlers.py | 9 +++++---- monai/networks/nets/densenet3d.py | 1 - 8 files changed, 13 insertions(+), 9 deletions(-) diff --git a/examples/classification_3d/densenet_evaluation_array.py b/examples/classification_3d/densenet_evaluation_array.py index f131c79683..9785605aad 100644 --- a/examples/classification_3d/densenet_evaluation_array.py +++ b/examples/classification_3d/densenet_evaluation_array.py @@ -43,6 +43,7 @@ "/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz" ] +# 2 binary labels for gender classification: man and woman labels = np.array([ 0, 0, 1, 0, 1, 0, 1, 0, 1, 0 ]) diff --git a/examples/classification_3d/densenet_evaluation_dict.py b/examples/classification_3d/densenet_evaluation_dict.py index d764e1cc61..d2e33cfe5e 100644 --- a/examples/classification_3d/densenet_evaluation_dict.py +++ b/examples/classification_3d/densenet_evaluation_dict.py @@ -42,6 +42,7 @@ "/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz" ] +# 2 binary labels for gender classification: man and woman labels = np.array([ 0, 0, 1, 0, 1, 0, 1, 0, 1, 0 ]) diff --git a/examples/classification_3d/densenet_training_array.py b/examples/classification_3d/densenet_training_array.py index e6bd9c0d6c..4993556404 100644 --- a/examples/classification_3d/densenet_training_array.py +++ b/examples/classification_3d/densenet_training_array.py @@ -54,6 +54,7 @@ "/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz" ] +# 2 binary labels for gender classification: man and woman labels = np.array([ 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0 ]) diff --git a/examples/classification_3d/densenet_training_dict.py b/examples/classification_3d/densenet_training_dict.py index c8e73fc173..c0017df969 100644 --- a/examples/classification_3d/densenet_training_dict.py +++ b/examples/classification_3d/densenet_training_dict.py @@ -54,6 +54,7 @@ "/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz" ] +# 2 binary labels for gender classification: man and woman labels = np.array([ 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0 ]) diff --git a/examples/segmentation_3d/unet_training_array.py b/examples/segmentation_3d/unet_training_array.py index af96e10f93..c9cb70875b 100644 --- a/examples/segmentation_3d/unet_training_array.py +++ b/examples/segmentation_3d/unet_training_array.py @@ -39,10 +39,10 @@ monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) -# Create a temporary directory and 50 random image, mask paris +# Create a temporary directory and 40 random image, mask paris tempdir = tempfile.mkdtemp() print('generating synthetic data to {} (this may take a while)'.format(tempdir)) -for i in range(50): +for i in range(40): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) n = nib.Nifti1Image(im, np.eye(4)) diff --git a/examples/segmentation_3d/unet_training_dict.py b/examples/segmentation_3d/unet_training_dict.py index 092c552a8c..018f7076a7 100644 --- a/examples/segmentation_3d/unet_training_dict.py +++ b/examples/segmentation_3d/unet_training_dict.py @@ -39,10 +39,10 @@ monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) -# Create a temporary directory and 50 random image, mask paris +# Create a temporary directory and 40 random image, mask paris tempdir = tempfile.mkdtemp() print('generating synthetic data to {} (this may take a while)'.format(tempdir)) -for i in range(50): +for i in range(40): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) diff --git a/monai/handlers/tensorboard_handlers.py b/monai/handlers/tensorboard_handlers.py index e6e3604802..411a53084b 100644 --- a/monai/handlers/tensorboard_handlers.py +++ b/monai/handlers/tensorboard_handlers.py @@ -38,7 +38,7 @@ def __init__(self, summary_writer=None, epoch_event_writer=None, iteration_event_writer=None, - output_transform=lambda x: {'Loss': x}, + output_transform=lambda x: x, global_epoch_transform=lambda x: x, tag_name=DEFAULT_TAG): """ @@ -50,8 +50,9 @@ def __init__(self, iteration_event_writer (Callable): custimized callable TensorBoard writer for iteration level. must accept parameter "engine" and "summary_writer", use default event writer if None. output_transform (Callable): a callable that is used to transform the - ``ignite.engine.output`` into a dictionary of (tag_name: scalar) pairs to be plotted onto tensorboard. - by default this scalar plotting happens when every iteration completed. + ``ignite.engine.output`` into a scalar to plot, or a dictionary of {key: scalar}. + in the latter case, the output string will be formated as key: value. + by default this value plotting happens when every iteration completed. global_epoch_transform (Callable): a callable that is used to customize global epoch number. For example, in evaluation, the evaluator engine might want to use trainer engines epoch number when plotting epoch vs metric curves. @@ -140,7 +141,7 @@ def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter): continue # not plot multi dimensional output writer.add_scalar(name, value.item() if torch.is_tensor(value) else value, engine.state.iteration) elif is_scalar(loss): # not printing multi dimensional output - writer.add_scalar(name, loss.item() if torch.is_tensor(loss) else loss, engine.state.iteration) + writer.add_scalar(self.tag_name, loss.item() if torch.is_tensor(loss) else loss, engine.state.iteration) else: warnings.warn('ignoring non-scalar output in TensorBoardStatsHandler,' ' make sure `output_transform(engine.state.output)` returns' diff --git a/monai/networks/nets/densenet3d.py b/monai/networks/nets/densenet3d.py index f5493f6c04..78fab167c4 100644 --- a/monai/networks/nets/densenet3d.py +++ b/monai/networks/nets/densenet3d.py @@ -146,7 +146,6 @@ def __init__(self, OrderedDict([ ('relu', nn.ReLU(inplace=True)), ('norm', get_avgpooling_type(spatial_dims, is_adaptive=True)(1)), - ('relu', nn.ReLU(inplace=True)), ('flatten', nn.Flatten(1)), ('class', nn.Linear(in_channels, out_channels)), ]))