diff --git a/examples/classification_3d/densenet_evaluation_array.py b/examples/classification_3d/densenet_evaluation_array.py new file mode 100644 index 0000000000..9785605aad --- /dev/null +++ b/examples/classification_3d/densenet_evaluation_array.py @@ -0,0 +1,97 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import logging +import numpy as np +import torch +from ignite.engine import create_supervised_evaluator, _prepare_batch +from torch.utils.data import DataLoader + +# assumes the framework is found here, change as necessary +sys.path.append("../..") +import monai +import monai.transforms.compose as transforms +from monai.data.nifti_reader import NiftiDataset +from monai.transforms import (AddChannel, Rescale, Resize) +from monai.handlers.stats_handler import StatsHandler +from monai.handlers.classification_saver import ClassificationSaver +from monai.handlers.checkpoint_loader import CheckpointLoader +from ignite.metrics import Accuracy + +monai.config.print_config() +logging.basicConfig(stream=sys.stdout, level=logging.INFO) + +# IXI dataset as a demo, dowloadable from https://brain-development.org/ixi-dataset/ +images = [ + "/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz" +] +# 2 binary labels for gender classification: man and woman +labels = np.array([ + 0, 0, 1, 0, 1, 0, 1, 0, 1, 0 +]) + +# Define transforms for image +val_transforms = transforms.Compose([ + Rescale(), + AddChannel(), + Resize((96, 96, 96)) +]) +# Define nifti dataset +val_ds = NiftiDataset(image_files=images, labels=labels, transform=val_transforms, image_only=False) +# Create DenseNet121 +net = monai.networks.nets.densenet3d.densenet121( + in_channels=1, + out_channels=2, +) +device = torch.device("cuda:0") + +metric_name = 'Accuracy' +# add evaluation metric to the evaluator engine +val_metrics = {metric_name: Accuracy()} + + +def prepare_batch(batch, device=None, non_blocking=False): + return _prepare_batch((batch[0], batch[1]), device, non_blocking) + + +# ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration, +# user can add output_transform to return other values +evaluator = create_supervised_evaluator(net, val_metrics, device, True, prepare_batch=prepare_batch) + +# Add stats event handler to print validation stats via evaluator +val_stats_handler = StatsHandler( + name='evaluator', + output_transform=lambda x: None # no need to print loss value, so disable per iteration output +) +val_stats_handler.attach(evaluator) + +# for the arrary data format, assume the 3rd item of batch data is the meta_data +prediction_saver = ClassificationSaver(output_dir='tempdir', batch_transform=lambda batch: batch[2], + output_transform=lambda output: output[0].argmax(1)) +prediction_saver.attach(evaluator) + +# the model was trained by "densenet_training_array" exmple +CheckpointLoader(load_path='./runs/net_checkpoint_40.pth', load_dict={'net': net}).attach(evaluator) + +# create a validation data loader +val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) + +state = evaluator.run(val_loader) diff --git a/examples/classification_3d/densenet_evaluation_dict.py b/examples/classification_3d/densenet_evaluation_dict.py new file mode 100644 index 0000000000..d2e33cfe5e --- /dev/null +++ b/examples/classification_3d/densenet_evaluation_dict.py @@ -0,0 +1,98 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ignite.metrics import Accuracy +import sys +import logging +import numpy as np +import torch +from ignite.engine import create_supervised_evaluator, _prepare_batch +from torch.utils.data import DataLoader + +# assumes the framework is found here, change as necessary +sys.path.append("../..") +from monai.handlers.classification_saver import ClassificationSaver +from monai.handlers.checkpoint_loader import CheckpointLoader +from monai.handlers.stats_handler import StatsHandler +from monai.transforms.composables import LoadNiftid, AddChanneld, Rescaled, Resized +import monai.transforms.compose as transforms +import monai + +monai.config.print_config() +logging.basicConfig(stream=sys.stdout, level=logging.INFO) + +# IXI dataset as a demo, dowloadable from https://brain-development.org/ixi-dataset/ +images = [ + "/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz" +] +# 2 binary labels for gender classification: man and woman +labels = np.array([ + 0, 0, 1, 0, 1, 0, 1, 0, 1, 0 +]) +val_files = [{'img': img, 'label': label} for img, label in zip(images, labels)] + +# Define transforms for image +val_transforms = transforms.Compose([ + LoadNiftid(keys=['img']), + AddChanneld(keys=['img']), + Rescaled(keys=['img']), + Resized(keys=['img'], output_spatial_shape=(96, 96, 96)) +]) + +# Create DenseNet121 +net = monai.networks.nets.densenet3d.densenet121( + in_channels=1, + out_channels=2, +) +device = torch.device("cuda:0") + + +def prepare_batch(batch, device=None, non_blocking=False): + return _prepare_batch((batch['img'], batch['label']), device, non_blocking) + + +metric_name = 'Accuracy' +# add evaluation metric to the evaluator engine +val_metrics = {metric_name: Accuracy()} +# ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration, +# user can add output_transform to return other values +evaluator = create_supervised_evaluator(net, val_metrics, device, True, prepare_batch=prepare_batch) + +# Add stats event handler to print validation stats via evaluator +val_stats_handler = StatsHandler( + name='evaluator', + output_transform=lambda x: None # no need to print loss value, so disable per iteration output +) +val_stats_handler.attach(evaluator) + +# for the arrary data format, assume the 3rd item of batch data is the meta_data +prediction_saver = ClassificationSaver(output_dir='tempdir', name='evaluator', + batch_transform=lambda batch: {'filename_or_obj': batch['img.filename_or_obj']}, + output_transform=lambda output: output[0].argmax(1)) +prediction_saver.attach(evaluator) + +# the model was trained by "densenet_training_dict" exmple +CheckpointLoader(load_path='./runs/net_checkpoint_40.pth', load_dict={'net': net}).attach(evaluator) + +# create a validation data loader +val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) +val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) + +state = evaluator.run(val_loader) diff --git a/examples/densenet_classification_3d.py b/examples/classification_3d/densenet_training_array.py similarity index 62% rename from examples/densenet_classification_3d.py rename to examples/classification_3d/densenet_training_array.py index 753097a2a9..4993556404 100644 --- a/examples/densenet_classification_3d.py +++ b/examples/classification_3d/densenet_training_array.py @@ -18,19 +18,20 @@ from torch.utils.data import DataLoader # assumes the framework is found here, change as necessary -sys.path.append("..") +sys.path.append("../..") import monai import monai.transforms.compose as transforms - from monai.data.nifti_reader import NiftiDataset -from monai.transforms import (AddChannel, Rescale, ToTensor, UniformRandomPatch) +from monai.transforms import (AddChannel, Rescale, Resize, RandRotate90) from monai.handlers.stats_handler import StatsHandler +from monai.handlers.tensorboard_handlers import TensorBoardStatsHandler from ignite.metrics import Accuracy from monai.handlers.utils import stopping_fn_from_metric monai.config.print_config() +logging.basicConfig(stream=sys.stdout, level=logging.INFO) -# FIXME: temp test dataset, Wenqi will replace later +# IXI dataset as a demo, dowloadable from https://brain-development.org/ixi-dataset/ images = [ "/workspace/data/medical/ixi/IXI-T1/IXI314-IOP-0889-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI249-Guys-1072-T1.nii.gz", @@ -53,37 +54,42 @@ "/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz" ] +# 2 binary labels for gender classification: man and woman labels = np.array([ 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0 ]) -# Define transforms for image and segmentation -imtrans = transforms.Compose([ +# Define transforms +train_transforms = transforms.Compose([ + Rescale(), + AddChannel(), + Resize((96, 96, 96)), + RandRotate90() +]) +val_transforms = transforms.Compose([ Rescale(), AddChannel(), - UniformRandomPatch((96, 96, 96)), - ToTensor() + Resize((96, 96, 96)) ]) -# Define nifti dataset, dataloader. -ds = NiftiDataset(image_files=images, labels=labels, transform=imtrans) -loader = DataLoader(ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available()) -im, label = monai.utils.misc.first(loader) +# Define nifti dataset, dataloader +check_ds = NiftiDataset(image_files=images, labels=labels, transform=train_transforms) +check_loader = DataLoader(check_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available()) +im, label = monai.utils.misc.first(check_loader) print(type(im), im.shape, label) -lr = 1e-5 - -# Create DenseNet121, CrossEntropyLoss and Adam optimizer. +# Create DenseNet121, CrossEntropyLoss and Adam optimizer net = monai.networks.nets.densenet3d.densenet121( in_channels=1, out_channels=2, ) - loss = torch.nn.CrossEntropyLoss() +lr = 1e-5 opt = torch.optim.Adam(net.parameters(), lr) - -# Create trainer device = torch.device("cuda:0") + +# ignite trainer expects batch=(img, label) and returns output=loss at every iteration, +# user can add output_transform to return other values, like: y_pred, y, etc. trainer = create_supervised_trainer(net, opt, loss, device, False) # adding checkpoint handler to save models (network params and optimizer stats) during training @@ -91,45 +97,58 @@ trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={'net': net, 'opt': opt}) -train_stats_handler = StatsHandler(output_transform=lambda x: x[3]) + +# StatsHandler prints loss at every iteration and print metrics at every epoch, +# we don't set metrics for trainer here, so just print loss, user can also customize print functions +# and can use output_transform to convert engine.state.output if it's not loss value +train_stats_handler = StatsHandler(name='trainer') train_stats_handler.attach(trainer) -@trainer.on(Events.EPOCH_COMPLETED) -def log_training_loss(engine): - engine.logger.info("Epoch[%s] Loss: %s", engine.state.epoch, engine.state.output) +# TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler +train_tensorboard_stats_handler = TensorBoardStatsHandler() +train_tensorboard_stats_handler.attach(trainer) # Set parameters for validation validation_every_n_epochs = 1 -metric_name = 'Accuracy' +metric_name = 'Accuracy' # add evaluation metric to the evaluator engine val_metrics = {metric_name: Accuracy()} +# ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration, +# user can add output_transform to return other values evaluator = create_supervised_evaluator(net, val_metrics, device, True) # Add stats event handler to print validation stats via evaluator -logging.basicConfig(stream=sys.stdout, level=logging.INFO) -val_stats_handler = StatsHandler(output_transform=lambda x: None) +val_stats_handler = StatsHandler( + name='evaluator', + output_transform=lambda x: None, # no need to print loss value, so disable per iteration output + global_epoch_transform=lambda x: trainer.state.epoch) # fetch global epoch number from trainer val_stats_handler.attach(evaluator) -# Add early stopping handler to evaluator. +# add handler to record metrics to TensorBoard at every epoch +val_tensorboard_stats_handler = TensorBoardStatsHandler( + output_transform=lambda x: None, # no need to plot loss value, so disable per iteration output + global_epoch_transform=lambda x: trainer.state.epoch) # fetch global epoch number from trainer +val_tensorboard_stats_handler.attach(evaluator) + +# Add early stopping handler to evaluator early_stopper = EarlyStopping(patience=4, score_function=stopping_fn_from_metric(metric_name), trainer=trainer) evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) # create a validation data loader -val_ds = NiftiDataset(image_files=images[-5:], labels=labels[-5:], transform=imtrans) -val_loader = DataLoader(ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available()) +val_ds = NiftiDataset(image_files=images[-10:], labels=labels[-10:], transform=val_transforms) +val_loader = DataLoader(val_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available()) @trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) def run_validation(engine): evaluator.run(val_loader) -# create a training data loader -logging.basicConfig(stream=sys.stdout, level=logging.INFO) -train_ds = NiftiDataset(image_files=images[:15], labels=labels[:15], transform=imtrans) +# create a training data loader +train_ds = NiftiDataset(image_files=images[:10], labels=labels[:10], transform=train_transforms) train_loader = DataLoader(train_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available()) train_epochs = 30 diff --git a/examples/classification_3d/densenet_training_dict.py b/examples/classification_3d/densenet_training_dict.py new file mode 100644 index 0000000000..c0017df969 --- /dev/null +++ b/examples/classification_3d/densenet_training_dict.py @@ -0,0 +1,164 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import logging +import numpy as np +import torch +from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator, _prepare_batch +from ignite.handlers import ModelCheckpoint, EarlyStopping +from torch.utils.data import DataLoader + +# assumes the framework is found here, change as necessary +sys.path.append("../..") +import monai +import monai.transforms.compose as transforms +from monai.transforms.composables import \ + LoadNiftid, AddChanneld, Rescaled, Resized, RandRotate90d +from monai.handlers.stats_handler import StatsHandler +from monai.handlers.tensorboard_handlers import TensorBoardStatsHandler +from ignite.metrics import Accuracy +from monai.handlers.utils import stopping_fn_from_metric + +monai.config.print_config() +logging.basicConfig(stream=sys.stdout, level=logging.INFO) + +# IXI dataset as a demo, dowloadable from https://brain-development.org/ixi-dataset/ +images = [ + "/workspace/data/medical/ixi/IXI-T1/IXI314-IOP-0889-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI249-Guys-1072-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI609-HH-2600-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI173-HH-1590-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI020-Guys-0700-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI342-Guys-0909-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI134-Guys-0780-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI577-HH-2661-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI066-Guys-0731-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI130-HH-1528-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz" +] +# 2 binary labels for gender classification: man and woman +labels = np.array([ + 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0 +]) +train_files = [{'img': img, 'label': label} for img, label in zip(images[:10], labels[:10])] +val_files = [{'img': img, 'label': label} for img, label in zip(images[-10:], labels[-10:])] + +# Define transforms for image +train_transforms = transforms.Compose([ + LoadNiftid(keys=['img']), + AddChanneld(keys=['img']), + Rescaled(keys=['img']), + Resized(keys=['img'], output_spatial_shape=(96, 96, 96)), + RandRotate90d(keys=['img'], prob=0.8, spatial_axes=[0, 2]) +]) +val_transforms = transforms.Compose([ + LoadNiftid(keys=['img']), + AddChanneld(keys=['img']), + Rescaled(keys=['img']), + Resized(keys=['img'], output_spatial_shape=(96, 96, 96)) +]) + +# Define dataset, dataloader +check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) +check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) +check_data = monai.utils.misc.first(check_loader) +print(check_data['img'].shape, check_data['label']) + +# Create DenseNet121, CrossEntropyLoss and Adam optimizer +net = monai.networks.nets.densenet3d.densenet121( + in_channels=1, + out_channels=2, +) +loss = torch.nn.CrossEntropyLoss() +lr = 1e-5 +opt = torch.optim.Adam(net.parameters(), lr) +device = torch.device("cuda:0") + + +# ignite trainer expects batch=(img, label) and returns output=loss at every iteration, +# user can add output_transform to return other values, like: y_pred, y, etc. +def prepare_batch(batch, device=None, non_blocking=False): + return _prepare_batch((batch['img'], batch['label']), device, non_blocking) + + +trainer = create_supervised_trainer(net, opt, loss, device, False, prepare_batch=prepare_batch) + +# adding checkpoint handler to save models (network params and optimizer stats) during training +checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False) +trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, + handler=checkpoint_handler, + to_save={'net': net, 'opt': opt}) + +# StatsHandler prints loss at every iteration and print metrics at every epoch, +# we don't set metrics for trainer here, so just print loss, user can also customize print functions +# and can use output_transform to convert engine.state.output if it's not loss value +train_stats_handler = StatsHandler(name='trainer') +train_stats_handler.attach(trainer) + +# TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler +train_tensorboard_stats_handler = TensorBoardStatsHandler() +train_tensorboard_stats_handler.attach(trainer) + +# Set parameters for validation +validation_every_n_epochs = 1 + +metric_name = 'Accuracy' +# add evaluation metric to the evaluator engine +val_metrics = {metric_name: Accuracy()} +# ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration, +# user can add output_transform to return other values +evaluator = create_supervised_evaluator(net, val_metrics, device, True, prepare_batch=prepare_batch) + +# Add stats event handler to print validation stats via evaluator +val_stats_handler = StatsHandler( + name='evaluator', + output_transform=lambda x: None, # no need to print loss value, so disable per iteration output + global_epoch_transform=lambda x: trainer.state.epoch) # fetch global epoch number from trainer +val_stats_handler.attach(evaluator) + +# add handler to record metrics to TensorBoard at every epoch +val_tensorboard_stats_handler = TensorBoardStatsHandler( + output_transform=lambda x: None, # no need to plot loss value, so disable per iteration output + global_epoch_transform=lambda x: trainer.state.epoch) # fetch global epoch number from trainer +val_tensorboard_stats_handler.attach(evaluator) + +# Add early stopping handler to evaluator +early_stopper = EarlyStopping(patience=4, + score_function=stopping_fn_from_metric(metric_name), + trainer=trainer) +evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) + +# create a validation data loader +val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) +val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) + + +@trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) +def run_validation(engine): + evaluator.run(val_loader) + + +# create a training data loader +train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) +train_loader = DataLoader(train_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) + +train_epochs = 30 +state = trainer.run(train_loader, train_epochs) diff --git a/examples/multi_gpu_test.ipynb b/examples/multi_gpu_test.ipynb index 8f1827f15b..98911e12f2 100644 --- a/examples/multi_gpu_test.ipynb +++ b/examples/multi_gpu_test.ipynb @@ -53,7 +53,7 @@ "net = monai.networks.nets.UNet(\n", " dimensions=2,\n", " in_channels=1,\n", - " num_classes=1,\n", + " out_channels=1,\n", " channels=(16, 32, 64, 128, 256),\n", " strides=(2, 2, 2, 2),\n", " num_res_units=2,\n", diff --git a/examples/unet_inference_3d_array.py b/examples/segmentation_3d/unet_evaluation_array.py similarity index 58% rename from examples/unet_inference_3d_array.py rename to examples/segmentation_3d/unet_evaluation_array.py index 8fe417c7dd..f5c039a39e 100644 --- a/examples/unet_inference_3d_array.py +++ b/examples/segmentation_3d/unet_evaluation_array.py @@ -13,7 +13,7 @@ import sys import tempfile from glob import glob - +import logging import nibabel as nib import numpy as np import torch @@ -21,26 +21,29 @@ from torch.utils.data import DataLoader # assumes the framework is found here, change as necessary -sys.path.append("..") +sys.path.append("../..") from monai import config from monai.handlers.checkpoint_loader import CheckpointLoader from monai.handlers.segmentation_saver import SegmentationSaver import monai.transforms.compose as transforms from monai.data.nifti_reader import NiftiDataset -from monai.transforms import AddChannel, Rescale, ToTensor +from monai.transforms import AddChannel, Rescale from monai.networks.nets.unet import UNet from monai.networks.utils import predict_segmentation from monai.data.synthetic import create_test_image_3d from monai.utils.sliding_window_inference import sliding_window_inference +from monai.handlers.stats_handler import StatsHandler +from monai.handlers.mean_dice import MeanDice config.print_config() +logging.basicConfig(stream=sys.stdout, level=logging.INFO) tempdir = tempfile.mkdtemp() # tempdir = './temp' print('generating synthetic data to {} (this may take a while)'.format(tempdir)) -for i in range(50): - im, seg = create_test_image_3d(256, 256, 256) +for i in range(5): + im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i)) @@ -50,39 +53,59 @@ images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz'))) segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) -imtrans = transforms.Compose([Rescale(), AddChannel(), ToTensor()]) -segtrans = transforms.Compose([AddChannel(), ToTensor()]) + +# Define transforms for image and segmentation +imtrans = transforms.Compose([Rescale(), AddChannel()]) +segtrans = transforms.Compose([AddChannel()]) ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False) device = torch.device("cuda:0") -roi_size = (64, 64, 64) -sw_batch_size = 4 net = UNet( dimensions=3, in_channels=1, - num_classes=1, + out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ) net.to(device) +# define sliding window size and batch size for windows inference +roi_size = (96, 96, 96) +sw_batch_size = 4 + def _sliding_window_processor(engine, batch): net.eval() img, seg, meta_data = batch with torch.no_grad(): - seg_probs = sliding_window_inference(img, roi_size, sw_batch_size, lambda x: net(x)[0], device) - return predict_segmentation(seg_probs) + seg_probs = sliding_window_inference(img, roi_size, sw_batch_size, net, device) + return seg_probs, seg.to(device) -infer_engine = Engine(_sliding_window_processor) +evaluator = Engine(_sliding_window_processor) + +# add evaluation metric to the evaluator engine +MeanDice(add_sigmoid=True, to_onehot_y=False).attach(evaluator, 'Mean_Dice') + +# StatsHandler prints loss at every iteration and print metrics at every epoch, +# we don't need to print loss for evaluator, so just print metrics, user can also customize print functions +val_stats_handler = StatsHandler( + name='evaluator', + output_transform=lambda x: None # no need to print loss value, so disable per iteration output +) +val_stats_handler.attach(evaluator) # for the arrary data format, assume the 3rd item of batch data is the meta_data -SegmentationSaver(output_path='tempdir', output_ext='.nii.gz', output_postfix='seg', - batch_transform=lambda x: x[2]).attach(infer_engine) -# the model was trained by "unet_segmentation_3d_array" exmple -CheckpointLoader(load_path='./runs/net_checkpoint_120.pth', load_dict={'net': net}).attach(infer_engine) +file_saver = SegmentationSaver( + output_path='tempdir', output_ext='.nii.gz', output_postfix='seg', name='evaluator', + batch_transform=lambda x: x[2], output_transform=lambda output: predict_segmentation(output[0])) +file_saver.attach(evaluator) + +# the model was trained by "unet_training_array" exmple +ckpt_saver = CheckpointLoader(load_path='./runs/net_checkpoint_50.pth', load_dict={'net': net}) +ckpt_saver.attach(evaluator) +# sliding window inferene need to input 1 image in every iteration loader = DataLoader(ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()) -state = infer_engine.run(loader) +state = evaluator.run(loader) diff --git a/examples/unet_inference_3d_dict.py b/examples/segmentation_3d/unet_evaluation_dict.py similarity index 63% rename from examples/unet_inference_3d_dict.py rename to examples/segmentation_3d/unet_evaluation_dict.py index 405b49aa8d..e78abeac3f 100644 --- a/examples/unet_inference_3d_dict.py +++ b/examples/segmentation_3d/unet_evaluation_dict.py @@ -13,7 +13,7 @@ import sys import tempfile from glob import glob - +import logging import nibabel as nib import numpy as np import torch @@ -21,7 +21,7 @@ from torch.utils.data import DataLoader # assumes the framework is found here, change as necessary -sys.path.append("..") +sys.path.append("../..") import monai from monai.data.utils import list_data_collate @@ -29,19 +29,22 @@ from monai.data.synthetic import create_test_image_3d from monai.networks.utils import predict_segmentation from monai.networks.nets.unet import UNet -from monai.transforms.composables import LoadNiftid, AsChannelFirstd +from monai.transforms.composables import LoadNiftid, AsChannelFirstd, Rescaled import monai.transforms.compose as transforms from monai.handlers.segmentation_saver import SegmentationSaver from monai.handlers.checkpoint_loader import CheckpointLoader +from monai.handlers.stats_handler import StatsHandler +from monai.handlers.mean_dice import MeanDice from monai import config config.print_config() +logging.basicConfig(stream=sys.stdout, level=logging.INFO) tempdir = tempfile.mkdtemp() # tempdir = './temp' print('generating synthetic data to {} (this may take a while)'.format(tempdir)) -for i in range(50): - im, seg = create_test_image_3d(256, 256, 256, channel_dim=-1) +for i in range(5): + im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i)) @@ -52,44 +55,62 @@ images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz'))) segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) val_files = [{'img': img, 'seg': seg} for img, seg in zip(images, segs)] + +# Define transforms for image and segmentation val_transforms = transforms.Compose([ LoadNiftid(keys=['img', 'seg']), - AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1) + AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1), + Rescaled(keys=['img', 'seg']) ]) val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) device = torch.device("cuda:0") -roi_size = (64, 64, 64) -sw_batch_size = 4 net = UNet( dimensions=3, in_channels=1, - num_classes=1, + out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ) net.to(device) +# define sliding window size and batch size for windows inference +roi_size = (96, 96, 96) +sw_batch_size = 4 + def _sliding_window_processor(engine, batch): net.eval() with torch.no_grad(): - seg_probs = sliding_window_inference(batch['img'], roi_size, sw_batch_size, lambda x: net(x)[0], device) - return predict_segmentation(seg_probs) + seg_probs = sliding_window_inference(batch['img'], roi_size, sw_batch_size, net, device) + return seg_probs, batch['seg'].to(device) + +evaluator = Engine(_sliding_window_processor) -infer_engine = Engine(_sliding_window_processor) +# add evaluation metric to the evaluator engine +MeanDice(add_sigmoid=True, to_onehot_y=False).attach(evaluator, 'Mean_Dice') + +# StatsHandler prints loss at every iteration and print metrics at every epoch, +# we don't need to print loss for evaluator, so just print metrics, user can also customize print functions +val_stats_handler = StatsHandler( + name='evaluator', + output_transform=lambda x: None # no need to print loss value, so disable per iteration output +) +val_stats_handler.attach(evaluator) -# for the arrary data format, assume the 3rd item of batch data is the meta_data -SegmentationSaver(output_path='tempdir', output_ext='.nii.gz', output_postfix='seg', +# convert the necessary metadata from batch data +SegmentationSaver(output_path='tempdir', output_ext='.nii.gz', output_postfix='seg', name='evaluator', batch_transform=lambda batch: {'filename_or_obj': batch['img.filename_or_obj'], 'original_affine': batch['img.original_affine'], 'affine': batch['img.affine'], - }).attach(infer_engine) -# the model was trained by "unet_segmentation_3d_array" exmple -CheckpointLoader(load_path='./runs/net_checkpoint_120.pth', load_dict={'net': net}).attach(infer_engine) + }, + output_transform=lambda output: predict_segmentation(output[0])).attach(evaluator) +# the model was trained by "unet_training_dict" exmple +CheckpointLoader(load_path='./runs/net_checkpoint_50.pth', load_dict={'net': net}).attach(evaluator) +# sliding window inferene need to input 1 image in every iteration val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()) -state = infer_engine.run(val_loader) +state = evaluator.run(val_loader) diff --git a/examples/unet_segmentation_3d_array.py b/examples/segmentation_3d/unet_training_array.py similarity index 57% rename from examples/unet_segmentation_3d_array.py rename to examples/segmentation_3d/unet_training_array.py index 3b0d880f10..c9cb70875b 100644 --- a/examples/unet_segmentation_3d_array.py +++ b/examples/segmentation_3d/unet_training_array.py @@ -14,7 +14,6 @@ import tempfile from glob import glob import logging - import nibabel as nib import numpy as np import torch @@ -23,26 +22,28 @@ from torch.utils.data import DataLoader # assumes the framework is found here, change as necessary -sys.path.append("..") +sys.path.append("../..") import monai import monai.transforms.compose as transforms from monai.data.nifti_reader import NiftiDataset -from monai.transforms import AddChannel, Rescale, ToTensor, UniformRandomPatch +from monai.transforms import AddChannel, Rescale, UniformRandomPatch, Resize from monai.handlers.stats_handler import StatsHandler from monai.handlers.tensorboard_handlers import TensorBoardStatsHandler, TensorBoardImageHandler from monai.handlers.mean_dice import MeanDice from monai.data.synthetic import create_test_image_3d from monai.handlers.utils import stopping_fn_from_metric +from monai.networks.utils import predict_segmentation monai.config.print_config() +logging.basicConfig(stream=sys.stdout, level=logging.INFO) -# Create a temporary directory and 50 random image, mask paris +# Create a temporary directory and 40 random image, mask paris tempdir = tempfile.mkdtemp() print('generating synthetic data to {} (this may take a while)'.format(tempdir)) -for i in range(50): - im, seg = create_test_image_3d(128, 128, 128) +for i in range(40): + im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i)) @@ -54,124 +55,119 @@ segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) # Define transforms for image and segmentation -imtrans = transforms.Compose([ +train_imtrans = transforms.Compose([ + Rescale(), + AddChannel(), + UniformRandomPatch((96, 96, 96)) +]) +train_segtrans = transforms.Compose([ + AddChannel(), + UniformRandomPatch((96, 96, 96)) +]) +val_imtrans = transforms.Compose([ Rescale(), AddChannel(), - UniformRandomPatch((96, 96, 96)), - ToTensor() + Resize((96, 96, 96)) ]) -segtrans = transforms.Compose([ +val_segtrans = transforms.Compose([ AddChannel(), - UniformRandomPatch((96, 96, 96)), - ToTensor() + Resize((96, 96, 96)) ]) -# Define nifti dataset, dataloader. -ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans) -loader = DataLoader(ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) -im, seg = monai.utils.misc.first(loader) +# Define nifti dataset, dataloader +check_ds = NiftiDataset(images, segs, transform=train_imtrans, seg_transform=train_segtrans) +check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) +im, seg = monai.utils.misc.first(check_loader) print(im.shape, seg.shape) -lr = 1e-5 +# create a training data loader +train_ds = NiftiDataset(images[:20], segs[:20], transform=train_imtrans, seg_transform=train_segtrans) +train_loader = DataLoader(train_ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available()) +# create a validation data loader +val_ds = NiftiDataset(images[-20:], segs[-20:], transform=val_imtrans, seg_transform=val_segtrans) +val_loader = DataLoader(val_ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available()) -# Create UNet, DiceLoss and Adam optimizer. + +# Create UNet, DiceLoss and Adam optimizer net = monai.networks.nets.UNet( dimensions=3, in_channels=1, - num_classes=1, + out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ) - loss = monai.losses.DiceLoss(do_sigmoid=True) +lr = 1e-3 opt = torch.optim.Adam(net.parameters(), lr) +device = torch.device("cuda:0") -# Since network outputs logits and segmentation, we need a custom function. -def _loss_fn(i, j): - return loss(i[0], j) - - -# Create trainer -device = torch.device("cpu:0") -trainer = create_supervised_trainer(net, opt, _loss_fn, device, False, - output_transform=lambda x, y, y_pred, loss: [y_pred[1], loss.item(), y]) +# ignite trainer expects batch=(img, seg) and returns output=loss at every iteration, +# user can add output_transform to return other values, like: y_pred, y, etc. +trainer = create_supervised_trainer(net, opt, loss, device, False) # adding checkpoint handler to save models (network params and optimizer stats) during training checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False) trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={'net': net, 'opt': opt}) -logging.basicConfig(stream=sys.stdout, level=logging.INFO) -# print training loss to commandline -train_stats_handler = StatsHandler(output_transform=lambda x: x[1]) +# StatsHandler prints loss at every iteration and print metrics at every epoch, +# we don't set metrics for trainer here, so just print loss, user can also customize print functions +# and can use output_transform to convert engine.state.output if it's not a loss value +train_stats_handler = StatsHandler(name='trainer') train_stats_handler.attach(trainer) -# record training loss to TensorBoard at every iteration -train_tensorboard_stats_handler = TensorBoardStatsHandler( - output_transform=lambda x: {'training_dice_loss': x[1]}, # plot under tag name taining_dice_loss - global_epoch_transform=lambda x: trainer.state.epoch) +# TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler +train_tensorboard_stats_handler = TensorBoardStatsHandler() train_tensorboard_stats_handler.attach(trainer) -@trainer.on(Events.EPOCH_COMPLETED) -def log_training_loss(engine): - engine.logger.info("Epoch[%s] Loss: %s", engine.state.epoch, engine.state.output[1]) - - -# Set parameters for validation validation_every_n_epochs = 1 +# Set parameters for validation metric_name = 'Mean_Dice' - # add evaluation metric to the evaluator engine -val_metrics = {metric_name: MeanDice( - add_sigmoid=True, to_onehot_y=False, output_transform=lambda output: (output[0][0], output[1])) -} +val_metrics = {metric_name: MeanDice(add_sigmoid=True, to_onehot_y=False)} + +# ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration, +# user can add output_transform to return other values evaluator = create_supervised_evaluator(net, val_metrics, device, True) + +@trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) +def run_validation(engine): + evaluator.run(val_loader) + + +# Add early stopping handler to evaluator +early_stopper = EarlyStopping(patience=4, + score_function=stopping_fn_from_metric(metric_name), + trainer=trainer) +evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) + # Add stats event handler to print validation stats via evaluator val_stats_handler = StatsHandler( - output_transform=lambda x: None, # disable per iteration output - global_epoch_transform=lambda x: trainer.state.epoch) + name='evaluator', + output_transform=lambda x: None, # no need to print loss value, so disable per iteration output + global_epoch_transform=lambda x: trainer.state.epoch) # fetch global epoch number from trainer val_stats_handler.attach(evaluator) -# add handler to record metrics to TensorBoard at every epoch +# add handler to record metrics to TensorBoard at every validation epoch val_tensorboard_stats_handler = TensorBoardStatsHandler( - output_transform=lambda x: None, # no iteration plot - global_epoch_transform=lambda x: trainer.state.epoch) # use epoch number from trainer + output_transform=lambda x: None, # no need to plot loss value, so disable per iteration output + global_epoch_transform=lambda x: trainer.state.epoch) # fetch global epoch number from trainer val_tensorboard_stats_handler.attach(evaluator) -# add handler to draw several images and the corresponding labels and model outputs -# here we draw the first 3 images(draw the first channel) as GIF format along Depth axis + +# add handler to draw the first image and the corresponding label and model output in the last batch +# here we draw the 3D output as GIF format along Depth axis, at every validation epoch val_tensorboard_image_handler = TensorBoardImageHandler( batch_transform=lambda batch: (batch[0], batch[1]), - output_transform=lambda output: output[0][1], + output_transform=lambda output: predict_segmentation(output[0]), global_iter_transform=lambda x: trainer.state.epoch ) evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=val_tensorboard_image_handler) -# Add early stopping handler to evaluator -early_stopper = EarlyStopping(patience=4, - score_function=stopping_fn_from_metric(metric_name), - trainer=trainer) -evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) - -# create a validation data loader -val_ds = NiftiDataset(images[-20:], segs[-20:], transform=imtrans, seg_transform=segtrans) -val_loader = DataLoader(ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available()) - - -@trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) -def run_validation(engine): - evaluator.run(val_loader) - - -# create a training data loader -logging.basicConfig(stream=sys.stdout, level=logging.INFO) - -train_ds = NiftiDataset(images[:20], segs[:20], transform=imtrans, seg_transform=segtrans) -train_loader = DataLoader(train_ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available()) train_epochs = 30 state = trainer.run(train_loader, train_epochs) diff --git a/examples/unet_segmentation_3d_dict.py b/examples/segmentation_3d/unet_training_dict.py similarity index 50% rename from examples/unet_segmentation_3d_dict.py rename to examples/segmentation_3d/unet_training_dict.py index 0e1e78811b..018f7076a7 100644 --- a/examples/unet_segmentation_3d_dict.py +++ b/examples/segmentation_3d/unet_training_dict.py @@ -14,36 +14,36 @@ import tempfile from glob import glob import logging - import nibabel as nib import numpy as np import torch -from torch.utils.tensorboard import SummaryWriter from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator, _prepare_batch from ignite.handlers import ModelCheckpoint, EarlyStopping from torch.utils.data import DataLoader # assumes the framework is found here, change as necessary -sys.path.append("..") +sys.path.append("../..") import monai import monai.transforms.compose as transforms from monai.transforms.composables import \ - LoadNiftid, AsChannelFirstd, RandCropByPosNegLabeld, RandRotate90d + LoadNiftid, AsChannelFirstd, Rescaled, RandCropByPosNegLabeld, RandRotate90d from monai.handlers.stats_handler import StatsHandler +from monai.handlers.tensorboard_handlers import TensorBoardStatsHandler, TensorBoardImageHandler from monai.handlers.mean_dice import MeanDice -from monai.visualize import img2tensorboard from monai.data.synthetic import create_test_image_3d from monai.handlers.utils import stopping_fn_from_metric from monai.data.utils import list_data_collate +from monai.networks.utils import predict_segmentation monai.config.print_config() +logging.basicConfig(stream=sys.stdout, level=logging.INFO) -# Create a temporary directory and 50 random image, mask paris +# Create a temporary directory and 40 random image, mask paris tempdir = tempfile.mkdtemp() print('generating synthetic data to {} (this may take a while)'.format(tempdir)) -for i in range(50): - im, seg = create_test_image_3d(128, 128, 128, channel_dim=-1) +for i in range(40): + im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, 'img%i.nii.gz' % i)) @@ -53,145 +53,128 @@ images = sorted(glob(os.path.join(tempdir, 'img*.nii.gz'))) segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) -train_files = [{'img': img, 'seg': seg} for img, seg in zip(images[:40], segs[:40])] -val_files = [{'img': img, 'seg': seg} for img, seg in zip(images[-10:], segs[-10:])] +train_files = [{'img': img, 'seg': seg} for img, seg in zip(images[:20], segs[:20])] +val_files = [{'img': img, 'seg': seg} for img, seg in zip(images[-20:], segs[-20:])] # Define transforms for image and segmentation train_transforms = transforms.Compose([ LoadNiftid(keys=['img', 'seg']), AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1), + Rescaled(keys=['img', 'seg']), RandCropByPosNegLabeld(keys=['img', 'seg'], label_key='seg', size=[96, 96, 96], pos=1, neg=1, num_samples=4), - RandRotate90d(keys=['img', 'seg'], prob=0.8, axes=[1, 3]) + RandRotate90d(keys=['img', 'seg'], prob=0.8, spatial_axes=[0, 2]) ]) val_transforms = transforms.Compose([ LoadNiftid(keys=['img', 'seg']), - AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1) + AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1), + Rescaled(keys=['img', 'seg']) ]) -# Define nifti dataset, dataloader. -ds = monai.data.Dataset(data=train_files, transform=train_transforms) -loader = DataLoader(ds, batch_size=2, num_workers=4, collate_fn=list_data_collate, - pin_memory=torch.cuda.is_available()) -check_data = monai.utils.misc.first(loader) +# Define dataset, dataloader +check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) +# use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training +check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, collate_fn=list_data_collate, + pin_memory=torch.cuda.is_available()) +check_data = monai.utils.misc.first(check_loader) print(check_data['img'].shape, check_data['seg'].shape) -lr = 1e-5 -# Create UNet, DiceLoss and Adam optimizer. +# create a training data loader +train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) +# use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training +train_loader = DataLoader(train_ds, batch_size=2, num_workers=4, collate_fn=list_data_collate, + pin_memory=torch.cuda.is_available()) +# create a validation data loader +val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) +val_loader = DataLoader(val_ds, batch_size=5, num_workers=8, collate_fn=list_data_collate, + pin_memory=torch.cuda.is_available()) + + +# Create UNet, DiceLoss and Adam optimizer net = monai.networks.nets.UNet( dimensions=3, in_channels=1, - num_classes=1, + out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ) - loss = monai.losses.DiceLoss(do_sigmoid=True) +lr = 1e-3 opt = torch.optim.Adam(net.parameters(), lr) +device = torch.device("cuda:0") -# Since network outputs logits and segmentation, we need a custom function. -def _loss_fn(i, j): - return loss(i[0], j) - - -# Create trainer +# ignite trainer expects batch=(img, seg) and returns output=loss at every iteration, +# user can add output_transform to return other values, like: y_pred, y, etc. def prepare_batch(batch, device=None, non_blocking=False): return _prepare_batch((batch['img'], batch['seg']), device, non_blocking) -device = torch.device("cuda:0") -trainer = create_supervised_trainer(net, opt, _loss_fn, device, False, - prepare_batch=prepare_batch, - output_transform=lambda x, y, y_pred, loss: [y_pred, loss.item(), y]) +trainer = create_supervised_trainer(net, opt, loss, device, False, prepare_batch=prepare_batch) # adding checkpoint handler to save models (network params and optimizer stats) during training checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False) trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={'net': net, 'opt': opt}) -train_stats_handler = StatsHandler(output_transform=lambda x: x[1]) + +# StatsHandler prints loss at every iteration and print metrics at every epoch, +# we don't set metrics for trainer here, so just print loss, user can also customize print functions +# and can use output_transform to convert engine.state.output if it's not loss value +train_stats_handler = StatsHandler(name='trainer') train_stats_handler.attach(trainer) +# TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler +train_tensorboard_stats_handler = TensorBoardStatsHandler() +train_tensorboard_stats_handler.attach(trainer) -@trainer.on(Events.EPOCH_COMPLETED) -def log_training_loss(engine): - # log loss to tensorboard with second item of engine.state.output, loss.item() from output_transform - writer.add_scalar('Loss/train', engine.state.output[1], engine.state.epoch) - - # tensor of ones to use where for converting labels to zero and ones - ones = torch.ones(engine.state.batch['seg'][0].shape, dtype=torch.int32) - first_output_tensor = engine.state.output[0][1][0].detach().cpu() - # log model output to tensorboard, as three dimensional tensor with no channels dimension - img2tensorboard.add_animated_gif_no_channels(writer, "first_output_final_batch", first_output_tensor, 64, - 255, engine.state.epoch) - # get label tensor and convert to single class - first_label_tensor = torch.where(engine.state.batch['seg'][0] > 0, ones, engine.state.batch['seg'][0]) - # log label tensor to tensorboard, there is a channel dimension when getting label from batch - img2tensorboard.add_animated_gif(writer, "first_label_final_batch", first_label_tensor, 64, - 255, engine.state.epoch) - second_output_tensor = engine.state.output[0][1][1].detach().cpu() - img2tensorboard.add_animated_gif_no_channels(writer, "second_output_final_batch", second_output_tensor, 64, - 255, engine.state.epoch) - second_label_tensor = torch.where(engine.state.batch['seg'][1] > 0, ones, engine.state.batch['seg'][1]) - img2tensorboard.add_animated_gif(writer, "second_label_final_batch", second_label_tensor, 64, - 255, engine.state.epoch) - third_output_tensor = engine.state.output[0][1][2].detach().cpu() - img2tensorboard.add_animated_gif_no_channels(writer, "third_output_final_batch", third_output_tensor, 64, - 255, engine.state.epoch) - third_label_tensor = torch.where(engine.state.batch['seg'][2] > 0, ones, engine.state.batch['seg'][2]) - img2tensorboard.add_animated_gif(writer, "third_label_final_batch", third_label_tensor, 64, - 255, engine.state.epoch) - engine.logger.info("Epoch[%s] Loss: %s", engine.state.epoch, engine.state.output[1]) - - -writer = SummaryWriter() +validation_every_n_iters = 5 # Set parameters for validation -validation_every_n_epochs = 1 metric_name = 'Mean_Dice' - # add evaluation metric to the evaluator engine val_metrics = {metric_name: MeanDice(add_sigmoid=True, to_onehot_y=False)} -evaluator = create_supervised_evaluator(net, val_metrics, device, True, - prepare_batch=prepare_batch, - output_transform=lambda x, y, y_pred: (y_pred[0], y)) - -# Add stats event handler to print validation stats via evaluator -logging.basicConfig(stream=sys.stdout, level=logging.INFO) -val_stats_handler = StatsHandler(output_transform=lambda x: None) -val_stats_handler.attach(evaluator) -# Add early stopping handler to evaluator. -early_stopper = EarlyStopping(patience=4, - score_function=stopping_fn_from_metric(metric_name), - trainer=trainer) -evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) +# ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration, +# user can add output_transform to return other values +evaluator = create_supervised_evaluator(net, val_metrics, device, True, prepare_batch=prepare_batch) -# create a validation data loader -val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) -val_loader = DataLoader(ds, batch_size=5, num_workers=8, collate_fn=list_data_collate, - pin_memory=torch.cuda.is_available()) - -@trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) +@trainer.on(Events.ITERATION_COMPLETED(every=validation_every_n_iters)) def run_validation(engine): evaluator.run(val_loader) -@evaluator.on(Events.EPOCH_COMPLETED) -def log_metrics_to_tensorboard(engine): - for _, value in engine.state.metrics.items(): - writer.add_scalar('Metrics/' + metric_name, value, trainer.state.epoch) +# Add early stopping handler to evaluator +early_stopper = EarlyStopping(patience=4, + score_function=stopping_fn_from_metric(metric_name), + trainer=trainer) +evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) +# Add stats event handler to print validation stats via evaluator +val_stats_handler = StatsHandler( + name='evaluator', + output_transform=lambda x: None, # no need to print loss value, so disable per iteration output + global_epoch_transform=lambda x: trainer.state.epoch) # fetch global epoch number from trainer +val_stats_handler.attach(evaluator) -# create a training data loader -logging.basicConfig(stream=sys.stdout, level=logging.INFO) +# add handler to record metrics to TensorBoard at every validation epoch +val_tensorboard_stats_handler = TensorBoardStatsHandler( + output_transform=lambda x: None, # no need to plot loss value, so disable per iteration output + global_epoch_transform=lambda x: trainer.state.iteration) # fetch global iteration number from trainer +val_tensorboard_stats_handler.attach(evaluator) + +# add handler to draw the first image and the corresponding label and model output in the last batch +# here we draw the 3D output as GIF format along the depth axis, every 2 validation iterations. +val_tensorboard_image_handler = TensorBoardImageHandler( + batch_transform=lambda batch: (batch['img'], batch['seg']), + output_transform=lambda output: predict_segmentation(output[0]), + global_iter_transform=lambda x: trainer.state.epoch +) +evaluator.add_event_handler( + event_name=Events.ITERATION_COMPLETED(every=2), handler=val_tensorboard_image_handler) -train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) -train_loader = DataLoader(train_ds, batch_size=2, num_workers=4, collate_fn=list_data_collate, - pin_memory=torch.cuda.is_available()) -train_epochs = 30 +train_epochs = 5 state = trainer.run(train_loader, train_epochs) diff --git a/examples/unet_segmentation_3d.ipynb b/examples/unet_segmentation_3d.ipynb index 0d49742f10..f42c33a2a5 100644 --- a/examples/unet_segmentation_3d.ipynb +++ b/examples/unet_segmentation_3d.ipynb @@ -145,7 +145,7 @@ "net = monai.networks.nets.UNet(\n", " dimensions=3,\n", " in_channels=1,\n", - " num_classes=1,\n", + " out_channels=1,\n", " channels=(16, 32, 64, 128, 256),\n", " strides=(2, 2, 2, 2),\n", " num_res_units=2,\n", diff --git a/monai/data/synthetic.py b/monai/data/synthetic.py index 4efd4fe393..063c16a965 100644 --- a/monai/data/synthetic.py +++ b/monai/data/synthetic.py @@ -20,7 +20,7 @@ def create_test_image_2d(width, height, num_objs=12, rad_max=30, noise_max=0.0, `radMax'. The mask will have `numSegClasses' number of classes for segmentations labeled sequentially from 1, plus a background class represented as 0. If `noiseMax' is greater than 0 then noise will be added to the image taken from the uniform distribution on range [0,noiseMax). If `channel_dim' is None, will create an image without channel - dimemsion, otherwise create an image with channel dimension as first dim or last dim. + dimension, otherwise create an image with channel dimension as first dim or last dim. """ image = np.zeros((width, height)) @@ -44,7 +44,7 @@ def create_test_image_2d(width, height, num_objs=12, rad_max=30, noise_max=0.0, if channel_dim is not None: assert isinstance(channel_dim, int) and channel_dim in (-1, 0, 2), 'invalid channel dim.' noisyimage, labels = noisyimage[None], labels[None] \ - if channel_dim == 0 else noisyimage[..., None], labels[..., None] + if channel_dim == 0 else (noisyimage[..., None], labels[..., None]) return noisyimage, labels @@ -54,7 +54,8 @@ def create_test_image_3d(height, width, depth, num_objs=12, rad_max=30, """ Return a noisy 3D image and segmentation. - See also: create_test_image_2d + See also: + ``create_test_image_2d`` """ image = np.zeros((width, height, depth)) diff --git a/monai/handlers/classification_saver.py b/monai/handlers/classification_saver.py new file mode 100644 index 0000000000..501dce816f --- /dev/null +++ b/monai/handlers/classification_saver.py @@ -0,0 +1,95 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import csv +import numpy as np +import torch +from ignite.engine import Events +import logging + + +class ClassificationSaver: + """ + Event handler triggered on completing every iteration to save the classification predictions as CSV file. + """ + + def __init__(self, output_dir='./', overwrite=True, + batch_transform=lambda x: x, output_transform=lambda x: x, name=None): + """ + Args: + output_dir (str): output CSV file directory. + overwrite (bool): whether to overwriting existing CSV file content. If we are not overwriting, + then we check if the results have been previously saved, and load them to the prediction_dict. + batch_transform (Callable): a callable that is used to transform the + ignite.engine.batch into expected format to extract the meta_data dictionary. + output_transform (Callable): a callable that is used to transform the + ignite.engine.output into the form expected model prediction data. + The first dimension of this transform's output will be treated as the + batch dimension. Each item in the batch will be saved individually. + name (str): identifier of logging.logger to use, defaulting to `engine.logger`. + + """ + self.output_dir = output_dir + self._prediction_dict = {} + self._preds_filepath = os.path.join(output_dir, 'predictions.csv') + self.overwrite = overwrite + self.batch_transform = batch_transform + self.output_transform = output_transform + + self.logger = None if name is None else logging.getLogger(name) + + def attach(self, engine): + if self.logger is None: + self.logger = engine.logger + if not engine.has_event_handler(self, Events.ITERATION_COMPLETED): + engine.add_event_handler(Events.ITERATION_COMPLETED, self) + if not engine.has_event_handler(self.finalize, Events.COMPLETED): + engine.add_event_handler(Events.COMPLETED, self.finalize) + + def finalize(self, _engine=None): + """ + Writes the prediction dict to a csv + + """ + if not self.overwrite and os.path.exists(self._preds_filepath): + with open(self._preds_filepath, 'r') as f: + reader = csv.reader(f) + for row in reader: + self._prediction_dict[row[0]] = np.array(row[1:]).astype(np.float32) + + if not os.path.exists(self.output_dir): + os.makedirs(self.output_dir) + with open(self._preds_filepath, 'w') as f: + for k, v in sorted(self._prediction_dict.items()): + f.write(k) + for result in v.flatten(): + f.write("," + str(result)) + f.write("\n") + self.logger.info('saved classification predictions into: {}'.format(self._preds_filepath)) + + def __call__(self, engine): + """ + This method assumes self.batch_transform will extract Metadata from the input batch. + Metadata should have the following keys: + + - ``'filename_or_obj'`` -- save the prediction corresponding to file name. + + """ + meta_data = self.batch_transform(engine.state.batch) + filenames = meta_data['filename_or_obj'] + + engine_output = self.output_transform(engine.state.output) + for batch_id, filename in enumerate(filenames): # save a batch of files + output = engine_output[batch_id] + if isinstance(output, torch.Tensor): + output = output.detach().cpu().numpy() + self._prediction_dict[filename] = output.astype(np.float32) diff --git a/monai/handlers/segmentation_saver.py b/monai/handlers/segmentation_saver.py index 1f3fe2615d..98eb972d3a 100644 --- a/monai/handlers/segmentation_saver.py +++ b/monai/handlers/segmentation_saver.py @@ -13,7 +13,7 @@ import numpy as np import torch from ignite.engine import Events - +import logging from monai.data.nifti_writer import write_nifti @@ -50,7 +50,8 @@ def __init__(self, output_path='./', dtype='float32', output_postfix='seg', outp def attach(self, engine): if self.logger is None: self.logger = engine.logger - return engine.add_event_handler(Events.ITERATION_COMPLETED, self) + if not engine.has_event_handler(self, Events.ITERATION_COMPLETED): + engine.add_event_handler(Events.ITERATION_COMPLETED, self) @staticmethod def _create_file_basename(postfix, input_file_name, folder_path, data_root_dir=""): @@ -115,5 +116,6 @@ def __call__(self, engine): output_filename = self._create_file_basename(self.output_postfix, filename, self.output_path) output_filename = '{}{}'.format(output_filename, self.output_ext) # change output to "channel last" format and write to nifti format file - write_nifti(np.moveaxis(seg_output, 0, -1), affine_, output_filename, original_affine_, dtype=seg_output.dtype) + to_save = np.moveaxis(seg_output, 0, -1) + write_nifti(to_save, affine_, output_filename, original_affine_, dtype=seg_output.dtype) self.logger.info('saved: {}'.format(output_filename)) diff --git a/monai/handlers/tensorboard_handlers.py b/monai/handlers/tensorboard_handlers.py index 2d5116bd59..411a53084b 100644 --- a/monai/handlers/tensorboard_handlers.py +++ b/monai/handlers/tensorboard_handlers.py @@ -18,6 +18,8 @@ from monai.utils.misc import is_scalar from monai.transforms.utils import rescale_array +DEFAULT_TAG = 'Loss' + class TensorBoardStatsHandler(object): """TensorBoardStatsHandler defines a set of Ignite Event-handlers for all the TensorBoard logics. @@ -36,8 +38,9 @@ def __init__(self, summary_writer=None, epoch_event_writer=None, iteration_event_writer=None, - output_transform=lambda x: {'Loss': x}, - global_epoch_transform=lambda x: x): + output_transform=lambda x: x, + global_epoch_transform=lambda x: x, + tag_name=DEFAULT_TAG): """ Args: summary_writer (SummaryWriter): user can specify TensorBoard SummaryWriter, @@ -47,17 +50,20 @@ def __init__(self, iteration_event_writer (Callable): custimized callable TensorBoard writer for iteration level. must accept parameter "engine" and "summary_writer", use default event writer if None. output_transform (Callable): a callable that is used to transform the - ``ignite.engine.output`` into a dictionary of (tag_name: scalar) pairs to be plotted onto tensorboard. - by default this scalar plotting happens when every iteration completed. + ``ignite.engine.output`` into a scalar to plot, or a dictionary of {key: scalar}. + in the latter case, the output string will be formated as key: value. + by default this value plotting happens when every iteration completed. global_epoch_transform (Callable): a callable that is used to customize global epoch number. For example, in evaluation, the evaluator engine might want to use trainer engines epoch number when plotting epoch vs metric curves. + tag_name (string): when iteration output is a scalar, tag_name is used to plot, defaults to ``'Loss'``. """ self._writer = SummaryWriter() if summary_writer is None else summary_writer self.epoch_event_writer = epoch_event_writer self.iteration_event_writer = iteration_event_writer self.output_transform = output_transform self.global_epoch_transform = global_epoch_transform + self.tag_name = tag_name def attach(self, engine: Engine): """Register a set of Ignite Event-Handlers to a specified Ignite engine. @@ -121,31 +127,34 @@ def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter): writer (SummaryWriter): TensorBoard writer, created in TensorBoardHandler. """ - loss_dict = self.output_transform(engine.state.output) - if loss_dict is None: + loss = self.output_transform(engine.state.output) + if loss is None: return # do nothing if output is empty - if not isinstance(loss_dict, dict): - raise ValueError('TensorBoardStatsHandler requires' - ' output_transform(engine.state.output) returning a dictionary' - ' of key and scalar pairs to plot' - ' got {}.'.format(type(loss_dict))) - for name, value in loss_dict.items(): - if not is_scalar(value): - warnings.warn('ignoring non-scalar output in tensorboard curve plotting,' - ' make sure `output_transform(engine.state.output)` returns' - ' a dictionary of key and scalar pairs to avoid this warning.' - ' Got {}:{}'.format(name, type(value))) - continue - plot_value = value.item() if torch.is_tensor(value) else value - writer.add_scalar(name, plot_value, engine.state.iteration) + if isinstance(loss, dict): + for name in sorted(loss): + value = loss[name] + if not is_scalar(value): + warnings.warn('ignoring non-scalar output in TensorBoardStatsHandler,' + ' make sure `output_transform(engine.state.output)` returns' + ' a scalar or dictionary of key and scalar pairs to avoid this warning.' + ' {}:{}'.format(name, type(value))) + continue # not plot multi dimensional output + writer.add_scalar(name, value.item() if torch.is_tensor(value) else value, engine.state.iteration) + elif is_scalar(loss): # not printing multi dimensional output + writer.add_scalar(self.tag_name, loss.item() if torch.is_tensor(loss) else loss, engine.state.iteration) + else: + warnings.warn('ignoring non-scalar output in TensorBoardStatsHandler,' + ' make sure `output_transform(engine.state.output)` returns' + ' a scalar or a dictionary of key and scalar pairs to avoid this warning.' + ' {}'.format(type(loss))) writer.flush() class TensorBoardImageHandler(object): """TensorBoardImageHandler is an ignite Event handler that can visualise images, labels and outputs as 2D/3D images. 2D output (shape in Batch, channel, H, W) will be shown as simple image using the first element in the batch, - for 3D to ND output (shape in Batch, channel, H, W, D) input, - the last three dimensions will be shown as GIF image along the last axis (typically Depth). + for 3D to ND output (shape in Batch, channel, H, W, D) input, each of ``self.max_channels`` number of images' + last three dimensions will be shown as animated GIF along the last axis (typically Depth). It's can be used for any Ignite Engine (trainer, validator and evaluator). User can easily added it to engine for any expected Event, for example: ``EPOCH_COMPLETED``, @@ -230,9 +239,9 @@ def _add_2_or_3_d(self, data, step, tag='output'): return if d.ndim == 3: - if d.shape[0] == 3 and self.max_channels == 3: # rgb? + if d.shape[0] == 3 and self.max_channels == 3: # RGB dataformats = 'CHW' - self._writer.add_image('{}_{}'.format(tag, dataformats), d, step, dataformats='CHW') + self._writer.add_image('{}_{}'.format(tag, dataformats), d, step, dataformats=dataformats) return for j, d2 in enumerate(d[:self.max_channels]): d2 = rescale_array(d2, 0, 1) diff --git a/monai/networks/nets/densenet3d.py b/monai/networks/nets/densenet3d.py index f5493f6c04..78fab167c4 100644 --- a/monai/networks/nets/densenet3d.py +++ b/monai/networks/nets/densenet3d.py @@ -146,7 +146,6 @@ def __init__(self, OrderedDict([ ('relu', nn.ReLU(inplace=True)), ('norm', get_avgpooling_type(spatial_dims, is_adaptive=True)(1)), - ('relu', nn.ReLU(inplace=True)), ('flatten', nn.Flatten(1)), ('class', nn.Linear(in_channels, out_channels)), ])) diff --git a/monai/networks/nets/unet.py b/monai/networks/nets/unet.py index b0d42612eb..ad9b3ddbf4 100644 --- a/monai/networks/nets/unet.py +++ b/monai/networks/nets/unet.py @@ -13,7 +13,6 @@ from monai.networks.blocks.convolutions import Convolution, ResidualUnit from monai.networks.layers.simplelayers import SkipConnection -from monai.networks.utils import predict_segmentation from monai.utils import export from monai.utils.aliases import alias @@ -22,13 +21,13 @@ @alias("Unet", "unet") class UNet(nn.Module): - def __init__(self, dimensions, in_channels, num_classes, channels, strides, kernel_size=3, up_kernel_size=3, + def __init__(self, dimensions, in_channels, out_channels, channels, strides, kernel_size=3, up_kernel_size=3, num_res_units=0, instance_norm=True, dropout=0): super().__init__() assert len(channels) == (len(strides) + 1) self.dimensions = dimensions self.in_channels = in_channels - self.num_classes = num_classes + self.out_channels = out_channels self.channels = channels self.strides = strides self.kernel_size = kernel_size @@ -58,7 +57,7 @@ def _create_block(inc, outc, channels, strides, is_top): return nn.Sequential(down, SkipConnection(subblock), up) - self.model = _create_block(in_channels, num_classes, self.channels, self.strides, True) + self.model = _create_block(in_channels, out_channels, self.channels, self.strides, True) def _get_down_layer(self, in_channels, out_channels, strides, is_top): if self.num_res_units > 0: @@ -98,4 +97,4 @@ def _get_up_layer(self, in_channels, out_channels, strides, is_top): def forward(self, x): x = self.model(x) - return x, predict_segmentation(x) + return x diff --git a/monai/transforms/composables.py b/monai/transforms/composables.py index c19c3f7df9..f86afd546e 100644 --- a/monai/transforms/composables.py +++ b/monai/transforms/composables.py @@ -14,8 +14,8 @@ """ import torch +import numpy as np from collections.abc import Hashable - import monai from monai.data.utils import get_random_patch, get_valid_patch_size from monai.networks.layers.simplelayers import GaussianFilter @@ -23,7 +23,7 @@ from monai.transforms.transforms import (LoadNifti, AsChannelFirst, Orientation, AddChannel, Spacing, Rotate90, SpatialCrop, RandAffine, Rand2DElastic, Rand3DElastic, - Flip, Rotate, Zoom) + Rescale, Resize, Flip, Rotate, Zoom) from monai.utils.misc import ensure_tuple from monai.transforms.utils import generate_pos_neg_label_crop_centers, create_grid from monai.utils.aliases import alias @@ -230,17 +230,18 @@ class Rotate90d(MapTransform): dictionary-based wrapper of Rotate90. """ - def __init__(self, keys, k=1, axes=(1, 2)): + def __init__(self, keys, k=1, spatial_axes=(0, 1)): """ Args: k (int): number of times to rotate by 90 degrees. - axes (2 ints): defines the plane to rotate with 2 axes. + spatial_axes (2 ints): defines the plane to rotate with 2 spatial axes. + Default: (0, 1), this is the first two axis in spatial dimensions. """ MapTransform.__init__(self, keys) self.k = k - self.plane_axes = axes + self.spatial_axes = spatial_axes - self.rotator = Rotate90(self.k, self.plane_axes) + self.rotator = Rotate90(self.k, self.spatial_axes) def __call__(self, data): d = dict(data) @@ -249,17 +250,73 @@ def __call__(self, data): return d +@export +@alias('RescaleD', 'RescaleDict') +class Rescaled(MapTransform): + """ + dictionary-based wrapper of Rescale. + """ + + def __init__(self, keys, minv=0.0, maxv=1.0, dtype=np.float32): + MapTransform.__init__(self, keys) + self.rescaler = Rescale(minv, maxv, dtype) + + def __call__(self, data): + d = dict(data) + for key in self.keys: + d[key] = self.rescaler(d[key]) + return d + + +@export +@alias('ResizeD', 'ResizeDict') +class Resized(MapTransform): + """ + dictionary-based wrapper of Resize. + Args: + keys (hashable items): keys of the corresponding items to be transformed. + See also: monai.transform.composables.MapTransform + output_spatial_shape (tuple or list): expected shape of spatial dimensions after resize operation. + order (int): Order of spline interpolation. Default=1. + mode (str): Points outside boundaries are filled according to given mode. + Options are 'constant', 'edge', 'symmetric', 'reflect', 'wrap'. + cval (float): Used with mode 'constant', the value outside image boundaries. + clip (bool): Wheter to clip range of output values after interpolation. Default: True. + preserve_range (bool): Whether to keep original range of values. Default is True. + If False, input is converted according to conventions of img_as_float. See + https://scikit-image.org/docs/dev/user_guide/data_types.html. + anti_aliasing (bool): Whether to apply a gaussian filter to image before down-scaling. + Default is True. + anti_aliasing_sigma (float, tuple of floats): Standard deviation for gaussian filtering. + """ + + def __init__(self, keys, output_spatial_shape, order=1, mode='reflect', cval=0, + clip=True, preserve_range=True, anti_aliasing=True, anti_aliasing_sigma=None): + MapTransform.__init__(self, keys) + self.resizer = Resize(output_spatial_shape, order, mode, cval, clip, preserve_range, + anti_aliasing, anti_aliasing_sigma) + + def __call__(self, data): + d = dict(data) + for key in self.keys: + d[key] = self.resizer(d[key]) + return d + + @export @alias('UniformRandomPatchD', 'UniformRandomPatchDict') class UniformRandomPatchd(Randomizable, MapTransform): """ Selects a patch of the given size chosen at a uniformly random position in the image. + + Args: + patch_spatial_size (tuple or list): Expected patch size of spatial dimensions. """ - def __init__(self, keys, patch_size): + def __init__(self, keys, patch_spatial_size): MapTransform.__init__(self, keys) - self.patch_size = (None,) + tuple(patch_size) + self.patch_spatial_size = (None,) + tuple(patch_spatial_size) self._slices = None @@ -270,8 +327,8 @@ def __call__(self, data): d = dict(data) image_shape = d[self.keys[0]].shape # image shape from the first data key - patch_size = get_valid_patch_size(image_shape, self.patch_size) - self.randomize(image_shape, patch_size) + patch_spatial_size = get_valid_patch_size(image_shape, self.patch_spatial_size) + self.randomize(image_shape, patch_spatial_size) for key in self.keys: d[key] = d[key][self._slices] return d @@ -282,10 +339,10 @@ def __call__(self, data): class RandRotate90d(Randomizable, MapTransform): """ With probability `prob`, input arrays are rotated by 90 degrees - in the plane specified by `axes`. + in the plane specified by `spatial_axes`. """ - def __init__(self, keys, prob=0.1, max_k=3, axes=(1, 2)): + def __init__(self, keys, prob=0.1, max_k=3, spatial_axes=(0, 1)): """ Args: keys (hashable items): keys of the corresponding items to be transformed. @@ -294,14 +351,14 @@ def __init__(self, keys, prob=0.1, max_k=3, axes=(1, 2)): (Default 0.1, with 10% probability it returns a rotated array.) max_k (int): number of rotations will be sampled from `np.random.randint(max_k) + 1`. (Default 3) - axes (2 ints): defines the plane to rotate with 2 axes. - (Default to (1, 2)) + spatial_axes (2 ints): defines the plane to rotate with 2 spatial axes. + Default: (0, 1), this is the first two axis in spatial dimensions. """ MapTransform.__init__(self, keys) self.prob = min(max(prob, 0.0), 1.0) self.max_k = max_k - self.axes = axes + self.spatial_axes = spatial_axes self._do_transform = False self._rand_k = 0 @@ -315,7 +372,7 @@ def __call__(self, data): if not self._do_transform: return data - rotator = Rotate90(self._rand_k, self.axes) + rotator = Rotate90(self._rand_k, self.spatial_axes) d = dict(data) for key in self.keys: d[key] = rotator(d[key]) @@ -599,15 +656,17 @@ def __call__(self, data): @alias('FlipD', 'FlipDict') class Flipd(MapTransform): """Dictionary-based wrapper of Flip. + See numpy.flip for additional details. + https://docs.scipy.org/doc/numpy/reference/generated/numpy.flip.html Args: keys (dict): Keys to pick data for transformation. - axis (None, int or tuple of ints): Axes along which to flip over. Default is None. + spatial_axis (None, int or tuple of ints): Spatial axes along which to flip over. Default is None. """ - def __init__(self, keys, axis=None): + def __init__(self, keys, spatial_axis=None): MapTransform.__init__(self, keys) - self.flipper = Flip(axis=axis) + self.flipper = Flip(spatial_axis=spatial_axis) def __call__(self, data): d = dict(data) @@ -620,19 +679,21 @@ def __call__(self, data): @alias('RandFlipD', 'RandFlipDict') class RandFlipd(Randomizable, MapTransform): """Dict-based wrapper of RandFlip. + See numpy.flip for additional details. + https://docs.scipy.org/doc/numpy/reference/generated/numpy.flip.html Args: prob (float): Probability of flipping. - axis (None, int or tuple of ints): Axes along which to flip over. Default is None. + spatial_axis (None, int or tuple of ints): Spatial axes along which to flip over. Default is None. """ - def __init__(self, keys, prob=0.1, axis=None): + def __init__(self, keys, prob=0.1, spatial_axis=None): MapTransform.__init__(self, keys) - self.axis = axis + self.spatial_axis = spatial_axis self.prob = prob self._do_transform = False - self.flipper = Flip(axis=axis) + self.flipper = Flip(spatial_axis=spatial_axis) def randomize(self): self._do_transform = self.R.random_sample() < self.prob @@ -655,8 +716,8 @@ class Rotated(MapTransform): Args: keys (dict): Keys to pick data for transformation. angle (float): Rotation angle in degrees. - axes (tuple of 2 ints): Axes of rotation. Default: (1, 2). This is the first two - axis in spatial dimensions according to MONAI channel first shape assumption. + spatial_axes (tuple of 2 ints): Spatial axes of rotation. Default: (0, 1). + This is the first two axis in spatial dimensions. reshape (bool): If true, output shape is made same as input. Default: True. order (int): Order of spline interpolation. Range 0-5. Default: 1. This is different from scipy where default interpolation is 3. @@ -666,10 +727,10 @@ class Rotated(MapTransform): prefiter (bool): Apply spline_filter before interpolation. Default: True. """ - def __init__(self, keys, angle, axes=(1, 2), reshape=True, order=1, + def __init__(self, keys, angle, spatial_axes=(0, 1), reshape=True, order=1, mode='constant', cval=0, prefilter=True): MapTransform.__init__(self, keys) - self.rotator = Rotate(angle=angle, axes=axes, reshape=reshape, + self.rotator = Rotate(angle=angle, spatial_axes=spatial_axes, reshape=reshape, order=order, mode=mode, cval=cval, prefilter=prefilter) def __call__(self, data): @@ -688,8 +749,8 @@ class RandRotated(Randomizable, MapTransform): prob (float): Probability of rotation. degrees (tuple of float or float): Range of rotation in degrees. If single number, angle is picked from (-degrees, degrees). - axes (tuple of 2 ints): Axes of rotation. Default: (1, 2). This is the first two - axis in spatial dimensions according to MONAI channel first shape assumption. + spatial_axes (tuple of 2 ints): Spatial axes of rotation. Default: (0, 1). + This is the first two axis in spatial dimensions. reshape (bool): If true, output shape is made same as input. Default: True. order (int): Order of spline interpolation. Range 0-5. Default: 1. This is different from scipy where default interpolation is 3. @@ -698,7 +759,7 @@ class RandRotated(Randomizable, MapTransform): cval (scalar): Value to fill outside boundary. Default: 0. prefiter (bool): Apply spline_filter before interpolation. Default: True. """ - def __init__(self, keys, degrees, prob=0.1, axes=(1, 2), reshape=True, order=1, + def __init__(self, keys, degrees, prob=0.1, spatial_axes=(0, 1), reshape=True, order=1, mode='constant', cval=0, prefilter=True): MapTransform.__init__(self, keys) self.prob = prob @@ -708,7 +769,7 @@ def __init__(self, keys, degrees, prob=0.1, axes=(1, 2), reshape=True, order=1, self.mode = mode self.cval = cval self.prefilter = prefilter - self.axes = axes + self.spatial_axes = spatial_axes if not hasattr(self.degrees, '__iter__'): self.degrees = (-self.degrees, self.degrees) @@ -726,10 +787,10 @@ def __call__(self, data): d = dict(data) if not self._do_transform: return d - rotator = Rotate(self.angle, self.axes, self.reshape, self.order, + rotator = Rotate(self.angle, self.spatial_axes, self.reshape, self.order, self.mode, self.cval, self.prefilter) for key in self.keys: - d[key] = self.flipper(d[key]) + d[key] = rotator(d[key]) return d @@ -772,7 +833,11 @@ class RandZoomd(Randomizable, MapTransform): keys (dict): Keys to pick data for transformation. prob (float): Probability of zooming. min_zoom (float or sequence): Min zoom factor. Can be float or sequence same size as image. + If a float, min_zoom is the same for each spatial axis. + If a sequence, min_zoom should contain one value for each spatial axis. max_zoom (float or sequence): Max zoom factor. Can be float or sequence same size as image. + If a float, max_zoom is the same for each spatial axis. + If a sequence, max_zoom should contain one value for each spatial axis. order (int): order of interpolation. Default=3. mode ('reflect', 'constant', 'nearest', 'mirror', 'wrap'): Determines how input is extended beyond boundaries. Default: 'constant'. diff --git a/monai/transforms/transforms.py b/monai/transforms/transforms.py index 8f140972f6..0e326c6f73 100644 --- a/monai/transforms/transforms.py +++ b/monai/transforms/transforms.py @@ -197,6 +197,16 @@ def __call__(self, filename): class AsChannelFirst: """ Change the channel dimension of the image to the first dimension. + + Most of the image transformations in ``monai.transforms`` + assumes the input image is in the channel-first format, which has the shape + (num_channels, spatial_dim_1[, spatial_dim_2, ...]). + + This transform could be used to convert, for example, a channel-last image array in shape + (spatial_dim_1[, spatial_dim_2, ...], num_channels) into the channel-first format, + so that the multidimensional image array can be correctly interpreted by the other + transforms. + Args: channel_dim (int): which dimension of input image is the channel, default is the last dimension. """ @@ -213,6 +223,15 @@ def __call__(self, img): class AddChannel: """ Adds a 1-length channel dimension to the input image. + + Most of the image transformations in ``monai.transforms`` + assumes the input image is in the channel-first format, which has the shape + (num_channels, spatial_dim_1[, spatial_dim_2, ...]). + + This transform could be used, for example, to convert a (spatial_dim_1[, spatial_dim_2, ...]) + spatial image into the channel-first format so that the + multidimensional image array can be correctly interpreted by the other + transforms. """ def __call__(self, img): @@ -253,8 +272,7 @@ class GaussianNoise(Randomizable): Args: mean (float or array of floats): Mean or “centre” of the distribution. - scale (float): Standard deviation (spread) of distribution. - size (int or tuple of ints): Output shape. Default: None (single value is returned). + std (float): Standard deviation (spread) of distribution. """ def __init__(self, mean=0.0, std=0.1): @@ -267,19 +285,28 @@ def __call__(self, img): @export class Flip: - """Reverses the order of elements along the given axis. Preserves shape. + """Reverses the order of elements along the given spatial axis. Preserves shape. Uses ``np.flip`` in practice. See numpy.flip for additional details. https://docs.scipy.org/doc/numpy/reference/generated/numpy.flip.html Args: - axis (None, int or tuple of ints): Axes along which to flip over. Default is None. + spatial_axis (None, int or tuple of ints): spatial axes along which to flip over. Default is None. """ - def __init__(self, axis=None): - self.axis = axis + def __init__(self, spatial_axis=None): + self.spatial_axis = spatial_axis def __call__(self, img): - return np.flip(img, self.axis) + """ + Args: + img (ndarray): channel first array, must have shape: (num_channels, H[, W, ..., ]), + """ + flipped = list() + for channel in img: + flipped.append( + np.flip(channel, self.spatial_axis) + ) + return np.stack(flipped) @export @@ -289,6 +316,7 @@ class Resize: For additional details, see https://scikit-image.org/docs/dev/api/skimage.transform.html#skimage.transform.resize. Args: + output_spatial_shape (tuple or list): expected shape of spatial dimensions after resize operation. order (int): Order of spline interpolation. Default=1. mode (str): Points outside boundaries are filled according to given mode. Options are 'constant', 'edge', 'symmetric', 'reflect', 'wrap'. @@ -302,10 +330,10 @@ class Resize: anti_aliasing_sigma (float, tuple of floats): Standard deviation for gaussian filtering. """ - def __init__(self, output_shape, order=1, mode='reflect', cval=0, + def __init__(self, output_spatial_shape, order=1, mode='reflect', cval=0, clip=True, preserve_range=True, anti_aliasing=True, anti_aliasing_sigma=None): assert isinstance(order, int), "order must be integer." - self.output_shape = output_shape + self.output_spatial_shape = output_spatial_shape self.order = order self.mode = mode self.cval = cval @@ -315,11 +343,20 @@ def __init__(self, output_shape, order=1, mode='reflect', cval=0, self.anti_aliasing_sigma = anti_aliasing_sigma def __call__(self, img): - return resize(img, self.output_shape, order=self.order, - mode=self.mode, cval=self.cval, - clip=self.clip, preserve_range=self.preserve_range, - anti_aliasing=self.anti_aliasing, - anti_aliasing_sigma=self.anti_aliasing_sigma) + """ + Args: + img (ndarray): channel first array, must have shape: (num_channels, H[, W, ..., ]), + """ + resized = list() + for channel in img: + resized.append( + resize(channel, self.output_spatial_shape, order=self.order, + mode=self.mode, cval=self.cval, + clip=self.clip, preserve_range=self.preserve_range, + anti_aliasing=self.anti_aliasing, + anti_aliasing_sigma=self.anti_aliasing_sigma) + ) + return np.stack(resized).astype(np.float32) @export @@ -330,8 +367,8 @@ class Rotate: Args: angle (float): Rotation angle in degrees. - axes (tuple of 2 ints): Axes of rotation. Default: (1, 2). This is the first two - axis in spatial dimensions according to MONAI channel first shape assumption. + spatial_axes (tuple of 2 ints): Spatial axes of rotation. Default: (0, 1). + This is the first two axis in spatial dimensions. reshape (bool): If true, output shape is made same as input. Default: True. order (int): Order of spline interpolation. Range 0-5. Default: 1. This is different from scipy where default interpolation is 3. @@ -341,19 +378,27 @@ class Rotate: prefiter (bool): Apply spline_filter before interpolation. Default: True. """ - def __init__(self, angle, axes=(1, 2), reshape=True, order=1, mode='constant', cval=0, prefilter=True): + def __init__(self, angle, spatial_axes=(0, 1), reshape=True, order=1, mode='constant', cval=0, prefilter=True): self.angle = angle self.reshape = reshape self.order = order self.mode = mode self.cval = cval self.prefilter = prefilter - self.axes = axes + self.spatial_axes = spatial_axes def __call__(self, img): - return scipy.ndimage.rotate(img, self.angle, self.axes, - reshape=self.reshape, order=self.order, mode=self.mode, cval=self.cval, - prefilter=self.prefilter) + """ + Args: + img (ndarray): channel first array, must have shape: (num_channels, H[, W, ..., ]), + """ + rotated = list() + for channel in img: + rotated.append( + scipy.ndimage.rotate(channel, self.angle, self.spatial_axes, reshape=self.reshape, + order=self.order, mode=self.mode, cval=self.cval, prefilter=self.prefilter) + ) + return np.stack(rotated).astype(np.float32) @export @@ -400,7 +445,7 @@ def __call__(self, img): Args: img (ndarray): channel first array, must have shape: (num_channels, H[, W, ..., ]), """ - zoomed = [] + zoomed = list() if self.use_gpu: import cupy for channel in cupy.array(img): @@ -420,7 +465,7 @@ def __call__(self, img): mode=self.mode, cval=self.cval, prefilter=self.prefilter)) - zoomed = np.stack(zoomed) + zoomed = np.stack(zoomed).astype(np.float32) if not self.keep_size or np.allclose(img.shape, zoomed.shape): return zoomed @@ -452,10 +497,13 @@ def __call__(self, img): class UniformRandomPatch(Randomizable): """ Selects a patch of the given size chosen at a uniformly random position in the image. + + Args: + patch_spatial_size (tuple or list): Expected patch size of spatial dimensions. """ - def __init__(self, patch_size): - self.patch_size = (None,) + tuple(patch_size) + def __init__(self, patch_spatial_size): + self.patch_spatial_size = (None,) + tuple(patch_spatial_size) self._slices = None @@ -463,8 +511,8 @@ def randomize(self, image_shape, patch_shape): self._slices = get_random_patch(image_shape, patch_shape, self.R) def __call__(self, img): - patch_size = get_valid_patch_size(img.shape, self.patch_size) - self.randomize(img.shape, patch_size) + patch_spatial_size = get_valid_patch_size(img.shape, self.patch_spatial_size) + self.randomize(img.shape, patch_spatial_size) return img[self._slices] @@ -478,16 +526,14 @@ class IntensityNormalizer: Args: subtrahend (ndarray): the amount to subtract by (usually the mean) divisor (ndarray): the amount to divide by (usually the standard deviation) - dtype: output data format """ - def __init__(self, subtrahend=None, divisor=None, dtype=np.float32): + def __init__(self, subtrahend=None, divisor=None): if subtrahend is not None or divisor is not None: assert isinstance(subtrahend, np.ndarray) and isinstance(divisor, np.ndarray), \ 'subtrahend and divisor must be set in pair and in numpy array.' self.subtrahend = subtrahend self.divisor = divisor - self.dtype = dtype def __call__(self, img): if self.subtrahend is not None and self.divisor is not None: @@ -497,8 +543,6 @@ def __call__(self, img): img -= np.mean(img) img /= np.std(img) - if self.dtype != img.dtype: - img = img.astype(self.dtype) return img @@ -511,15 +555,13 @@ class ImageEndPadder: Args: out_size (list): the size of region of interest at the end of the operation. mode (string): a portion from numpy.lib.arraypad.pad is copied below. - dtype: output data format. """ - def __init__(self, out_size, mode, dtype=np.float32): - assert out_size is not None and isinstance(out_size, (list, tuple)), 'out_size must be list or tuple' + def __init__(self, out_size, mode): + assert out_size is not None and isinstance(out_size, (list, tuple)), 'out_size must be list or tuple.' self.out_size = out_size - assert isinstance(mode, str), 'mode must be str' + assert isinstance(mode, str), 'mode must be str.' self.mode = mode - self.dtype = dtype def _determine_data_pad_width(self, data_shape): return [(0, max(self.out_size[i] - data_shape[i], 0)) for i in range(len(self.out_size))] @@ -537,39 +579,49 @@ class Rotate90: Rotate an array by 90 degrees in the plane specified by `axes`. """ - def __init__(self, k=1, axes=(1, 2)): + def __init__(self, k=1, spatial_axes=(0, 1)): """ Args: k (int): number of times to rotate by 90 degrees. - axes (2 ints): defines the plane to rotate with 2 axes. + spatial_axes (2 ints): defines the plane to rotate with 2 spatial axes. + Default: (0, 1), this is the first two axis in spatial dimensions. """ self.k = k - self.plane_axes = axes + self.spatial_axes = spatial_axes def __call__(self, img): - return np.ascontiguousarray(np.rot90(img, self.k, self.plane_axes)) + """ + Args: + img (ndarray): channel first array, must have shape: (num_channels, H[, W, ..., ]), + """ + rotated = list() + for channel in img: + rotated.append( + np.rot90(channel, self.k, self.spatial_axes) + ) + return np.stack(rotated) @export class RandRotate90(Randomizable): """ With probability `prob`, input arrays are rotated by 90 degrees - in the plane specified by `axes`. + in the plane specified by `spatial_axes`. """ - def __init__(self, prob=0.1, max_k=3, axes=(1, 2)): + def __init__(self, prob=0.1, max_k=3, spatial_axes=(0, 1)): """ Args: prob (float): probability of rotating. (Default 0.1, with 10% probability it returns a rotated array) max_k (int): number of rotations will be sampled from `np.random.randint(max_k) + 1`. (Default 3) - axes (2 ints): defines the plane to rotate with 2 axes. - (Default (1, 2)) + spatial_axes (2 ints): defines the plane to rotate with 2 spatial axes. + Default: (0, 1), this is the first two axis in spatial dimensions. """ self.prob = min(max(prob, 0.0), 1.0) self.max_k = max_k - self.axes = axes + self.spatial_axes = spatial_axes self._do_transform = False self._rand_k = 0 @@ -582,7 +634,7 @@ def __call__(self, img): self.randomize() if not self._do_transform: return img - rotator = Rotate90(self._rand_k, self.axes) + rotator = Rotate90(self._rand_k, self.spatial_axes) return rotator(img) @@ -590,7 +642,7 @@ def __call__(self, img): class SpatialCrop: """General purpose cropper to produce sub-volume region of interest (ROI). It can support to crop ND spatial (channel-first) data. - Either a center and size must be provided, or alternatively if center and size + Either a spatial center and size must be provided, or alternatively if center and size are not provided, the start and end coordinates of the ROI must be provided. The sub-volume must sit the within original image. @@ -638,8 +690,8 @@ class RandRotate(Randomizable): prob (float): Probability of rotation. degrees (tuple of float or float): Range of rotation in degrees. If single number, angle is picked from (-degrees, degrees). - axes (tuple of 2 ints): Axes of rotation. Default: (1, 2). This is the first two - axis in spatial dimensions according to MONAI channel first shape assumption. + spatial_axes (tuple of 2 ints): Spatial axes of rotation. Default: (0, 1). + This is the first two axis in spatial dimensions. reshape (bool): If true, output shape is made same as input. Default: True. order (int): Order of spline interpolation. Range 0-5. Default: 1. This is different from scipy where default interpolation is 3. @@ -649,7 +701,7 @@ class RandRotate(Randomizable): prefiter (bool): Apply spline_filter before interpolation. Default: True. """ - def __init__(self, degrees, prob=0.1, axes=(1, 2), reshape=True, order=1, + def __init__(self, degrees, prob=0.1, spatial_axes=(0, 1), reshape=True, order=1, mode='constant', cval=0, prefilter=True): self.prob = prob self.degrees = degrees @@ -658,7 +710,7 @@ def __init__(self, degrees, prob=0.1, axes=(1, 2), reshape=True, order=1, self.mode = mode self.cval = cval self.prefilter = prefilter - self.axes = axes + self.spatial_axes = spatial_axes if not hasattr(self.degrees, '__iter__'): self.degrees = (-self.degrees, self.degrees) @@ -675,7 +727,7 @@ def __call__(self, img): self.randomize() if not self._do_transform: return img - rotator = Rotate(self.angle, self.axes, self.reshape, self.order, + rotator = Rotate(self.angle, self.spatial_axes, self.reshape, self.order, self.mode, self.cval, self.prefilter) return rotator(img) @@ -688,15 +740,13 @@ class RandFlip(Randomizable): Args: prob (float): Probability of flipping. - axis (None, int or tuple of ints): Axes along which to flip over. Default is None. + spatial_axis (None, int or tuple of ints): Spatial axes along which to flip over. Default is None. """ - def __init__(self, prob=0.1, axis=None): + def __init__(self, prob=0.1, spatial_axis=None): self.prob = prob - self.flipper = Flip(axis=axis) - + self.flipper = Flip(spatial_axis=spatial_axis) self._do_transform = False - self.flipper = Flip(axis=axis) def randomize(self): self._do_transform = self.R.random_sample() < self.prob @@ -715,7 +765,11 @@ class RandZoom(Randomizable): Args: prob (float): Probability of zooming. min_zoom (float or sequence): Min zoom factor. Can be float or sequence same size as image. + If a float, min_zoom is the same for each spatial axis. + If a sequence, min_zoom should contain one value for each spatial axis. max_zoom (float or sequence): Max zoom factor. Can be float or sequence same size as image. + If a float, max_zoom is the same for each spatial axis. + If a sequence, max_zoom should contain one value for each spatial axis. order (int): order of interpolation. Default=3. mode ('reflect', 'constant', 'nearest', 'mirror', 'wrap'): Determines how input is extended beyond boundaries. Default: 'constant'. @@ -985,7 +1039,7 @@ def __init__(self, Args: rotate_params (float, list of floats): a rotation angle in radians, - a scalar for 2D image, a tuple of 2 floats for 3D. Defaults to no rotation. + a scalar for 2D image, a tuple of 3 floats for 3D. Defaults to no rotation. shear_params (list of floats): a tuple of 2 floats for 2D, a tuple of 6 floats for 3D. Defaults to no shearing. translate_params (list of floats): diff --git a/tests/integration_sliding_window.py b/tests/integration_sliding_window.py index 31c7a1248a..b99bb5c681 100644 --- a/tests/integration_sliding_window.py +++ b/tests/integration_sliding_window.py @@ -40,7 +40,7 @@ def run_test(batch_size=2, device=torch.device("cpu:0")): net = UNet( dimensions=3, in_channels=1, - num_classes=1, + out_channels=1, channels=(4, 8, 16, 32), strides=(2, 2, 2), num_res_units=2, @@ -52,7 +52,7 @@ def _sliding_window_processor(_engine, batch): net.eval() img, seg, meta_data = batch with torch.no_grad(): - seg_probs = sliding_window_inference(img, roi_size, sw_batch_size, lambda x: net(x)[0], device) + seg_probs = sliding_window_inference(img, roi_size, sw_batch_size, net, device) return predict_segmentation(seg_probs) infer_engine = Engine(_sliding_window_processor) diff --git a/tests/integration_unet2d.py b/tests/integration_unet2d.py index 7b0f116b77..d437258c10 100644 --- a/tests/integration_unet2d.py +++ b/tests/integration_unet2d.py @@ -35,7 +35,7 @@ def __len__(self): net = UNet( dimensions=2, in_channels=1, - num_classes=1, + out_channels=1, channels=(4, 8, 16, 32), strides=(2, 2, 2), num_res_units=2, @@ -45,10 +45,7 @@ def __len__(self): opt = torch.optim.Adam(net.parameters(), 1e-4) src = DataLoader(_TestBatch(), batch_size=batch_size) - def loss_fn(pred, grnd): - return loss(pred[0], grnd) - - trainer = create_supervised_trainer(net, opt, loss_fn, device, False) + trainer = create_supervised_trainer(net, opt, loss, device, False) trainer.run(src, 1) loss = trainer.state.output diff --git a/tests/test_dice_loss.py b/tests/test_dice_loss.py index c5640a5660..e937185f91 100644 --- a/tests/test_dice_loss.py +++ b/tests/test_dice_loss.py @@ -82,7 +82,7 @@ TEST_CASE_6 = [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) { - 'include_background': False, + 'include_background': True, 'do_sigmoid': True, }, { diff --git a/tests/test_flip.py b/tests/test_flip.py index a261c315e2..050d66d8db 100644 --- a/tests/test_flip.py +++ b/tests/test_flip.py @@ -14,7 +14,7 @@ import numpy as np from parameterized import parameterized -from monai.transforms import Flip, Flipd +from monai.transforms import Flip from tests.utils import NumpyImageTestCase2D INVALID_CASES = [("wrong_axis", ['s', 1], TypeError), @@ -22,35 +22,25 @@ VALID_CASES = [("no_axis", None), ("one_axis", 1), - ("many_axis", [0, 1, 2])] + ("many_axis", [0, 1])] -class FlipTest(NumpyImageTestCase2D): +class TestFlip(NumpyImageTestCase2D): @parameterized.expand(INVALID_CASES) - def test_invalid_inputs(self, _, axis, raises): + def test_invalid_inputs(self, _, spatial_axis, raises): with self.assertRaises(raises): - flip = Flip(axis) - flip(self.imt) - - @parameterized.expand(INVALID_CASES) - def test_invalid_cases_dict(self, _, axis, raises): - with self.assertRaises(raises): - flip = Flipd(keys='img', axis=axis) - flip({'img': self.imt}) - - @parameterized.expand(VALID_CASES) - def test_correct_results(self, _, axis): - flip = Flip(axis=axis) - expected = np.flip(self.imt, axis) - self.assertTrue(np.allclose(expected, flip(self.imt))) + flip = Flip(spatial_axis) + flip(self.imt[0]) @parameterized.expand(VALID_CASES) - def test_correct_results_dict(self, _, axis): - flip = Flipd(keys='img', axis=axis) - expected = np.flip(self.imt, axis) - res = flip({'img': self.imt}) - assert np.allclose(expected, res['img']) + def test_correct_results(self, _, spatial_axis): + flip = Flip(spatial_axis=spatial_axis) + expected = list() + for channel in self.imt[0]: + expected.append(np.flip(channel, spatial_axis)) + expected = np.stack(expected) + self.assertTrue(np.allclose(expected, flip(self.imt[0]))) if __name__ == '__main__': diff --git a/tests/test_flipd.py b/tests/test_flipd.py new file mode 100644 index 0000000000..e2fcb6b915 --- /dev/null +++ b/tests/test_flipd.py @@ -0,0 +1,48 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import Flipd +from tests.utils import NumpyImageTestCase2D + +INVALID_CASES = [("wrong_axis", ['s', 1], TypeError), + ("not_numbers", 's', TypeError)] + +VALID_CASES = [("no_axis", None), + ("one_axis", 1), + ("many_axis", [0, 1])] + + +class TestFlipd(NumpyImageTestCase2D): + + @parameterized.expand(INVALID_CASES) + def test_invalid_cases(self, _, spatial_axis, raises): + with self.assertRaises(raises): + flip = Flipd(keys='img', spatial_axis=spatial_axis) + flip({'img': self.imt[0]}) + + @parameterized.expand(VALID_CASES) + def test_correct_results(self, _, spatial_axis): + flip = Flipd(keys='img', spatial_axis=spatial_axis) + expected = list() + for channel in self.imt[0]: + expected.append(np.flip(channel, spatial_axis)) + expected = np.stack(expected) + res = flip({'img': self.imt[0]}) + assert np.allclose(expected, res['img']) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_generalized_dice_loss.py b/tests/test_generalized_dice_loss.py index e08ff1d296..b2ce96169e 100644 --- a/tests/test_generalized_dice_loss.py +++ b/tests/test_generalized_dice_loss.py @@ -94,7 +94,7 @@ TEST_CASE_6 = [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) { - 'include_background': False, + 'include_background': True, 'do_sigmoid': True, }, { diff --git a/tests/test_handler_classification_saver.py b/tests/test_handler_classification_saver.py new file mode 100644 index 0000000000..3eea9d86ed --- /dev/null +++ b/tests/test_handler_classification_saver.py @@ -0,0 +1,55 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import csv +import shutil +import unittest +import numpy as np +import torch +from ignite.engine import Engine + +from monai.handlers.classification_saver import ClassificationSaver + + +class TestHandlerClassificationSaver(unittest.TestCase): + + def test_saved_content(self): + default_dir = os.path.join('.', 'tempdir') + shutil.rmtree(default_dir, ignore_errors=True) + + # set up engine + def _train_func(engine, batch): + return torch.zeros(8) + + engine = Engine(_train_func) + + # set up testing handler + saver = ClassificationSaver(output_dir=default_dir) + saver.attach(engine) + + data = [{'filename_or_obj': ['testfile' + str(i) for i in range(8)]}] + engine.run(data, epoch_length=2, max_epochs=1) + filepath = os.path.join(default_dir, 'predictions.csv') + self.assertTrue(os.path.exists(filepath)) + with open(filepath, 'r') as f: + reader = csv.reader(f) + i = 0 + for row in reader: + self.assertEqual(row[0], 'testfile' + str(i)) + self.assertEqual(np.array(row[1:]).astype(np.float32), 0.0) + i += 1 + self.assertEqual(i, 8) + shutil.rmtree(default_dir) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_random_affine.py b/tests/test_rand_affine.py similarity index 100% rename from tests/test_random_affine.py rename to tests/test_rand_affine.py diff --git a/tests/test_random_affine_grid.py b/tests/test_rand_affine_grid.py similarity index 100% rename from tests/test_random_affine_grid.py rename to tests/test_rand_affine_grid.py diff --git a/tests/test_random_affined.py b/tests/test_rand_affined.py similarity index 100% rename from tests/test_random_affined.py rename to tests/test_rand_affined.py diff --git a/tests/test_random_deform_grid.py b/tests/test_rand_deform_grid.py similarity index 100% rename from tests/test_random_deform_grid.py rename to tests/test_rand_deform_grid.py diff --git a/tests/test_random_elastic_2d.py b/tests/test_rand_elastic_2d.py similarity index 100% rename from tests/test_random_elastic_2d.py rename to tests/test_rand_elastic_2d.py diff --git a/tests/test_random_elastic_3d.py b/tests/test_rand_elastic_3d.py similarity index 100% rename from tests/test_random_elastic_3d.py rename to tests/test_rand_elastic_3d.py diff --git a/tests/test_random_elasticd_2d.py b/tests/test_rand_elasticd_2d.py similarity index 100% rename from tests/test_random_elasticd_2d.py rename to tests/test_rand_elasticd_2d.py diff --git a/tests/test_random_elasticd_3d.py b/tests/test_rand_elasticd_3d.py similarity index 100% rename from tests/test_random_elasticd_3d.py rename to tests/test_rand_elasticd_3d.py diff --git a/tests/test_rand_flip.py b/tests/test_rand_flip.py index be03ff5a28..1206c85571 100644 --- a/tests/test_rand_flip.py +++ b/tests/test_rand_flip.py @@ -14,7 +14,7 @@ import numpy as np from parameterized import parameterized -from monai.transforms import RandFlip, RandFlipd +from monai.transforms import RandFlip from tests.utils import NumpyImageTestCase2D INVALID_CASES = [("wrong_axis", ['s', 1], TypeError), @@ -22,29 +22,24 @@ VALID_CASES = [("no_axis", None), ("one_axis", 1), - ("many_axis", [0, 1, 2])] + ("many_axis", [0, 1])] -class RandFlipTest(NumpyImageTestCase2D): +class TestRandFlip(NumpyImageTestCase2D): @parameterized.expand(INVALID_CASES) - def test_invalid_inputs(self, _, axis, raises): + def test_invalid_inputs(self, _, spatial_axis, raises): with self.assertRaises(raises): - flip = RandFlip(prob=1.0, axis=axis) - flip(self.imt) + flip = RandFlip(prob=1.0, spatial_axis=spatial_axis) + flip(self.imt[0]) @parameterized.expand(VALID_CASES) - def test_correct_results(self, _, axis): - flip = RandFlip(prob=1.0, axis=axis) - expected = np.flip(self.imt, axis) - self.assertTrue(np.allclose(expected, flip(self.imt))) - - @parameterized.expand(VALID_CASES) - def test_correct_results_dict(self, _, axis): - flip = RandFlipd(keys='img', prob=1.0, axis=axis) - res = flip({'img': self.imt}) - - expected = np.flip(self.imt, axis) - self.assertTrue(np.allclose(expected, res['img'])) + def test_correct_results(self, _, spatial_axis): + flip = RandFlip(prob=1.0, spatial_axis=spatial_axis) + expected = list() + for channel in self.imt[0]: + expected.append(np.flip(channel, spatial_axis)) + expected = np.stack(expected) + self.assertTrue(np.allclose(expected, flip(self.imt[0]))) if __name__ == '__main__': diff --git a/tests/test_rand_flipd.py b/tests/test_rand_flipd.py new file mode 100644 index 0000000000..bcda54eecd --- /dev/null +++ b/tests/test_rand_flipd.py @@ -0,0 +1,38 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import RandFlipd +from tests.utils import NumpyImageTestCase2D + +VALID_CASES = [("no_axis", None), + ("one_axis", 1), + ("many_axis", [0, 1])] + +class TestRandFlipd(NumpyImageTestCase2D): + + @parameterized.expand(VALID_CASES) + def test_correct_results(self, _, spatial_axis): + flip = RandFlipd(keys='img', prob=1.0, spatial_axis=spatial_axis) + res = flip({'img': self.imt[0]}) + expected = list() + for channel in self.imt[0]: + expected.append(np.flip(channel, spatial_axis)) + expected = np.stack(expected) + self.assertTrue(np.allclose(expected, res['img'])) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_rand_rotate.py b/tests/test_rand_rotate.py index 29036663af..1e5a18bfc8 100644 --- a/tests/test_rand_rotate.py +++ b/tests/test_rand_rotate.py @@ -19,23 +19,26 @@ from tests.utils import NumpyImageTestCase2D -class RandomRotateTest(NumpyImageTestCase2D): +class TestRandRotate(NumpyImageTestCase2D): @parameterized.expand([ - (90, (1, 2), True, 1, 'reflect', 0, True), - ((-45, 45), (2, 1), True, 3, 'constant', 0, True), - (180, (2, 3), False, 2, 'constant', 4, False), + (90, (0, 1), True, 1, 'reflect', 0, True), + ((-45, 45), (1, 0), True, 3, 'constant', 0, True), + (180, (1, 0), False, 2, 'constant', 4, False), ]) - def test_correct_results(self, degrees, axes, reshape, + def test_correct_results(self, degrees, spatial_axes, reshape, order, mode, cval, prefilter): - rotate_fn = RandRotate(degrees, prob=1.0, axes=axes, reshape=reshape, + rotate_fn = RandRotate(degrees, prob=1.0, spatial_axes=spatial_axes, reshape=reshape, order=order, mode=mode, cval=cval, prefilter=prefilter) rotate_fn.set_random_state(243) - rotated = rotate_fn(self.imt) + rotated = rotate_fn(self.imt[0]) angle = rotate_fn.angle - expected = scipy.ndimage.rotate(self.imt, angle, axes, reshape, order=order, - mode=mode, cval=cval, prefilter=prefilter) + expected = list() + for channel in self.imt[0]: + expected.append(scipy.ndimage.rotate(channel, angle, spatial_axes, reshape, order=order, + mode=mode, cval=cval, prefilter=prefilter)) + expected = np.stack(expected).astype(np.float32) self.assertTrue(np.allclose(expected, rotated)) diff --git a/tests/test_rand_rotate90.py b/tests/test_rand_rotate90.py index 4b291d8cf0..e50c3e0c67 100644 --- a/tests/test_rand_rotate90.py +++ b/tests/test_rand_rotate90.py @@ -17,34 +17,46 @@ from tests.utils import NumpyImageTestCase2D -class Rotate90Test(NumpyImageTestCase2D): +class TestRandRotate90(NumpyImageTestCase2D): def test_default(self): rotate = RandRotate90() rotate.set_random_state(123) - rotated = rotate(self.imt) - expected = np.rot90(self.imt, 0, (1, 2)) + rotated = rotate(self.imt[0]) + expected = list() + for channel in self.imt[0]: + expected.append(np.rot90(channel, 0, (0, 1))) + expected = np.stack(expected) self.assertTrue(np.allclose(rotated, expected)) def test_k(self): rotate = RandRotate90(max_k=2) rotate.set_random_state(234) - rotated = rotate(self.imt) - expected = np.rot90(self.imt, 0, (1, 2)) + rotated = rotate(self.imt[0]) + expected = list() + for channel in self.imt[0]: + expected.append(np.rot90(channel, 0, (0, 1))) + expected = np.stack(expected) self.assertTrue(np.allclose(rotated, expected)) - def test_axes(self): - rotate = RandRotate90(axes=(1, 2)) + def test_spatial_axes(self): + rotate = RandRotate90(spatial_axes=(0, 1)) rotate.set_random_state(234) - rotated = rotate(self.imt) - expected = np.rot90(self.imt, 0, (1, 2)) + rotated = rotate(self.imt[0]) + expected = list() + for channel in self.imt[0]: + expected.append(np.rot90(channel, 0, (0, 1))) + expected = np.stack(expected) self.assertTrue(np.allclose(rotated, expected)) - def test_prob_k_axes(self): - rotate = RandRotate90(prob=1.0, max_k=2, axes=(2, 3)) + def test_prob_k_spatial_axes(self): + rotate = RandRotate90(prob=1.0, max_k=2, spatial_axes=(0, 1)) rotate.set_random_state(234) - rotated = rotate(self.imt) - expected = np.rot90(self.imt, 1, (2, 3)) + rotated = rotate(self.imt[0]) + expected = list() + for channel in self.imt[0]: + expected.append(np.rot90(channel, 1, (0, 1))) + expected = np.stack(expected) self.assertTrue(np.allclose(rotated, expected)) diff --git a/tests/test_rand_rotate90d.py b/tests/test_rand_rotate90d.py index c52a82389f..193627fef1 100644 --- a/tests/test_rand_rotate90d.py +++ b/tests/test_rand_rotate90d.py @@ -17,45 +17,57 @@ from tests.utils import NumpyImageTestCase2D -class Rotate90Test(NumpyImageTestCase2D): +class TestRandRotate90d(NumpyImageTestCase2D): def test_default(self): key = None rotate = RandRotate90d(keys=key) rotate.set_random_state(123) - rotated = rotate({key: self.imt}) - expected = np.rot90(self.imt, 0, (1, 2)) + rotated = rotate({key: self.imt[0]}) + expected = list() + for channel in self.imt[0]: + expected.append(np.rot90(channel, 0, (0, 1))) + expected = np.stack(expected) self.assertTrue(np.allclose(rotated[key], expected)) def test_k(self): key = 'test' rotate = RandRotate90d(keys=key, max_k=2) rotate.set_random_state(234) - rotated = rotate({key: self.imt}) - expected = np.rot90(self.imt, 0, (1, 2)) + rotated = rotate({key: self.imt[0]}) + expected = list() + for channel in self.imt[0]: + expected.append(np.rot90(channel, 0, (0, 1))) + expected = np.stack(expected) self.assertTrue(np.allclose(rotated[key], expected)) - def test_axes(self): - key = ['test'] - rotate = RandRotate90d(keys=key, axes=(1, 2)) + def test_spatial_axes(self): + key = 'test' + rotate = RandRotate90d(keys=key, spatial_axes=(0, 1)) rotate.set_random_state(234) - rotated = rotate({key[0]: self.imt}) - expected = np.rot90(self.imt, 0, (1, 2)) - self.assertTrue(np.allclose(rotated[key[0]], expected)) + rotated = rotate({key: self.imt[0]}) + expected = list() + for channel in self.imt[0]: + expected.append(np.rot90(channel, 0, (0, 1))) + expected = np.stack(expected) + self.assertTrue(np.allclose(rotated[key], expected)) - def test_prob_k_axes(self): - key = ('test',) - rotate = RandRotate90d(keys=key, prob=1.0, max_k=2, axes=(2, 3)) + def test_prob_k_spatial_axes(self): + key = 'test' + rotate = RandRotate90d(keys=key, prob=1.0, max_k=2, spatial_axes=(0, 1)) rotate.set_random_state(234) - rotated = rotate({key[0]: self.imt}) - expected = np.rot90(self.imt, 1, (2, 3)) - self.assertTrue(np.allclose(rotated[key[0]], expected)) + rotated = rotate({key: self.imt[0]}) + expected = list() + for channel in self.imt[0]: + expected.append(np.rot90(channel, 1, (0, 1))) + expected = np.stack(expected) + self.assertTrue(np.allclose(rotated[key], expected)) def test_no_key(self): key = 'unknown' - rotate = RandRotate90d(keys=key, prob=1.0, max_k=2, axes=(2, 3)) + rotate = RandRotate90d(keys=key, prob=1.0, max_k=2, spatial_axes=(0, 1)) with self.assertRaisesRegex(KeyError, ''): - rotated = rotate({'test': self.imt}) + rotated = rotate({'test': self.imt[0]}) if __name__ == '__main__': diff --git a/tests/test_rand_rotated.py b/tests/test_rand_rotated.py new file mode 100644 index 0000000000..1c9d98e83e --- /dev/null +++ b/tests/test_rand_rotated.py @@ -0,0 +1,46 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np + +import scipy.ndimage +from parameterized import parameterized + +from monai.transforms import RandRotated +from tests.utils import NumpyImageTestCase2D + + +class TestRandRotated(NumpyImageTestCase2D): + + @parameterized.expand([ + (90, (0, 1), True, 1, 'reflect', 0, True), + ((-45, 45), (1, 0), True, 3, 'constant', 0, True), + (180, (1, 0), False, 2, 'constant', 4, False), + ]) + def test_correct_results(self, degrees, spatial_axes, reshape, + order, mode, cval, prefilter): + rotate_fn = RandRotated('img', degrees, prob=1.0, spatial_axes=spatial_axes, reshape=reshape, + order=order, mode=mode, cval=cval, prefilter=prefilter) + rotate_fn.set_random_state(243) + rotated = rotate_fn({'img': self.imt[0]}) + + angle = rotate_fn.angle + expected = list() + for channel in self.imt[0]: + expected.append(scipy.ndimage.rotate(channel, angle, spatial_axes, reshape, order=order, + mode=mode, cval=cval, prefilter=prefilter)) + expected = np.stack(expected).astype(np.float32) + self.assertTrue(np.allclose(expected, rotated['img'])) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_rand_zoom.py b/tests/test_rand_zoom.py index 530504b887..7dfdb7a522 100644 --- a/tests/test_rand_zoom.py +++ b/tests/test_rand_zoom.py @@ -17,12 +17,12 @@ from scipy.ndimage import zoom as zoom_scipy from parameterized import parameterized -from monai.transforms import RandZoom, RandZoomd +from monai.transforms import RandZoom from tests.utils import NumpyImageTestCase2D VALID_CASES = [(0.9, 1.1, 3, 'constant', 0, True, False, False)] -class ZoomTest(NumpyImageTestCase2D): +class TestRandZoom(NumpyImageTestCase2D): @parameterized.expand(VALID_CASES) def test_correct_results(self, min_zoom, max_zoom, order, mode, @@ -31,28 +31,14 @@ def test_correct_results(self, min_zoom, max_zoom, order, mode, mode=mode, cval=cval, prefilter=prefilter, use_gpu=use_gpu, keep_size=keep_size) random_zoom.set_random_state(234) - - zoomed = random_zoom(self.imt) - expected = zoom_scipy(self.imt, zoom=random_zoom._zoom, mode=mode, - order=order, cval=cval, prefilter=prefilter) - + zoomed = random_zoom(self.imt[0]) + expected = list() + for channel in self.imt[0]: + expected.append(zoom_scipy(channel, zoom=random_zoom._zoom, mode=mode, order=order, + cval=cval, prefilter=prefilter)) + expected = np.stack(expected).astype(np.float32) self.assertTrue(np.allclose(expected, zoomed)) - @parameterized.expand(VALID_CASES) - def test_correct_results_dict(self, min_zoom, max_zoom, order, mode, - cval, prefilter, use_gpu, keep_size): - keys = 'img' - random_zoom = RandZoomd(keys, prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, order=order, - mode=mode, cval=cval, prefilter=prefilter, use_gpu=use_gpu, - keep_size=keep_size) - random_zoom.set_random_state(234) - - zoomed = random_zoom({keys: self.imt}) - expected = zoom_scipy(self.imt, zoom=random_zoom._zoom, mode=mode, - order=order, cval=cval, prefilter=prefilter) - - self.assertTrue(np.allclose(expected, zoomed[keys])) - @parameterized.expand([ (0.8, 1.2, 1, 'constant', 0, True) ]) @@ -64,17 +50,20 @@ def test_gpu_zoom(self, min_zoom, max_zoom, order, mode, cval, prefilter): keep_size=False) random_zoom.set_random_state(234) - zoomed = random_zoom(self.imt) - expected = zoom_scipy(self.imt, zoom=random_zoom._zoom, mode=mode, order=order, - cval=cval, prefilter=prefilter) + zoomed = random_zoom(self.imt[0]) + expected = list() + for channel in self.imt[0]: + expected.append(zoom_scipy(channel, zoom=random_zoom._zoom, mode=mode, order=order, + cval=cval, prefilter=prefilter)) + expected = np.stack(expected).astype(np.float32) self.assertTrue(np.allclose(expected, zoomed)) def test_keep_size(self): random_zoom = RandZoom(prob=1.0, min_zoom=0.6, max_zoom=0.7, keep_size=True) - zoomed = random_zoom(self.imt) - self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape)) + zoomed = random_zoom(self.imt[0]) + self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:])) @parameterized.expand([ ("no_min_zoom", None, 1.1, 1, TypeError), @@ -83,7 +72,7 @@ def test_keep_size(self): def test_invalid_inputs(self, _, min_zoom, max_zoom, order, raises): with self.assertRaises(raises): random_zoom = RandZoom(prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, order=order) - zoomed = random_zoom(self.imt) + zoomed = random_zoom(self.imt[0]) if __name__ == '__main__': diff --git a/tests/test_rand_zoomd.py b/tests/test_rand_zoomd.py new file mode 100644 index 0000000000..9a5838da4b --- /dev/null +++ b/tests/test_rand_zoomd.py @@ -0,0 +1,83 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import importlib + +from scipy.ndimage import zoom as zoom_scipy +from parameterized import parameterized + +from monai.transforms import RandZoomd +from tests.utils import NumpyImageTestCase2D + +VALID_CASES = [(0.9, 1.1, 3, 'constant', 0, True, False, False)] + +class TestRandZoomd(NumpyImageTestCase2D): + + @parameterized.expand(VALID_CASES) + def test_correct_results(self, min_zoom, max_zoom, order, mode, + cval, prefilter, use_gpu, keep_size): + key = 'img' + random_zoom = RandZoomd(key, prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, order=order, + mode=mode, cval=cval, prefilter=prefilter, use_gpu=use_gpu, + keep_size=keep_size) + random_zoom.set_random_state(234) + + zoomed = random_zoom({key: self.imt[0]}) + expected = list() + for channel in self.imt[0]: + expected.append(zoom_scipy(channel, zoom=random_zoom._zoom, mode=mode, order=order, + cval=cval, prefilter=prefilter)) + expected = np.stack(expected).astype(np.float32) + self.assertTrue(np.allclose(expected, zoomed[key])) + + @parameterized.expand([ + (0.8, 1.2, 1, 'constant', 0, True) + ]) + def test_gpu_zoom(self, min_zoom, max_zoom, order, mode, cval, prefilter): + key = 'img' + if importlib.util.find_spec('cupy'): + random_zoom = RandZoomd( + key, prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, order=order, + mode=mode, cval=cval, prefilter=prefilter, use_gpu=True, + keep_size=False) + random_zoom.set_random_state(234) + + zoomed = random_zoom({key: self.imt[0]}) + expected = list() + for channel in self.imt[0]: + expected.append(zoom_scipy(channel, zoom=random_zoom._zoom, mode=mode, order=order, + cval=cval, prefilter=prefilter)) + expected = np.stack(expected).astype(np.float32) + self.assertTrue(np.allclose(expected, zoomed)) + + def test_keep_size(self): + key = 'img' + random_zoom = RandZoomd(key, prob=1.0, min_zoom=0.6, + max_zoom=0.7, keep_size=True) + zoomed = random_zoom({key: self.imt[0]}) + self.assertTrue(np.array_equal(zoomed[key].shape, self.imt.shape[1:])) + + @parameterized.expand([ + ("no_min_zoom", None, 1.1, 1, TypeError), + ("invalid_order", 0.9, 1.1 , 's', AssertionError) + ]) + def test_invalid_inputs(self, _, min_zoom, max_zoom, order, raises): + key = 'img' + with self.assertRaises(raises): + random_zoom = RandZoomd(key, prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, order=order) + zoomed = random_zoom({key: self.imt[0]}) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_resize.py b/tests/test_resize.py index 7feaf9f634..30f8101baa 100644 --- a/tests/test_resize.py +++ b/tests/test_resize.py @@ -19,34 +19,37 @@ from tests.utils import NumpyImageTestCase2D -class ResizeTest(NumpyImageTestCase2D): +class TestResize(NumpyImageTestCase2D): @parameterized.expand([ ("invalid_order", "order", AssertionError) ]) def test_invalid_inputs(self, _, order, raises): with self.assertRaises(raises): - resize = Resize(output_shape=(128, 128, 3), order=order) - resize(self.imt) + resize = Resize(output_spatial_shape=(128, 128, 3), order=order) + resize(self.imt[0]) @parameterized.expand([ - ((1, 1, 64, 64), 1, 'reflect', 0, True, True, True, None), - ((1, 1, 32, 32), 2, 'constant', 3, False, False, False, None), - ((1, 1, 256, 256), 3, 'constant', 3, False, False, False, None), + ((64, 64), 1, 'reflect', 0, True, True, True, None), + ((32, 32), 2, 'constant', 3, False, False, False, None), + ((256, 256), 3, 'constant', 3, False, False, False, None), ]) - def test_correct_results(self, output_shape, order, mode, + def test_correct_results(self, output_spatial_shape, order, mode, cval, clip, preserve_range, anti_aliasing, anti_aliasing_sigma): - resize = Resize(output_shape, order, mode, cval, clip, + resize = Resize(output_spatial_shape, order, mode, cval, clip, preserve_range, anti_aliasing, anti_aliasing_sigma) - expected = skimage.transform.resize(self.imt, output_shape, - order=order, mode=mode, - cval=cval, clip=clip, - preserve_range=preserve_range, - anti_aliasing=anti_aliasing, - anti_aliasing_sigma=anti_aliasing_sigma) - self.assertTrue(np.allclose(resize(self.imt), expected)) + expected = list() + for channel in self.imt[0]: + expected.append(skimage.transform.resize(channel, output_spatial_shape, + order=order, mode=mode, + cval=cval, clip=clip, + preserve_range=preserve_range, + anti_aliasing=anti_aliasing, + anti_aliasing_sigma=anti_aliasing_sigma)) + expected = np.stack(expected).astype(np.float32) + self.assertTrue(np.allclose(resize(self.imt[0]), expected)) if __name__ == '__main__': diff --git a/tests/test_resized.py b/tests/test_resized.py new file mode 100644 index 0000000000..d7830d3e1d --- /dev/null +++ b/tests/test_resized.py @@ -0,0 +1,56 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import skimage +from parameterized import parameterized + +from monai.transforms import Resized +from tests.utils import NumpyImageTestCase2D + + +class TestResized(NumpyImageTestCase2D): + + @parameterized.expand([ + ("invalid_order", "order", AssertionError) + ]) + def test_invalid_inputs(self, _, order, raises): + with self.assertRaises(raises): + resize = Resized(keys='img', output_spatial_shape=(128, 128, 3), order=order) + resize({'img': self.imt[0]}) + + @parameterized.expand([ + ((64, 64), 1, 'reflect', 0, True, True, True, None), + ((32, 32), 2, 'constant', 3, False, False, False, None), + ((256, 256), 3, 'constant', 3, False, False, False, None), + ]) + def test_correct_results(self, output_spatial_shape, order, mode, + cval, clip, preserve_range, + anti_aliasing, anti_aliasing_sigma): + resize = Resized('img', output_spatial_shape, order, mode, cval, clip, + preserve_range, anti_aliasing, + anti_aliasing_sigma) + expected = list() + for channel in self.imt[0]: + expected.append(skimage.transform.resize(channel, output_spatial_shape, + order=order, mode=mode, + cval=cval, clip=clip, + preserve_range=preserve_range, + anti_aliasing=anti_aliasing, + anti_aliasing_sigma=anti_aliasing_sigma)) + expected = np.stack(expected).astype(np.float32) + self.assertTrue(np.allclose(resize({'img': self.imt[0]})['img'], expected)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_rotate.py b/tests/test_rotate.py index 0c34f5809e..7d6d1b531b 100644 --- a/tests/test_rotate.py +++ b/tests/test_rotate.py @@ -15,38 +15,27 @@ import scipy.ndimage from parameterized import parameterized -from monai.transforms import Rotate, Rotated +from monai.transforms import Rotate from tests.utils import NumpyImageTestCase2D -TEST_CASES = [(90, (1, 2), True, 1, 'reflect', 0, True), - (-90, (2, 1), True, 3, 'constant', 0, True), - (180, (2, 3), False, 2, 'constant', 4, False)] +TEST_CASES = [(90, (0, 1), True, 1, 'reflect', 0, True), + (-90, (1, 0), True, 3, 'constant', 0, True), + (180, (1, 0), False, 2, 'constant', 4, False)] -class RotateTest(NumpyImageTestCase2D): +class TestRotate(NumpyImageTestCase2D): @parameterized.expand(TEST_CASES) - def test_correct_results(self, angle, axes, reshape, + def test_correct_results(self, angle, spatial_axes, reshape, order, mode, cval, prefilter): - rotate_fn = Rotate(angle, axes, reshape, + rotate_fn = Rotate(angle, spatial_axes, reshape, order, mode, cval, prefilter) - rotated = rotate_fn(self.imt) - - expected = scipy.ndimage.rotate(self.imt, angle, axes, reshape, order=order, - mode=mode, cval=cval, prefilter=prefilter) + rotated = rotate_fn(self.imt[0]) + expected = list() + for channel in self.imt[0]: + expected.append(scipy.ndimage.rotate(channel, angle, spatial_axes, reshape, order=order, + mode=mode, cval=cval, prefilter=prefilter)) + expected = np.stack(expected).astype(np.float32) self.assertTrue(np.allclose(expected, rotated)) - @parameterized.expand(TEST_CASES) - def test_correct_results_dict(self, angle, axes, reshape, - order, mode, cval, prefilter): - key = 'img' - rotate_fn = Rotated(key, angle, axes, reshape, order, - mode, cval, prefilter) - rotated = rotate_fn({key: self.imt}) - - expected = scipy.ndimage.rotate(self.imt, angle, axes, reshape, order=order, - mode=mode, cval=cval, prefilter=prefilter) - self.assertTrue(np.allclose(expected, rotated[key])) - - if __name__ == '__main__': unittest.main() diff --git a/tests/test_rotate90.py b/tests/test_rotate90.py index 1b1aca78df..990e489cd9 100644 --- a/tests/test_rotate90.py +++ b/tests/test_rotate90.py @@ -17,30 +17,42 @@ from tests.utils import NumpyImageTestCase2D -class Rotate90Test(NumpyImageTestCase2D): +class TestRotate90(NumpyImageTestCase2D): def test_rotate90_default(self): rotate = Rotate90() - rotated = rotate(self.imt) - expected = np.rot90(self.imt, 1, (1, 2)) + rotated = rotate(self.imt[0]) + expected = list() + for channel in self.imt[0]: + expected.append(np.rot90(channel, 1, (0, 1))) + expected = np.stack(expected) self.assertTrue(np.allclose(rotated, expected)) def test_k(self): rotate = Rotate90(k=2) - rotated = rotate(self.imt) - expected = np.rot90(self.imt, 2, (1, 2)) + rotated = rotate(self.imt[0]) + expected = list() + for channel in self.imt[0]: + expected.append(np.rot90(channel, 2, (0, 1))) + expected = np.stack(expected) self.assertTrue(np.allclose(rotated, expected)) - def test_axes(self): - rotate = Rotate90(axes=(1, 2)) - rotated = rotate(self.imt) - expected = np.rot90(self.imt, 1, (1, 2)) + def test_spatial_axes(self): + rotate = Rotate90(spatial_axes=(0, 1)) + rotated = rotate(self.imt[0]) + expected = list() + for channel in self.imt[0]: + expected.append(np.rot90(channel, 1, (0, 1))) + expected = np.stack(expected) self.assertTrue(np.allclose(rotated, expected)) - def test_k_axes(self): - rotate = Rotate90(k=2, axes=(2, 3)) - rotated = rotate(self.imt) - expected = np.rot90(self.imt, 2, (2, 3)) + def test_prob_k_spatial_axes(self): + rotate = Rotate90(k=2, spatial_axes=(0, 1)) + rotated = rotate(self.imt[0]) + expected = list() + for channel in self.imt[0]: + expected.append(np.rot90(channel, 2, (0, 1))) + expected = np.stack(expected) self.assertTrue(np.allclose(rotated, expected)) diff --git a/tests/test_rotate90d.py b/tests/test_rotate90d.py index ccfb2380f0..4b54a9a296 100644 --- a/tests/test_rotate90d.py +++ b/tests/test_rotate90d.py @@ -17,41 +17,53 @@ from tests.utils import NumpyImageTestCase2D -class Rotate90Test(NumpyImageTestCase2D): +class TestRotate90d(NumpyImageTestCase2D): def test_rotate90_default(self): key = 'test' rotate = Rotate90d(keys=key) - rotated = rotate({key: self.imt}) - expected = np.rot90(self.imt, 1, (1, 2)) + rotated = rotate({key: self.imt[0]}) + expected = list() + for channel in self.imt[0]: + expected.append(np.rot90(channel, 1, (0, 1))) + expected = np.stack(expected) self.assertTrue(np.allclose(rotated[key], expected)) def test_k(self): key = None rotate = Rotate90d(keys=key, k=2) - rotated = rotate({key: self.imt}) - expected = np.rot90(self.imt, 2, (1, 2)) + rotated = rotate({key: self.imt[0]}) + expected = list() + for channel in self.imt[0]: + expected.append(np.rot90(channel, 2, (0, 1))) + expected = np.stack(expected) self.assertTrue(np.allclose(rotated[key], expected)) - def test_axes(self): - key = ['test'] - rotate = Rotate90d(keys=key, axes=(1, 2)) - rotated = rotate({key[0]: self.imt}) - expected = np.rot90(self.imt, 1, (1, 2)) - self.assertTrue(np.allclose(rotated[key[0]], expected)) + def test_spatial_axes(self): + key = 'test' + rotate = Rotate90d(keys=key, spatial_axes=(0, 1)) + rotated = rotate({key: self.imt[0]}) + expected = list() + for channel in self.imt[0]: + expected.append(np.rot90(channel, 1, (0, 1))) + expected = np.stack(expected) + self.assertTrue(np.allclose(rotated[key], expected)) - def test_k_axes(self): - key = ('test',) - rotate = Rotate90d(keys=key, k=2, axes=(2, 3)) - rotated = rotate({key[0]: self.imt}) - expected = np.rot90(self.imt, 2, (2, 3)) - self.assertTrue(np.allclose(rotated[key[0]], expected)) + def test_prob_k_spatial_axes(self): + key = 'test' + rotate = Rotate90d(keys=key, k=2, spatial_axes=(0, 1)) + rotated = rotate({key: self.imt[0]}) + expected = list() + for channel in self.imt[0]: + expected.append(np.rot90(channel, 2, (0, 1))) + expected = np.stack(expected) + self.assertTrue(np.allclose(rotated[key], expected)) def test_no_key(self): key = 'unknown' rotate = Rotate90d(keys=key) with self.assertRaisesRegex(KeyError, ''): - rotate({'test': self.imt}) + rotate({'test': self.imt[0]}) if __name__ == '__main__': diff --git a/tests/test_rotated.py b/tests/test_rotated.py new file mode 100644 index 0000000000..af7a758d8d --- /dev/null +++ b/tests/test_rotated.py @@ -0,0 +1,43 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np + +import scipy.ndimage +from parameterized import parameterized + +from monai.transforms import Rotated +from tests.utils import NumpyImageTestCase2D + +TEST_CASES = [(90, (0, 1), True, 1, 'reflect', 0, True), + (-90, (1, 0), True, 3, 'constant', 0, True), + (180, (1, 0), False, 2, 'constant', 4, False)] + +class TestRotated(NumpyImageTestCase2D): + + @parameterized.expand(TEST_CASES) + def test_correct_results(self, angle, spatial_axes, reshape, + order, mode, cval, prefilter): + key = 'img' + rotate_fn = Rotated(key, angle, spatial_axes, reshape, order, + mode, cval, prefilter) + rotated = rotate_fn({key: self.imt[0]}) + expected = list() + for channel in self.imt[0]: + expected.append(scipy.ndimage.rotate(channel, angle, spatial_axes, reshape, order=order, + mode=mode, cval=cval, prefilter=prefilter)) + expected = np.stack(expected).astype(np.float32) + self.assertTrue(np.allclose(expected, rotated[key])) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_unet.py b/tests/test_unet.py index c1e838c284..5b8e85f915 100644 --- a/tests/test_unet.py +++ b/tests/test_unet.py @@ -20,39 +20,39 @@ { 'dimensions': 2, 'in_channels': 1, - 'num_classes': 3, + 'out_channels': 3, 'channels': (16, 32, 64), 'strides': (2, 2), 'num_res_units': 1, }, torch.randn(16, 1, 32, 32), - (16, 1, 32, 32), + (16, 3, 32, 32), ] TEST_CASE_2 = [ # single channel 3D, batch 16 { 'dimensions': 3, 'in_channels': 1, - 'num_classes': 3, + 'out_channels': 3, 'channels': (16, 32, 64), 'strides': (2, 2), 'num_res_units': 1, }, torch.randn(16, 1, 32, 24, 48), - (16, 1, 32, 24, 48), + (16, 3, 32, 24, 48), ] TEST_CASE_3 = [ # 4-channel 3D, batch 16 { 'dimensions': 3, 'in_channels': 4, - 'num_classes': 3, + 'out_channels': 3, 'channels': (16, 32, 64), 'strides': (2, 2), 'num_res_units': 1, }, torch.randn(16, 4, 32, 64, 48), - (16, 1, 32, 64, 48), + (16, 3, 32, 64, 48), ] @@ -63,7 +63,7 @@ def test_shape(self, input_param, input_data, expected_shape): net = UNet(**input_param) net.eval() with torch.no_grad(): - result = net.forward(input_data)[1] + result = net.forward(input_data) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_uniform_rand_patch.py b/tests/test_uniform_rand_patch.py index f11c4b43f4..32b1077ef7 100644 --- a/tests/test_uniform_rand_patch.py +++ b/tests/test_uniform_rand_patch.py @@ -17,13 +17,13 @@ from tests.utils import NumpyImageTestCase2D -class UniformRandomPatchTest(NumpyImageTestCase2D): +class TestUniformRandomPatch(NumpyImageTestCase2D): def test_2d(self): - patch_size = (1, 10, 10) - patch_transform = UniformRandomPatch(patch_size=patch_size) - patch = patch_transform(self.imt) - self.assertTrue(np.allclose(patch.shape[:-2], patch_size[:-2])) + patch_spatial_size = (10, 10) + patch_transform = UniformRandomPatch(patch_spatial_size=patch_spatial_size) + patch = patch_transform(self.imt[0]) + self.assertTrue(np.allclose(patch.shape[1:], patch_spatial_size)) if __name__ == '__main__': diff --git a/tests/test_uniform_rand_patchd.py b/tests/test_uniform_rand_patchd.py index 1ab03b4b6f..a87438acce 100644 --- a/tests/test_uniform_rand_patchd.py +++ b/tests/test_uniform_rand_patchd.py @@ -17,20 +17,20 @@ from tests.utils import NumpyImageTestCase2D -class UniformRandomPatchdTest(NumpyImageTestCase2D): +class TestUniformRandomPatchd(NumpyImageTestCase2D): def test_2d(self): - patch_size = (1, 10, 10) + patch_spatial_size = (10, 10) key = 'test' - patch_transform = UniformRandomPatchd(keys='test', patch_size=patch_size) - patch = patch_transform({key: self.imt}) - self.assertTrue(np.allclose(patch[key].shape[:-2], patch_size[:-2])) + patch_transform = UniformRandomPatchd(keys='test', patch_spatial_size=patch_spatial_size) + patch = patch_transform({key: self.imt[0]}) + self.assertTrue(np.allclose(patch[key].shape[1:], patch_spatial_size)) def test_sync(self): - patch_size = (1, 4, 4) + patch_spatial_size = (4, 4) key_1, key_2 = 'foo', 'bar' rand_image = np.random.rand(3, 10, 10) - patch_transform = UniformRandomPatchd(keys=(key_1, key_2), patch_size=patch_size) + patch_transform = UniformRandomPatchd(keys=(key_1, key_2), patch_spatial_size=patch_spatial_size) patch = patch_transform({key_1: rand_image, key_2: rand_image}) self.assertTrue(np.allclose(patch[key_1], patch[key_2])) diff --git a/tests/test_zoom.py b/tests/test_zoom.py index 83795542bc..cb0af47fef 100644 --- a/tests/test_zoom.py +++ b/tests/test_zoom.py @@ -17,7 +17,7 @@ from scipy.ndimage import zoom as zoom_scipy from parameterized import parameterized -from monai.transforms import Zoom, Zoomd +from monai.transforms import Zoom from tests.utils import NumpyImageTestCase2D VALID_CASES = [(1.1, 3, 'constant', 0, True, False, False), @@ -30,37 +30,31 @@ ("invalid_order", 0.9, 's', AssertionError)] -class ZoomTest(NumpyImageTestCase2D): +class TestZoom(NumpyImageTestCase2D): @parameterized.expand(VALID_CASES) def test_correct_results(self, zoom, order, mode, cval, prefilter, use_gpu, keep_size): zoom_fn = Zoom(zoom=zoom, order=order, mode=mode, cval=cval, prefilter=prefilter, use_gpu=use_gpu, keep_size=keep_size) zoomed = zoom_fn(self.imt[0]) - expected = zoom_scipy(self.imt, zoom=zoom, mode=mode, order=order, - cval=cval, prefilter=prefilter) + expected = list() + for channel in self.imt[0]: + expected.append(zoom_scipy(channel, zoom=zoom, mode=mode, order=order, + cval=cval, prefilter=prefilter)) + expected = np.stack(expected).astype(np.float32) self.assertTrue(np.allclose(expected, zoomed)) - @parameterized.expand(VALID_CASES) - def test_correct_results_dict(self, zoom, order, mode, cval, prefilter, use_gpu, keep_size): - key = 'img' - zoom_fn = Zoomd(key, zoom=zoom, order=order, mode=mode, cval=cval, - prefilter=prefilter, use_gpu=use_gpu, keep_size=keep_size) - zoomed = zoom_fn({key: self.imt[0]}) - - expected = zoom_scipy(self.imt, zoom=zoom, mode=mode, order=order, - cval=cval, prefilter=prefilter) - self.assertTrue(np.allclose(expected, zoomed[key])) - - @parameterized.expand(GPU_CASES) def test_gpu_zoom(self, _, zoom, order, mode, cval, prefilter): if importlib.util.find_spec('cupy'): zoom_fn = Zoom(zoom=zoom, order=order, mode=mode, cval=cval, prefilter=prefilter, use_gpu=True, keep_size=False) zoomed = zoom_fn(self.imt[0]) - expected = zoom_scipy(self.imt, zoom=zoom, mode=mode, order=order, - cval=cval, prefilter=prefilter) + expected = list() + for channel in self.imt[0]: + expected.append(zoom_scipy(channel, zoom=zoom, mode=mode, order=order, + cval=cval, prefilter=prefilter)) + expected = np.stack(expected).astype(np.float32) self.assertTrue(np.allclose(expected, zoomed)) def test_keep_size(self): @@ -76,7 +70,7 @@ def test_keep_size(self): def test_invalid_inputs(self, _, zoom, order, raises): with self.assertRaises(raises): zoom_fn = Zoom(zoom=zoom, order=order) - zoomed = zoom_fn(self.imt) + zoomed = zoom_fn(self.imt[0]) if __name__ == '__main__': diff --git a/tests/test_zoomd.py b/tests/test_zoomd.py new file mode 100644 index 0000000000..6ef85cb1fe --- /dev/null +++ b/tests/test_zoomd.py @@ -0,0 +1,82 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import importlib + +from scipy.ndimage import zoom as zoom_scipy +from parameterized import parameterized + +from monai.transforms import Zoomd +from tests.utils import NumpyImageTestCase2D + +VALID_CASES = [(1.1, 3, 'constant', 0, True, False, False), + (0.9, 3, 'constant', 0, True, False, False), + (0.8, 1, 'reflect', 0, False, False, False)] + +GPU_CASES = [("gpu_zoom", 0.6, 1, 'constant', 0, True)] + +INVALID_CASES = [("no_zoom", None, 1, TypeError), + ("invalid_order", 0.9, 's', AssertionError)] + + +class TestZoomd(NumpyImageTestCase2D): + + @parameterized.expand(VALID_CASES) + def test_correct_results(self, zoom, order, mode, cval, prefilter, use_gpu, keep_size): + key = 'img' + zoom_fn = Zoomd(key, zoom=zoom, order=order, mode=mode, cval=cval, + prefilter=prefilter, use_gpu=use_gpu, keep_size=keep_size) + zoomed = zoom_fn({key: self.imt[0]}) + expected = list() + for channel in self.imt[0]: + expected.append(zoom_scipy(channel, zoom=zoom, mode=mode, order=order, + cval=cval, prefilter=prefilter)) + expected = np.stack(expected).astype(np.float32) + self.assertTrue(np.allclose(expected, zoomed[key])) + + + @parameterized.expand(GPU_CASES) + def test_gpu_zoom(self, _, zoom, order, mode, cval, prefilter): + key = 'img' + if importlib.util.find_spec('cupy'): + zoom_fn = Zoomd(key, zoom=zoom, order=order, mode=mode, cval=cval, + prefilter=prefilter, use_gpu=True, keep_size=False) + zoomed = zoom_fn({key: self.imt[0]}) + expected = list() + for channel in self.imt[0]: + expected.append(zoom_scipy(channel, zoom=zoom, mode=mode, order=order, + cval=cval, prefilter=prefilter)) + expected = np.stack(expected).astype(np.float32) + self.assertTrue(np.allclose(expected, zoomed[key])) + + def test_keep_size(self): + key = 'img' + zoom_fn = Zoomd(key, zoom=0.6, keep_size=True) + zoomed = zoom_fn({key: self.imt[0]}) + self.assertTrue(np.array_equal(zoomed[key].shape, self.imt.shape[1:])) + + zoom_fn = Zoomd(key, zoom=1.3, keep_size=True) + zoomed = zoom_fn({key: self.imt[0]}) + self.assertTrue(np.array_equal(zoomed[key].shape, self.imt.shape[1:])) + + @parameterized.expand(INVALID_CASES) + def test_invalid_inputs(self, _, zoom, order, raises): + key = 'img' + with self.assertRaises(raises): + zoom_fn = Zoomd(key, zoom=zoom, order=order) + zoomed = zoom_fn({key: self.imt[0]}) + + +if __name__ == '__main__': + unittest.main()