From a3e419cdad5c704f1642ca96abeddb5c7cdcd001 Mon Sep 17 00:00:00 2001 From: Sachidanand Alle Date: Wed, 23 Dec 2020 12:32:11 -0800 Subject: [PATCH 01/20] Support to train/run Deepgrow 2D/3D models Signed-off-by: Sachidanand Alle --- monai/apps/deepgrow/__init__.py | 10 + monai/apps/deepgrow/dataset.py | 267 ++++++++++++++ monai/apps/deepgrow/handler.py | 287 +++++++++++++++ monai/apps/deepgrow/interaction.py | 66 ++++ monai/apps/deepgrow/transforms.py | 550 +++++++++++++++++++++++++++++ 5 files changed, 1180 insertions(+) create mode 100644 monai/apps/deepgrow/__init__.py create mode 100644 monai/apps/deepgrow/dataset.py create mode 100644 monai/apps/deepgrow/handler.py create mode 100644 monai/apps/deepgrow/interaction.py create mode 100644 monai/apps/deepgrow/transforms.py diff --git a/monai/apps/deepgrow/__init__.py b/monai/apps/deepgrow/__init__.py new file mode 100644 index 0000000000..d0044e3563 --- /dev/null +++ b/monai/apps/deepgrow/__init__.py @@ -0,0 +1,10 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/monai/apps/deepgrow/dataset.py b/monai/apps/deepgrow/dataset.py new file mode 100644 index 0000000000..dc4c1a059d --- /dev/null +++ b/monai/apps/deepgrow/dataset.py @@ -0,0 +1,267 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +import sys +from typing import Callable, Dict, List, Sequence, Union + +import numpy as np + +from monai.apps.datasets import DecathlonDataset +from monai.transforms import AsChannelFirstd, Compose, GridSampleMode, LoadNiftid, Orientationd, Spacingd + + +# TODO:: Test basic functionality +# TODO:: Unit Test +class DeepgrowDataset(DecathlonDataset): + def __init__( + self, + dimension: int, + pixdim: Sequence[float], + root_dir: str, + task: str, + section: str, + transform: Union[Sequence[Callable], Callable] = (), + download: bool = False, + seed: int = 0, + val_frac: float = 0.2, + cache_num: int = sys.maxsize, + cache_rate: float = 1.0, + num_workers: int = 0, + limit: int = 0, + ) -> None: + self.dimension = dimension + self.pixdim = pixdim + self.limit = limit + + super().__init__( + root_dir=root_dir, + task=task, + section=section, + transform=transform, + download=download, + seed=seed, + val_frac=val_frac, + cache_num=cache_num, + cache_rate=cache_rate, + num_workers=num_workers, + ) + + def _generate_data_list(self, dataset_dir: str) -> List[Dict]: + dataset = super()._generate_data_list(dataset_dir) + + tmp_dataset_dir = dataset_dir + "_{}.deep".format(self.section) + new_datalist = create_dataset( + datalist=dataset, + keys=["image", "label"], + output_dir=tmp_dataset_dir, + dimension=self.dimension, + pixdim=self.pixdim, + limit=self.limit, + relative_path=False, + ) + + dataset_json = os.path.join(tmp_dataset_dir, "dataset.json") + with open(dataset_json, "w") as fp: + json.dump({self.section: new_datalist}, fp, indent=2) + return new_datalist + + +def _get_transforms(keys, pixdim): + mode = [GridSampleMode.BILINEAR, GridSampleMode.NEAREST] if len(keys) == 2 else [GridSampleMode.BILINEAR] + transforms = [ + LoadNiftid(keys=keys), + AsChannelFirstd(keys=keys), + Spacingd(keys=keys, pixdim=pixdim, mode=mode), + Orientationd(keys=keys, axcodes="RAS"), + ] + + return Compose(transforms) + + +def _save_data_2d(vol_idx, data, keys, dataset_dir, relative_path): + vol_image = data[keys[0]] + vol_label = data.get(keys[1]) + data_list = [] + + if len(vol_image.shape) == 4: + logging.info("4D-Image, pick only first series; Image: {}; Label: {}".format(vol_image.shape, vol_label.shape)) + vol_image = vol_image[0] + vol_image = np.moveaxis(vol_image, -1, 0) + + image_count = 0 + label_count = 0 + unique_labels_count = 0 + for sid in range(vol_image.shape[0]): + image = vol_image[sid, ...] + label = vol_label[sid, ...] if vol_label is not None else None + + if vol_label is not None and np.sum(label) == 0: + continue + + image_file_prefix = "vol_idx_{:0>4d}_slice_{:0>3d}".format(vol_idx, sid) + image_file = os.path.join(dataset_dir, "images", image_file_prefix) + image_file += ".npy" + + os.makedirs(os.path.join(dataset_dir, "images"), exist_ok=True) + np.save(image_file, image) + image_count += 1 + + # Test Data + if vol_label is None: + data_list.append( + { + "image": image_file.replace(dataset_dir + "/", "") if relative_path else image_file, + } + ) + continue + + # For all Labels + unique_labels = np.unique(label.flatten()) + unique_labels = unique_labels[unique_labels != 0] + unique_labels_count = max(unique_labels_count, len(unique_labels)) + + for idx in unique_labels: + label_file_prefix = "{}_region_{:0>2d}".format(image_file_prefix, int(idx)) + label_file = os.path.join(dataset_dir, "labels", label_file_prefix) + label_file += ".npy" + + os.makedirs(os.path.join(dataset_dir, "labels"), exist_ok=True) + curr_label = (label == idx).astype(np.float32) + np.save(label_file, curr_label) + + label_count += 1 + data_list.append( + { + "image": image_file.replace(dataset_dir + "/", "") if relative_path else image_file, + "label": label_file.replace(dataset_dir + "/", "") if relative_path else label_file, + "region": int(idx), + } + ) + + print( + "{} => Image: {} => {}; Label: {} => {}; Unique Labels: {}".format( + vol_idx, + vol_image.shape, + image_count, + vol_label.shape if vol_label is not None else None, + label_count, + unique_labels_count, + ) + ) + return data_list + + +def _save_data_3d(vol_idx, data, keys, dataset_dir, relative_path): + vol_image = data[keys[0]] + vol_label = data.get(keys[1]) + data_list = [] + + if len(vol_image.shape) == 4: + logging.info("4D-Image, pick only first series; Image: {}; Label: {}".format(vol_image.shape, vol_label.shape)) + vol_image = vol_image[0] + vol_image = np.moveaxis(vol_image, -1, 0) + + image_count = 0 + label_count = 0 + unique_labels_count = 0 + + image_file_prefix = "vol_idx_{:0>4d}".format(vol_idx) + image_file = os.path.join(dataset_dir, "images", image_file_prefix) + image_file += ".npy" + + os.makedirs(os.path.join(dataset_dir, "images"), exist_ok=True) + np.save(image_file, vol_image) + image_count += 1 + + # Test Data + if vol_label is None: + data_list.append( + { + "image": image_file.replace(dataset_dir + "/", "") if relative_path else image_file, + } + ) + else: + # For all Labels + unique_labels = np.unique(vol_label.flatten()) + unique_labels = unique_labels[unique_labels != 0] + unique_labels_count = max(unique_labels_count, len(unique_labels)) + + for idx in unique_labels: + label_file_prefix = "{}_region_{:0>2d}".format(image_file_prefix, int(idx)) + label_file = os.path.join(dataset_dir, "labels", label_file_prefix) + label_file += ".npy" + + curr_label = (vol_label == idx).astype(np.float32) + os.makedirs(os.path.join(dataset_dir, "labels"), exist_ok=True) + np.save(label_file, curr_label) + + label_count += 1 + data_list.append( + { + "image": image_file.replace(dataset_dir + "/", "") if relative_path else image_file, + "label": label_file.replace(dataset_dir + "/", "") if relative_path else label_file, + "region": int(idx), + } + ) + + print( + "{} => Image: {} => {}; Label: {} => {}; Unique Labels: {}".format( + vol_idx, + vol_image.shape, + image_count, + vol_label.shape if vol_label is not None else None, + label_count, + unique_labels_count, + ) + ) + return data_list + + +def create_dataset( + datalist, output_dir, dimension, pixdim, keys=("image", "label"), base_dir=None, limit=0, relative_path=False +) -> List[Dict]: + if not isinstance(keys, list) and not isinstance(keys, tuple): + keys = [keys] + + transforms = _get_transforms(keys, pixdim) + new_datalist = [] + for idx in range(len(datalist)): + if limit and idx >= limit: + break + + image = datalist[idx][keys[0]] + label = datalist[idx].get(keys[1]) if len(keys) > 1 else None + if base_dir: + image = os.path.join(base_dir, image) + label = os.path.join(base_dir, label) if label else None + + print("{} => {}".format(image, label if label else None)) + if dimension == 2: + data = _save_data_2d( + vol_idx=idx, + data=transforms({"image": image, "label": label}), + keys=("image", "label"), + dataset_dir=output_dir, + relative_path=relative_path, + ) + else: + data = _save_data_3d( + vol_idx=idx, + data=transforms({"image": image, "label": label}), + keys=("image", "label"), + dataset_dir=output_dir, + relative_path=relative_path, + ) + new_datalist.extend(data) + return new_datalist diff --git a/monai/apps/deepgrow/handler.py b/monai/apps/deepgrow/handler.py new file mode 100644 index 0000000000..dbdbbf4289 --- /dev/null +++ b/monai/apps/deepgrow/handler.py @@ -0,0 +1,287 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import statistics + +import numpy as np +import torch +import torch.distributed +from torch.utils.tensorboard import SummaryWriter + +from monai.engines.workflow import Engine, Events +from monai.metrics import compute_meandice +from monai.transforms import rescale_array +from monai.utils import optional_import +from monai.visualize import plot_2d_or_3d_image + +nib, _ = optional_import("nibabel") +torchvision, _ = optional_import("torchvision") +make_grid, _ = optional_import("torchvision.utils", name="make_grid") + +# TODO:: Unit Test + + +class MeanDice: + def __init__(self): + self.data = [] + + def reset(self): + self.data = [] + + def update(self, y_pred, y, batched=True): + if not batched: + y_pred = y_pred[None] + y = y[None] + score = compute_meandice(y_pred=y_pred, y=y, include_background=False).mean() + self.data.append(score.item()) + + def mean(self): + return statistics.mean(self.data) + + def stdev(self): + return statistics.stdev(self.data) if len(self.data) > 1 else 0 + + +class DeepgrowStatsHandler(object): + def __init__( + self, + summary_writer=None, + interval=1, + log_dir="./runs", + tag_name="val_dice", + compute_metric=True, + images=True, + image_interval=1, + max_channels=1, + max_frames=64, + add_scalar=True, + add_stdev=False, + merge_scalar=False, + fold_size=0, + ): + self.writer = SummaryWriter(log_dir=log_dir) if summary_writer is None else summary_writer + self.interval = interval + self.tag_name = tag_name + self.compute_metric = compute_metric + self.images = images + self.image_interval = image_interval + self.max_channels = max_channels + self.max_frames = max_frames + self.add_scalar = add_scalar + self.add_stdev = add_stdev + self.merge_scalar = merge_scalar + self.fold_size = fold_size + + if torch.distributed.is_initialized(): + self.tag_name = "{}-r{}".format(self.tag_name, torch.distributed.get_rank()) + + self.plot_data = {} + self.metric_data = {} + + def attach(self, engine: Engine) -> None: + engine.add_event_handler(Events.ITERATION_COMPLETED(every=self.interval), self, "iteration") + engine.add_event_handler(Events.EPOCH_COMPLETED(every=1), self, "epoch") + + def write_images(self, epoch): + if not self.plot_data or not len(self.plot_data): + return + + all_imgs = [] + titles = [] + for region in sorted(self.plot_data.keys()): + all_imgs.extend(self.plot_data[region]) + metric = self.metric_data.get(region) + dice = "{:.4f}".format(metric.mean()) if self.compute_metric and metric else "" + stdev = "{:.4f}".format(metric.stdev()) if self.compute_metric and metric else "" + titles.extend( + [ + "x({})".format(region), + "y({})".format(region), + "dice: {} +/- {}".format(dice, stdev) if self.compute_metric else "yh({})".format(region), + ] + ) + + if len(all_imgs[0].shape) == 3: + img_tensor = make_grid( + tensor=torch.from_numpy(np.array(all_imgs)), + nrow=3, + normalize=True, + pad_value=2, + ) + self.writer.add_image(tag=f"Deepgrow Regions ({self.tag_name})", img_tensor=img_tensor, global_step=epoch) + + if len(all_imgs[0].shape) == 4: + for region in sorted(self.plot_data.keys()): + tags = [f"region_{region}_image", f"region_{region}_label", f"region_{region}_output"] + for i in range(3): + img = self.plot_data[region][i] + plot_2d_or_3d_image( + img[np.newaxis], epoch, self.writer, 0, self.max_channels, self.max_frames, tags[i] + ) + + logging.info( + "Saved {} Regions {} into Tensorboard at epoch: {}".format( + len(self.plot_data), sorted([*self.plot_data]), epoch + ) + ) + self.writer.flush() + + def write_region_metrics(self, epoch): + metric_sum = 0 + means = {} + stdevs = {} + for region in self.metric_data: + metric = self.metric_data[region].mean() + stdev = self.metric_data[region].stdev() + if self.merge_scalar: + means["{:0>2d}".format(region)] = metric + stdevs["{:0>2d}".format(region)] = stdev + else: + if self.add_stdev: + self.writer.add_scalar("{}_{:0>2d}_mean".format(self.tag_name, region), metric, epoch) + self.writer.add_scalar("{}_{:0>2d}_mean+".format(self.tag_name, region), metric + stdev, epoch) + self.writer.add_scalar("{}_{:0>2d}_mean-".format(self.tag_name, region), metric - stdev, epoch) + else: + self.writer.add_scalar("{}_{:0>2d}".format(self.tag_name, region), metric, epoch) + metric_sum += metric + if self.merge_scalar: + self.writer.add_scalars("{}_region".format(self.tag_name), means, epoch) + + if len(self.metric_data) > 1: + metric_avg = metric_sum / len(self.metric_data) + self.writer.add_scalar("{}_regions_avg".format(self.tag_name), metric_avg, epoch) + self.writer.flush() + + def __call__(self, engine: Engine, action) -> None: + total_steps = engine.state.iteration + if total_steps < engine.state.epoch_length: + total_steps = engine.state.epoch_length * (engine.state.epoch - 1) + total_steps + + if action == "epoch" and not self.fold_size: + epoch = engine.state.epoch + elif self.fold_size and total_steps % self.fold_size == 0: + epoch = int(total_steps / self.fold_size) + else: + epoch = None + + if epoch: + if self.images and epoch % self.image_interval == 0: + self.write_images(epoch) + if self.add_scalar: + self.write_region_metrics(epoch) + + if action == "epoch" or epoch: + self.plot_data = {} + self.metric_data = {} + return + + device = engine.state.device + batch_data = engine.state.batch + output_data = engine.state.output + + for bidx in range(len(batch_data.get("region", []))): + region = batch_data.get("region")[bidx] + region = region.item() if torch.is_tensor(region) else region + + if self.images and self.plot_data.get(region) is None: + self.plot_data[region] = [ + rescale_array(batch_data["image"][bidx][0].detach().cpu().numpy()[np.newaxis], 0, 1), + rescale_array(batch_data["label"][bidx].detach().cpu().numpy(), 0, 1), + rescale_array(output_data["pred"][bidx].detach().cpu().numpy(), 0, 1), + ] + + if self.compute_metric: + if self.metric_data.get(region) is None: + self.metric_data[region] = MeanDice() + self.metric_data[region].update( + y_pred=output_data["pred"][bidx].to(device), y=batch_data["label"][bidx].to(device), batched=False + ) + + +class SegmentationSaver: + def __init__( + self, + output_dir: str = "./runs", + save_np=False, + images=True, + ): + self.output_dir = output_dir + self.save_np = save_np + self.images = images + os.makedirs(self.output_dir, exist_ok=True) + + def attach(self, engine: Engine) -> None: + if not engine.has_event_handler(self, Events.ITERATION_COMPLETED): + engine.add_event_handler(Events.ITERATION_COMPLETED, self) + + def __call__(self, engine: Engine): + batch_data = engine.state.batch + output_data = engine.state.output + device = engine.state.device + tag = "" + if torch.distributed.is_initialized(): + tag = "r{}-".format(torch.distributed.get_rank()) + + for bidx in range(len(batch_data.get("image"))): + step = engine.state.iteration + region = batch_data.get("region")[bidx] + region = region.item() if torch.is_tensor(region) else region + + image = batch_data["image"][bidx][0].detach().cpu().numpy()[np.newaxis] + label = batch_data["label"][bidx].detach().cpu().numpy() + pred = output_data["pred"][bidx].detach().cpu().numpy() + dice = compute_meandice( + y_pred=output_data["pred"][bidx][None].to(device), + y=batch_data["label"][bidx][None].to(device), + include_background=False, + ).mean() + + if self.save_np: + np.savez( + os.path.join( + self.output_dir, + "{}img_label_pred_{}_{:0>4d}_{:0>2d}_{:.4f}".format(tag, region, step, bidx, dice), + ), + image, + label, + pred, + ) + + if self.images and len(image.shape) == 3: + img = make_grid(torch.from_numpy(rescale_array(image, 0, 1)[0])) + lab = make_grid(torch.from_numpy(rescale_array(label, 0, 1)[0])) + + pos = rescale_array(output_data["image"][bidx][1].detach().cpu().numpy()[np.newaxis], 0, 1)[0] + neg = rescale_array(output_data["image"][bidx][2].detach().cpu().numpy()[np.newaxis], 0, 1)[0] + pre = make_grid(torch.from_numpy(np.array([rescale_array(pred, 0, 1)[0], pos, neg]))) + + torchvision.utils.save_image( + tensor=[img, lab, pre], + nrow=3, + pad_value=2, + fp=os.path.join( + self.output_dir, + "{}img_label_pred_{}_{:0>4d}_{:0>2d}_{:.4f}.png".format(tag, region, step, bidx, dice), + ), + ) + + if self.images and len(image.shape) == 4: + samples = {"image": image[0], "label": label[0], "pred": pred[0]} + for sample in samples: + img = nib.Nifti1Image(samples[sample], np.eye(4)) + nib.save( + img, + os.path.join( + self.output_dir, "{}{}_{:0>4d}_{:0>2d}_{:.4f}.nii.gz".format(tag, sample, step, bidx, dice) + ), + ) diff --git a/monai/apps/deepgrow/interaction.py b/monai/apps/deepgrow/interaction.py new file mode 100644 index 0000000000..d29399e2da --- /dev/null +++ b/monai/apps/deepgrow/interaction.py @@ -0,0 +1,66 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict + +import torch +from ignite.engine import Engine, Events +from torch.cuda.amp import autocast + +from monai.engines.utils import CommonKeys + +# TODO:: Unit Test + + +class Interaction: + """ + Deepgrow Training/Evaluation iteration method with interactions (simulation of clicks) support for image and label. + + Args: + transforms: execute additional transformation during every iteration (before train). + Typically, several Tensor based transforms composed by `Compose`. + max_interactions: maximum number of interactions per iteration + train: training or evaluation + key_probability: field name to fill probability for every interaction + """ + + def __init__(self, transforms, max_interactions: int, train: bool, key_probability: str = "probability") -> None: + self.transforms = transforms + self.max_interactions = max_interactions + self.train = train + self.key_probability = key_probability + + def attach(self, engine: Engine) -> None: + engine.add_event_handler(Events.ITERATION_STARTED, self) + + def __call__(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): + if batchdata is None: + raise ValueError("Must provide batch data for current iteration.") + + for j in range(self.max_interactions): + inputs, _ = engine.prepare_batch(batchdata) + inputs = inputs.to(engine.state.device) + + engine.network.eval() + with torch.no_grad(): + if engine.amp: + with autocast(): + predictions = engine.inferer(inputs, engine.network) + else: + predictions = engine.inferer(inputs, engine.network) + + batchdata.update({CommonKeys.PRED: predictions}) + batchdata[self.key_probability] = torch.as_tensor( + ([1.0 - ((1.0 / self.max_interactions) * j)] if self.train else [1.0]) * len(inputs) + ) + batchdata = self.transforms(batchdata) + + return engine._iteration(engine, batchdata) diff --git a/monai/apps/deepgrow/transforms.py b/monai/apps/deepgrow/transforms.py new file mode 100644 index 0000000000..f5dcfdd253 --- /dev/null +++ b/monai/apps/deepgrow/transforms.py @@ -0,0 +1,550 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +A collection of "vanilla" transforms for spatial operations +https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design +""" +import json +from typing import Optional, Union + +import numpy as np + +from monai.config import KeysCollection +from monai.transforms import InterpolateMode, InterpolateModeSequence, Resize, SpatialCrop +from monai.transforms.compose import MapTransform, Randomizable, Transform +from monai.transforms.utils import generate_spatial_bounding_box +from monai.utils import Sequence, ensure_tuple_rep, min_version, optional_import + +measure, _ = optional_import("skimage.measure", "0.14.2", min_version) +distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt") +gaussian_filter, _ = optional_import("scipy.ndimage", name="gaussian_filter") + + +class AddInitialSeedPointd(Randomizable, Transform): + def __init__(self, label="label", guidance="guidance", dimensions=2, connected_regions=6): + self.label = label + self.guidance = guidance + self.dimensions = dimensions + self.connected_regions = connected_regions + + def randomize(self, data=None): + pass + + def _apply(self, label): + label = (label > 0.5).astype(np.float32) + + blobs_labels = measure.label(label.astype(int), background=0) if self.dimensions == 2 else label + assert np.max(blobs_labels) > 0, "Not a valid Label" + + default_guidance = [-1] * (self.dimensions + 1) + pos_guidance = [] + for ridx in range(1, 2 if self.dimensions == 3 else self.connected_regions): + if self.dimensions == 2: + label = (blobs_labels == ridx).astype(np.float32) + if np.sum(label) == 0: + pos_guidance.append(default_guidance) + continue + + distance = distance_transform_cdt(label).flatten() + probability = np.exp(distance) - 1.0 + + idx = np.where(label.flatten() > 0)[0] + seed = np.random.choice(idx, size=1, p=probability[idx] / np.sum(probability[idx])) + dst = distance[seed] + + g = np.asarray(np.unravel_index(seed, label.shape)).transpose().tolist()[0] + g[0] = dst[0] + pos_guidance.append(g) + + return np.asarray([pos_guidance, [default_guidance] * len(pos_guidance)]) + + def __call__(self, data): + data[self.guidance] = self._apply(data[self.label]) + return data + + +class AddGuidanceSignald(Transform): + def __init__(self, image="image", guidance="guidance", sigma=2, dimensions=2, number_intensity_ch=1, batched=False): + self.image = image + self.guidance = guidance + self.sigma = sigma + self.dimensions = dimensions + self.number_intensity_ch = number_intensity_ch + self.batched = batched + + def _get_signal(self, image, guidance): + guidance = guidance.tolist() if isinstance(guidance, np.ndarray) else guidance + if self.dimensions == 3: + signal = np.zeros((len(guidance), image.shape[-3], image.shape[-2], image.shape[-1]), dtype=np.float32) + else: + signal = np.zeros((len(guidance), image.shape[-2], image.shape[-1]), dtype=np.float32) + + for i in range(len(guidance)): + for point in guidance[i]: + if np.any(np.asarray(point) < 0): + continue + + if self.dimensions == 3: + signal[i, int(point[-3]), int(point[-2]), int(point[-1])] = 1.0 + else: + signal[i, int(point[-2]), int(point[-1])] = 1.0 + + if np.max(signal[i]) > 0: + signal[i] = gaussian_filter(signal[i], sigma=self.sigma) + signal[i] = (signal[i] - np.min(signal[i])) / (np.max(signal[i]) - np.min(signal[i])) + return signal + + def _apply(self, image, guidance): + if not self.batched: + signal = self._get_signal(image, guidance) + return np.concatenate([image, signal], axis=0) + + images = [] + for i, g in zip(image, guidance): + i = i[0 : 0 + self.number_intensity_ch, ...] + signal = self._get_signal(i, g) + images.append(np.concatenate([i, signal], axis=0)) + return images + + def __call__(self, data): + image = data[self.image] + guidance = data[self.guidance] + + data[self.image] = self._apply(image, guidance) + return data + + +class FindDiscrepancyRegionsd(Transform): + def __init__(self, label="label", pred="pred", discrepancy="discrepancy", batched=True): + self.label = label + self.pred = pred + self.discrepancy = discrepancy + self.batched = batched + + @staticmethod + def disparity(label, pred): + label = (label > 0.5).astype(np.float32) + pred = (pred > 0.5).astype(np.float32) + disparity = label - pred + + pos_disparity = (disparity > 0).astype(np.float32) + neg_disparity = (disparity < 0).astype(np.float32) + return [pos_disparity, neg_disparity] + + def _apply(self, label, pred): + if not self.batched: + return self.disparity(label, pred) + + disparity = [] + for la, pr in zip(label, pred): + disparity.append(self.disparity(la, pr)) + return disparity + + def __call__(self, data): + label = data[self.label] + pred = data[self.pred] + + data[self.discrepancy] = self._apply(label, pred) + return data + + +class AddRandomGuidanced(Randomizable, Transform): + def __init__( + self, guidance="guidance", discrepancy="discrepancy", probability="probability", dimensions=2, batched=True + ): + self.guidance = guidance + self.discrepancy = discrepancy + self.probability = probability + self.dimensions = dimensions + self.batched = batched + + def randomize(self, data=None): + pass + + @staticmethod + def find_guidance(discrepancy): + distance = distance_transform_cdt(discrepancy).flatten() + probability = np.exp(distance) - 1.0 + idx = np.where(discrepancy.flatten() > 0)[0] + + if np.sum(discrepancy > 0) > 0: + seed = np.random.choice(idx, size=1, p=probability[idx] / np.sum(probability[idx])) + dst = distance[seed] + + g = np.asarray(np.unravel_index(seed, discrepancy.shape)).transpose().tolist()[0] + g[0] = dst[0] + return g + return None + + @staticmethod + def add_guidance(discrepancy, probability): + will_interact = np.random.choice([True, False], p=[probability, 1.0 - probability]) + if not will_interact: + return None, None + + pos_discr = discrepancy[0] + neg_discr = discrepancy[1] + + can_be_positive = np.sum(pos_discr) > 0 + can_be_negative = np.sum(neg_discr) > 0 + correct_pos = np.sum(pos_discr) >= np.sum(neg_discr) + + if correct_pos and can_be_positive: + return AddRandomGuidanced.find_guidance(pos_discr), None + + if not correct_pos and can_be_negative: + return None, AddRandomGuidanced.find_guidance(neg_discr) + return None, None + + def _apply(self, guidance, discrepancy, probability): + guidance = guidance.tolist() if isinstance(guidance, np.ndarray) else guidance + default_guidance = [-1] * (self.dimensions + 1) + + if not self.batched: + pos, neg = self.add_guidance(discrepancy, probability) + if pos: + guidance[0].append(pos) + guidance[1].append(default_guidance) + if neg: + guidance[0].append(default_guidance) + guidance[1].append(neg) + else: + for g, d, p in zip(guidance, discrepancy, probability): + pos, neg = self.add_guidance(d, p) + if pos: + g[0].append(pos) + g[1].append(default_guidance) + if neg: + g[0].append(default_guidance) + g[1].append(neg) + return np.asarray(guidance) + + def __call__(self, data): + guidance = data[self.guidance] + discrepancy = data[self.discrepancy] + probability = data[self.probability] + + data[self.guidance] = self._apply(guidance, discrepancy, probability) + return data + + +class SpatialCropForegroundd(MapTransform): + def __init__( + self, + keys, + source_key: str, + spatial_size, + select_fn=lambda x: x > 0, + channel_indices=None, + margin: int = 0, + meta_key_postfix="meta_dict", + start_coord_key: str = "foreground_start_coord", + end_coord_key: str = "foreground_end_coord", + original_shape_key: str = "foreground_original_shape", + cropped_shape_key: str = "foreground_cropped_shape", + ) -> None: + super().__init__(keys) + + self.source_key = source_key + self.spatial_size = spatial_size + self.select_fn = select_fn + self.channel_indices = channel_indices + self.margin = margin + self.meta_key_postfix = meta_key_postfix + self.start_coord_key = start_coord_key + self.end_coord_key = end_coord_key + self.original_shape_key = original_shape_key + self.cropped_shape_key = cropped_shape_key + + def __call__(self, data): + box_start, box_end = generate_spatial_bounding_box( + data[self.source_key], self.select_fn, self.channel_indices, self.margin + ) + + center = np.mean([box_start, box_end], axis=0).astype(int).tolist() + current_size = np.subtract(box_end, box_start).astype(int).tolist() + + if np.all(np.less(current_size, self.spatial_size)): + cropper = SpatialCrop(roi_center=center, roi_size=self.spatial_size) + box_start = cropper.roi_start + box_end = cropper.roi_end + else: + cropper = SpatialCrop(roi_start=box_start, roi_end=box_end) + + for key in self.keys: + meta_key = f"{key}_{self.meta_key_postfix}" + data[meta_key][self.start_coord_key] = box_start + data[meta_key][self.end_coord_key] = box_end + data[meta_key][self.original_shape_key] = data[key].shape + + image = cropper(data[key]) + data[meta_key][self.cropped_shape_key] = image.shape + data[key] = image + return data + + +# Transforms to support Inference +class SpatialCropGuidanced(MapTransform): + def __init__( + self, + keys, + guidance: str, + spatial_size, + spatial_size_key: str = "spatial_size", + meta_key_postfix="meta_dict", + start_coord_key: str = "foreground_start_coord", + end_coord_key: str = "foreground_end_coord", + original_shape_key: str = "foreground_original_shape", + cropped_shape_key: str = "foreground_cropped_shape", + ) -> None: + super().__init__(keys) + + self.guidance = guidance + self.spatial_size = spatial_size + self.spatial_size_key = spatial_size_key + self.meta_key_postfix = meta_key_postfix + self.start_coord_key = start_coord_key + self.end_coord_key = end_coord_key + self.original_shape_key = original_shape_key + self.cropped_shape_key = cropped_shape_key + + def __call__(self, data): + guidance = data[self.guidance] + center = np.mean(guidance[0] + guidance[1], axis=0).astype(int).tolist() + spatial_size = data.get(self.spatial_size_key, self.spatial_size) + + cropper = SpatialCrop(roi_center=center, roi_size=spatial_size) + box_start, box_end = cropper.roi_start, cropper.roi_end + + for key in self.keys: + meta_key = f"{key}_{self.meta_key_postfix}" + data[meta_key][self.start_coord_key] = box_start + data[meta_key][self.end_coord_key] = box_end + data[meta_key][self.original_shape_key] = data[key].shape + + image = cropper(data[key]) + data[meta_key][self.cropped_shape_key] = image.shape + data[key] = image + + pos_clicks, neg_clicks = guidance[0], guidance[1] + pos = np.subtract(pos_clicks, box_start).tolist() if len(pos_clicks) else [] + neg = np.subtract(neg_clicks, box_start).tolist() if len(neg_clicks) else [] + + data[self.guidance] = [pos, neg] + return data + + +class ResizeGuidanced(Transform): + def __init__( + self, + guidance: str, + ref_image, + meta_key_postfix="meta_dict", + cropped_shape_key: str = "foreground_cropped_shape", + ) -> None: + self.guidance = guidance + self.ref_image = ref_image + self.meta_key_postfix = meta_key_postfix + self.cropped_shape_key = cropped_shape_key + + def __call__(self, data): + guidance = data[self.guidance] + meta_dict = data[f"{self.ref_image}_{self.meta_key_postfix}"] + current_shape = data[self.ref_image].shape[1:] + cropped_shape = meta_dict[self.cropped_shape_key][1:] + factor = np.divide(current_shape, cropped_shape) + + pos_clicks, neg_clicks = guidance[0], guidance[1] + pos = np.multiply(pos_clicks, factor).astype(int).tolist() if len(pos_clicks) else [] + neg = np.multiply(neg_clicks, factor).astype(int).tolist() if len(neg_clicks) else [] + + data[self.guidance] = [pos, neg] + return data + + +class RestoreCroppedLabeld(MapTransform): + def __init__( + self, + keys: KeysCollection, + ref_image: str, + slice_only=False, + channel_first=True, + mode: InterpolateModeSequence = InterpolateMode.NEAREST, + align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None, + meta_key_postfix: str = "meta_dict", + start_coord_key: str = "foreground_start_coord", + end_coord_key: str = "foreground_end_coord", + original_shape_key: str = "foreground_original_shape", + cropped_shape_key: str = "foreground_cropped_shape", + ) -> None: + super().__init__(keys) + self.ref_image = ref_image + self.slice_only = slice_only + self.channel_first = channel_first + self.mode = ensure_tuple_rep(mode, len(self.keys)) + self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) + self.meta_key_postfix = meta_key_postfix + self.start_coord_key = start_coord_key + self.end_coord_key = end_coord_key + self.original_shape_key = original_shape_key + self.cropped_shape_key = cropped_shape_key + + def __call__(self, data): + meta_dict = data[f"{self.ref_image}_{self.meta_key_postfix}"] + + for idx, key in enumerate(self.keys): + image = data[key] + + # Undo Resize + current_size = image.shape + cropped_size = meta_dict[self.cropped_shape_key] + if np.any(np.not_equal(current_size, cropped_size)): + resizer = Resize(spatial_size=cropped_size[1:], mode=self.mode[idx]) + image = resizer(image, mode=self.mode[idx], align_corners=self.align_corners[idx]) + + # Undo Crop + original_shape = meta_dict[self.original_shape_key] + result = np.zeros(original_shape, dtype=np.float32) + box_start = meta_dict[self.start_coord_key] + box_end = meta_dict[self.end_coord_key] + + sd = min(len(box_start), len(box_end), len(image.shape[1:])) # spatial dims + slices = [slice(None)] + [slice(s, e) for s, e in zip(box_start[:sd], box_end[:sd])] + slices = tuple(slices) + result[slices] = image + + # Undo Spacing + current_size = result.shape[1:] + spatial_shape = np.roll(meta_dict["spatial_shape"], 1).tolist() + spatial_size = spatial_shape[-len(current_size) :] + + if np.any(np.not_equal(current_size, spatial_size)): + resizer = Resize(spatial_size=spatial_size, mode=self.mode[idx]) + result = resizer(result, mode=self.mode[idx], align_corners=self.align_corners[idx]) + + # Undo Slicing + slice_idx = meta_dict.get("slice_idx") + if slice_idx is None or self.slice_only: + final_result = result if len(result.shape) <= 3 else result[0] + else: + slice_idx = meta_dict["slice_idx"][0] + final_result = np.zeros(spatial_shape) + if self.channel_first: + final_result[slice_idx] = result + else: + final_result[..., slice_idx] = result + data[key] = final_result + + meta = data.get(f"{key}_{self.meta_key_postfix}") + if meta is None: + meta = dict() + data[f"{key}_{self.meta_key_postfix}"] = meta + meta["slice_idx"] = slice_idx + meta["affine"] = meta_dict["original_affine"] + return data + + +class AddGuidanceFromPointsd(Randomizable, Transform): + def __init__( + self, + ref_image, + guidance="guidance", + foreground="foreground", + background="background", + axis=0, + channel_first=True, + dimensions=2, + slice_key="slice", + meta_key_postfix: str = "meta_dict", + ): + self.ref_image = ref_image + self.guidance = guidance + self.foreground = foreground + self.background = background + self.axis = axis + self.channel_first = channel_first + self.dimensions = dimensions + self.slice_key = slice_key + self.meta_key_postfix = meta_key_postfix + + def randomize(self, data=None): + pass + + def _apply(self, pos_clicks, neg_clicks, factor, slice_num=None): + points = pos_clicks + points.extend(neg_clicks) + points = np.array(points) + + if self.dimensions == 2: + slices = np.unique(points[:, self.axis]).tolist() + slice_idx = slices[0] if slice_num is None else next(x for x in slices if x == slice_num) + + pos = neg = [] + if len(pos_clicks): + pos_clicks = np.array(pos_clicks) + pos = (pos_clicks[np.where(pos_clicks[:, self.axis] == slice_idx)] * factor)[:, 1:].astype(int).tolist() + if len(neg_clicks): + neg_clicks = np.array(neg_clicks) + neg = (neg_clicks[np.where(neg_clicks[:, self.axis] == slice_idx)] * factor)[:, 1:].astype(int).tolist() + + guidance = [pos, neg, slice_idx, factor] + else: + pos = neg = [] + if len(pos_clicks): + pos = np.multiply(pos_clicks, factor).astype(int).tolist() + if len(neg_clicks): + neg = np.multiply(neg_clicks, factor).astype(int).tolist() + guidance = [pos, neg] + return guidance + + def __call__(self, data): + meta_dict = data[f"{self.ref_image}_{self.meta_key_postfix}"] + original_shape = meta_dict["spatial_shape"] + current_shape = list(data[self.ref_image].shape) + + clicks = [data[self.foreground], data[self.background]] + if self.channel_first: + original_shape = np.roll(original_shape, 1).tolist() + for i in range(len(clicks)): + clicks[i] = json.loads(clicks[i]) if isinstance(clicks[i], str) else clicks[i] + clicks[i] = np.array(clicks[i]).astype(int).tolist() + for j in range(len(clicks[i])): + clicks[i][j] = np.roll(clicks[i][j], 1).tolist() + + factor = np.array(current_shape) / original_shape + + data[self.guidance] = self._apply(clicks[0], clicks[1], factor, data.get(self.slice_key)) + return data + + +class Fetch2DSliced(MapTransform): + def __init__(self, keys, guidance="guidance", axis=0, meta_key_postfix: str = "meta_dict"): + super().__init__(keys) + self.guidance = guidance + self.axis = axis + self.meta_key_postfix = meta_key_postfix + + def _apply(self, image, guidance): + slice_idx = guidance[2] + idx = [] + for i in range(len(image.shape)): + idx.append(slice_idx) if i == self.axis else idx.append(slice(0, image.shape[i])) + + idx = tuple(idx) + return image[idx], idx + + def __call__(self, data): + guidance = data[self.guidance] + for key in self.keys: + img, idx = self._apply(data[key], guidance) + data[key] = img + data[f"{key}_{self.meta_key_postfix}"]["slice_idx"] = idx + return data From 0f626f6ac5fba8a60bae3e2fd5ec6efb7bf4e710 Mon Sep 17 00:00:00 2001 From: Sachidanand Alle Date: Wed, 23 Dec 2020 12:46:35 -0800 Subject: [PATCH 02/20] Fix import dependencies Signed-off-by: Sachidanand Alle --- monai/apps/deepgrow/handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/apps/deepgrow/handler.py b/monai/apps/deepgrow/handler.py index dbdbbf4289..7f45e6ad57 100644 --- a/monai/apps/deepgrow/handler.py +++ b/monai/apps/deepgrow/handler.py @@ -16,9 +16,9 @@ import numpy as np import torch import torch.distributed -from torch.utils.tensorboard import SummaryWriter from monai.engines.workflow import Engine, Events +from monai.handlers.tensorboard_handlers import SummaryWriter from monai.metrics import compute_meandice from monai.transforms import rescale_array from monai.utils import optional_import From d15f04ee67d47c9fd01777a6d260aa98098d4098 Mon Sep 17 00:00:00 2001 From: Sachidanand Alle Date: Wed, 23 Dec 2020 12:55:52 -0800 Subject: [PATCH 03/20] Fix import dependencies Signed-off-by: Sachidanand Alle --- monai/apps/deepgrow/interaction.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/apps/deepgrow/interaction.py b/monai/apps/deepgrow/interaction.py index d29399e2da..ca3f29af60 100644 --- a/monai/apps/deepgrow/interaction.py +++ b/monai/apps/deepgrow/interaction.py @@ -12,10 +12,10 @@ from typing import Dict import torch -from ignite.engine import Engine, Events from torch.cuda.amp import autocast from monai.engines.utils import CommonKeys +from monai.engines.workflow import Engine, Events # TODO:: Unit Test From 95f11db9ef566ece5ffedfe6d4c2371cbfbfc88b Mon Sep 17 00:00:00 2001 From: Sachidanand Alle Date: Wed, 23 Dec 2020 13:16:06 -0800 Subject: [PATCH 04/20] Fix import dependencies Signed-off-by: Sachidanand Alle --- monai/apps/deepgrow/interaction.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/monai/apps/deepgrow/interaction.py b/monai/apps/deepgrow/interaction.py index ca3f29af60..b421f34eca 100644 --- a/monai/apps/deepgrow/interaction.py +++ b/monai/apps/deepgrow/interaction.py @@ -12,7 +12,6 @@ from typing import Dict import torch -from torch.cuda.amp import autocast from monai.engines.utils import CommonKeys from monai.engines.workflow import Engine, Events @@ -52,7 +51,7 @@ def __call__(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): engine.network.eval() with torch.no_grad(): if engine.amp: - with autocast(): + with torch.cuda.amp.autocast(): predictions = engine.inferer(inputs, engine.network) else: predictions = engine.inferer(inputs, engine.network) From 9fd94565f90cb7d8b2376037f2d0dcb75ec41c65 Mon Sep 17 00:00:00 2001 From: Sachidanand Alle Date: Sat, 26 Dec 2020 14:15:40 -0800 Subject: [PATCH 05/20] Fix bbox for inference and stat handler for additional info Signed-off-by: Sachidanand Alle --- monai/apps/deepgrow/dataset.py | 3 ++ monai/apps/deepgrow/handler.py | 49 +++++++++++++------------------ monai/apps/deepgrow/transforms.py | 44 ++++++++++++++++++++++----- 3 files changed, 61 insertions(+), 35 deletions(-) diff --git a/monai/apps/deepgrow/dataset.py b/monai/apps/deepgrow/dataset.py index dc4c1a059d..b714184fcf 100644 --- a/monai/apps/deepgrow/dataset.py +++ b/monai/apps/deepgrow/dataset.py @@ -246,6 +246,9 @@ def create_dataset( image = os.path.join(base_dir, image) label = os.path.join(base_dir, label) if label else None + image = os.path.abspath(image) + label = os.path.abspath(label) if label else None + print("{} => {}".format(image, label if label else None)) if dimension == 2: data = _save_data_2d( diff --git a/monai/apps/deepgrow/handler.py b/monai/apps/deepgrow/handler.py index 7f45e6ad57..401edf7755 100644 --- a/monai/apps/deepgrow/handler.py +++ b/monai/apps/deepgrow/handler.py @@ -27,6 +27,9 @@ nib, _ = optional_import("nibabel") torchvision, _ = optional_import("torchvision") make_grid, _ = optional_import("torchvision.utils", name="make_grid") +Image, _ = optional_import("PIL.Image") +ImageDraw, _ = optional_import("PIL.ImageDraw") + # TODO:: Unit Test @@ -65,7 +68,6 @@ def __init__( max_channels=1, max_frames=64, add_scalar=True, - add_stdev=False, merge_scalar=False, fold_size=0, ): @@ -78,7 +80,6 @@ def __init__( self.max_channels = max_channels self.max_frames = max_frames self.add_scalar = add_scalar - self.add_stdev = add_stdev self.merge_scalar = merge_scalar self.fold_size = fold_size @@ -97,27 +98,23 @@ def write_images(self, epoch): return all_imgs = [] - titles = [] for region in sorted(self.plot_data.keys()): - all_imgs.extend(self.plot_data[region]) metric = self.metric_data.get(region) - dice = "{:.4f}".format(metric.mean()) if self.compute_metric and metric else "" - stdev = "{:.4f}".format(metric.stdev()) if self.compute_metric and metric else "" - titles.extend( - [ - "x({})".format(region), - "y({})".format(region), - "dice: {} +/- {}".format(dice, stdev) if self.compute_metric else "yh({})".format(region), - ] - ) + region_data = self.plot_data[region] + if len(region_data[0].shape) == 3: + ti = Image.new("RGB", region_data[0].shape[1:]) + d = ImageDraw.Draw(ti) + t = "region: {}".format(region) + if self.compute_metric: + t = t + "\ndice: {:.4f}".format(metric.mean()) + t = t + "\nstdev: {:.4f}".format(metric.stdev()) + d.multiline_text((10, 10), t, fill=(255, 255, 0)) + ti = rescale_array(np.rollaxis(np.array(ti), 2, 0)[0][np.newaxis]) + all_imgs.append(ti) + all_imgs.extend(region_data) if len(all_imgs[0].shape) == 3: - img_tensor = make_grid( - tensor=torch.from_numpy(np.array(all_imgs)), - nrow=3, - normalize=True, - pad_value=2, - ) + img_tensor = make_grid(tensor=torch.from_numpy(np.array(all_imgs)), nrow=4, normalize=True, pad_value=2) self.writer.add_image(tag=f"Deepgrow Regions ({self.tag_name})", img_tensor=img_tensor, global_step=epoch) if len(all_imgs[0].shape) == 4: @@ -139,20 +136,16 @@ def write_images(self, epoch): def write_region_metrics(self, epoch): metric_sum = 0 means = {} - stdevs = {} for region in self.metric_data: metric = self.metric_data[region].mean() - stdev = self.metric_data[region].stdev() + logging.info( + "Epoch[{}] Metrics -- Region: {:0>2d}, {}: {:.4f}".format(epoch, region, self.tag_name, metric) + ) + if self.merge_scalar: means["{:0>2d}".format(region)] = metric - stdevs["{:0>2d}".format(region)] = stdev else: - if self.add_stdev: - self.writer.add_scalar("{}_{:0>2d}_mean".format(self.tag_name, region), metric, epoch) - self.writer.add_scalar("{}_{:0>2d}_mean+".format(self.tag_name, region), metric + stdev, epoch) - self.writer.add_scalar("{}_{:0>2d}_mean-".format(self.tag_name, region), metric - stdev, epoch) - else: - self.writer.add_scalar("{}_{:0>2d}".format(self.tag_name, region), metric, epoch) + self.writer.add_scalar("{}_{:0>2d}".format(self.tag_name, region), metric, epoch) metric_sum += metric if self.merge_scalar: self.writer.add_scalars("{}_region".format(self.tag_name), means, epoch) diff --git a/monai/apps/deepgrow/transforms.py b/monai/apps/deepgrow/transforms.py index f5dcfdd253..6ff4971d69 100644 --- a/monai/apps/deepgrow/transforms.py +++ b/monai/apps/deepgrow/transforms.py @@ -87,15 +87,21 @@ def _get_signal(self, image, guidance): else: signal = np.zeros((len(guidance), image.shape[-2], image.shape[-1]), dtype=np.float32) + sshape = signal.shape for i in range(len(guidance)): for point in guidance[i]: if np.any(np.asarray(point) < 0): continue if self.dimensions == 3: - signal[i, int(point[-3]), int(point[-2]), int(point[-1])] = 1.0 + p1 = max(0, min(int(point[-3]), sshape[-3] - 1)) + p2 = max(0, min(int(point[-2]), sshape[-2] - 1)) + p3 = max(0, min(int(point[-1]), sshape[-1] - 1)) + signal[i, p1, p2, p3] = 1.0 else: - signal[i, int(point[-2]), int(point[-1])] = 1.0 + p1 = max(0, min(int(point[-2]), sshape[-2] - 1)) + p2 = max(0, min(int(point[-1]), sshape[-1] - 1)) + signal[i, p1, p2] = 1.0 if np.max(signal[i]) > 0: signal[i] = gaussian_filter(signal[i], sigma=self.sigma) @@ -299,6 +305,7 @@ def __init__( guidance: str, spatial_size, spatial_size_key: str = "spatial_size", + margin: int = 20, meta_key_postfix="meta_dict", start_coord_key: str = "foreground_start_coord", end_coord_key: str = "foreground_end_coord", @@ -310,21 +317,44 @@ def __init__( self.guidance = guidance self.spatial_size = spatial_size self.spatial_size_key = spatial_size_key + self.margin = margin self.meta_key_postfix = meta_key_postfix self.start_coord_key = start_coord_key self.end_coord_key = end_coord_key self.original_shape_key = original_shape_key self.cropped_shape_key = cropped_shape_key + def bounding_box(self, points, img_shape): + ndim = len(img_shape) + margin = ensure_tuple_rep(self.margin, ndim) + for m in margin: + if m < 0: + raise ValueError("margin value should not be negative number.") + + box_start = [0] * ndim + box_end = [0] * ndim + + for di in range(ndim): + dt = points[..., di] + min_d = max(min(dt - margin[di]), 0) + max_d = min(img_shape[di] - 1, max(dt + margin[di])) + box_start[di], box_end[di] = min_d, max_d + return box_start, box_end + def __call__(self, data): guidance = data[self.guidance] - center = np.mean(guidance[0] + guidance[1], axis=0).astype(int).tolist() - spatial_size = data.get(self.spatial_size_key, self.spatial_size) + for key in self.keys: + box_start, box_end = self.bounding_box(np.array(guidance[0] + guidance[1]), data[key].shape[1:]) + center = np.mean([box_start, box_end], axis=0).astype(int).tolist() + spatial_size = data.get(self.spatial_size_key, self.spatial_size) - cropper = SpatialCrop(roi_center=center, roi_size=spatial_size) - box_start, box_end = cropper.roi_start, cropper.roi_end + current_size = np.subtract(box_start, box_end).astype(int).tolist() + if np.all(np.less(current_size, self.spatial_size)): + cropper = SpatialCrop(roi_center=center, roi_size=spatial_size) + else: + cropper = SpatialCrop(roi_start=box_start, roi_end=box_end) + box_start, box_end = cropper.roi_start, cropper.roi_end - for key in self.keys: meta_key = f"{key}_{self.meta_key_postfix}" data[meta_key][self.start_coord_key] = box_start data[meta_key][self.end_coord_key] = box_end From 06879fe1ac779c0028fc6681c1919c55726dc0ca Mon Sep 17 00:00:00 2001 From: Sachidanand Alle Date: Sun, 27 Dec 2020 11:58:46 -0800 Subject: [PATCH 06/20] Fix handler and transform init in iteration Signed-off-by: Sachidanand Alle --- monai/apps/deepgrow/handler.py | 11 +++++++---- monai/apps/deepgrow/interaction.py | 19 ++++++++++++++++++- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/monai/apps/deepgrow/handler.py b/monai/apps/deepgrow/handler.py index 401edf7755..2b2d0ec7ff 100644 --- a/monai/apps/deepgrow/handler.py +++ b/monai/apps/deepgrow/handler.py @@ -55,7 +55,7 @@ def stdev(self): return statistics.stdev(self.data) if len(self.data) > 1 else 0 -class DeepgrowStatsHandler(object): +class DeepgrowStatsHandler: def __init__( self, summary_writer=None, @@ -82,6 +82,7 @@ def __init__( self.add_scalar = add_scalar self.merge_scalar = merge_scalar self.fold_size = fold_size + self.logger = logging.getLogger(__name__) if torch.distributed.is_initialized(): self.tag_name = "{}-r{}".format(self.tag_name, torch.distributed.get_rank()) @@ -122,11 +123,12 @@ def write_images(self, epoch): tags = [f"region_{region}_image", f"region_{region}_label", f"region_{region}_output"] for i in range(3): img = self.plot_data[region][i] + img = np.moveaxis(img, -3, -1) plot_2d_or_3d_image( img[np.newaxis], epoch, self.writer, 0, self.max_channels, self.max_frames, tags[i] ) - logging.info( + self.logger.info( "Saved {} Regions {} into Tensorboard at epoch: {}".format( len(self.plot_data), sorted([*self.plot_data]), epoch ) @@ -138,7 +140,7 @@ def write_region_metrics(self, epoch): means = {} for region in self.metric_data: metric = self.metric_data[region].mean() - logging.info( + self.logger.info( "Epoch[{}] Metrics -- Region: {:0>2d}, {}: {:.4f}".format(epoch, region, self.tag_name, metric) ) @@ -271,7 +273,8 @@ def __call__(self, engine: Engine): if self.images and len(image.shape) == 4: samples = {"image": image[0], "label": label[0], "pred": pred[0]} for sample in samples: - img = nib.Nifti1Image(samples[sample], np.eye(4)) + img = np.moveaxis(samples[sample], -3, -1) + img = nib.Nifti1Image(img, np.eye(4)) nib.save( img, os.path.join( diff --git a/monai/apps/deepgrow/interaction.py b/monai/apps/deepgrow/interaction.py index b421f34eca..32cc9eb949 100644 --- a/monai/apps/deepgrow/interaction.py +++ b/monai/apps/deepgrow/interaction.py @@ -8,13 +8,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import importlib from typing import Dict import torch from monai.engines.utils import CommonKeys from monai.engines.workflow import Engine, Events +from monai.transforms import Compose # TODO:: Unit Test @@ -37,6 +38,22 @@ def __init__(self, transforms, max_interactions: int, train: bool, key_probabili self.train = train self.key_probability = key_probability + if not isinstance(self.transforms, Compose): + transforms = [] + for t in transforms: + transforms.append(self.init_external_class(t)) + self.transforms = Compose(transforms) + + @staticmethod + def init_external_class(config_dict): + class_args = None if config_dict.get("args") is None else dict(config_dict.get("args")) + class_path = config_dict.get("path", config_dict["name"]) + + module_name, class_name = class_path.rsplit(".", 1) + m = importlib.import_module(module_name) + c = getattr(m, class_name) + return c(**class_args) if class_args else c() + def attach(self, engine: Engine) -> None: engine.add_event_handler(Events.ITERATION_STARTED, self) From ec8d0d736b7317b784c9623894971db1231bdddf Mon Sep 17 00:00:00 2001 From: Sachidanand Alle Date: Mon, 4 Jan 2021 11:03:45 -0800 Subject: [PATCH 07/20] fix transforms for training/inference Signed-off-by: Sachidanand Alle --- monai/apps/deepgrow/handler.py | 8 ++- monai/apps/deepgrow/transforms.py | 97 +++++++++++++++++++++---------- 2 files changed, 73 insertions(+), 32 deletions(-) diff --git a/monai/apps/deepgrow/handler.py b/monai/apps/deepgrow/handler.py index 2b2d0ec7ff..185c5b1d3d 100644 --- a/monai/apps/deepgrow/handler.py +++ b/monai/apps/deepgrow/handler.py @@ -121,6 +121,9 @@ def write_images(self, epoch): if len(all_imgs[0].shape) == 4: for region in sorted(self.plot_data.keys()): tags = [f"region_{region}_image", f"region_{region}_label", f"region_{region}_output"] + if torch.distributed.is_initialized(): + rank = "r{}-".format(torch.distributed.get_rank()) + tags = [rank + tags[0], rank + tags[1], rank + tags[2]] for i in range(3): img = self.plot_data[region][i] img = np.moveaxis(img, -3, -1) @@ -149,10 +152,11 @@ def write_region_metrics(self, epoch): else: self.writer.add_scalar("{}_{:0>2d}".format(self.tag_name, region), metric, epoch) metric_sum += metric + if self.merge_scalar: + means["avg"] = metric_sum / len(self.metric_data) self.writer.add_scalars("{}_region".format(self.tag_name), means, epoch) - - if len(self.metric_data) > 1: + elif len(self.metric_data) > 1: metric_avg = metric_sum / len(self.metric_data) self.writer.add_scalar("{}_regions_avg".format(self.tag_name), metric_avg, epoch) self.writer.flush() diff --git a/monai/apps/deepgrow/transforms.py b/monai/apps/deepgrow/transforms.py index 6ff4971d69..92b23d24c1 100644 --- a/monai/apps/deepgrow/transforms.py +++ b/monai/apps/deepgrow/transforms.py @@ -28,26 +28,54 @@ gaussian_filter, _ = optional_import("scipy.ndimage", name="gaussian_filter") +class FindAllValidSlicesd(Transform): + def __init__(self, label="label", sids="sids"): + self.label = label + self.sids = sids + + def _apply(self, label): + if len(label.shape) != 4: # only for 3D + return None + + sids = [] + for sid in range(label.shape[1]): # Assume channel is first + if np.sum(label[0][sid]) == 0: + continue + sids.append(sid) + return np.asarray(sids) + + def __call__(self, data): + data[self.sids] = self._apply(data[self.label]) + return data + + class AddInitialSeedPointd(Randomizable, Transform): - def __init__(self, label="label", guidance="guidance", dimensions=2, connected_regions=6): + def __init__(self, label="label", guidance="guidance", sids="sids", sid="sid", connected_regions=6): self.label = label + self.sids = sids + self.sid = sid self.guidance = guidance - self.dimensions = dimensions self.connected_regions = connected_regions def randomize(self, data=None): pass - def _apply(self, label): - label = (label > 0.5).astype(np.float32) + def _apply(self, label, sid): + dimensions = 3 if len(label.shape) > 3 else 2 + default_guidance = [-1] * (dimensions + 1) + + dims = dimensions + if sid is not None and dimensions == 3: + dims = 2 + label = label[0][sid][np.newaxis] # Assume channel is first - blobs_labels = measure.label(label.astype(int), background=0) if self.dimensions == 2 else label + label = (label > 0.5).astype(np.float32) + blobs_labels = measure.label(label.astype(int), background=0) if dims == 2 else label assert np.max(blobs_labels) > 0, "Not a valid Label" - default_guidance = [-1] * (self.dimensions + 1) pos_guidance = [] - for ridx in range(1, 2 if self.dimensions == 3 else self.connected_regions): - if self.dimensions == 2: + for ridx in range(1, 2 if dims == 3 else self.connected_regions): + if dims == 2: label = (blobs_labels == ridx).astype(np.float32) if np.sum(label) == 0: pos_guidance.append(default_guidance) @@ -57,32 +85,42 @@ def _apply(self, label): probability = np.exp(distance) - 1.0 idx = np.where(label.flatten() > 0)[0] - seed = np.random.choice(idx, size=1, p=probability[idx] / np.sum(probability[idx])) + seed = self.R.choice(idx, size=1, p=probability[idx] / np.sum(probability[idx])) dst = distance[seed] g = np.asarray(np.unravel_index(seed, label.shape)).transpose().tolist()[0] g[0] = dst[0] - pos_guidance.append(g) + if dimensions == 2 or dims == 3: + pos_guidance.append(g) + else: + pos_guidance.append([g[0], sid, g[-2], g[-1]]) return np.asarray([pos_guidance, [default_guidance] * len(pos_guidance)]) def __call__(self, data): - data[self.guidance] = self._apply(data[self.label]) + sid = data.get(self.sid) + sids = data.get(self.sids) + if sids is not None: + if sid is None or sid not in sids: + sid = self.R.choice(sids, replace=False) + else: + sid = None + data[self.guidance] = self._apply(data[self.label], sid) return data class AddGuidanceSignald(Transform): - def __init__(self, image="image", guidance="guidance", sigma=2, dimensions=2, number_intensity_ch=1, batched=False): + def __init__(self, image="image", guidance="guidance", sigma=2, number_intensity_ch=1, batched=False): self.image = image self.guidance = guidance self.sigma = sigma - self.dimensions = dimensions self.number_intensity_ch = number_intensity_ch self.batched = batched def _get_signal(self, image, guidance): + dimensions = 3 if len(image.shape) > 3 else 2 guidance = guidance.tolist() if isinstance(guidance, np.ndarray) else guidance - if self.dimensions == 3: + if dimensions == 3: signal = np.zeros((len(guidance), image.shape[-3], image.shape[-2], image.shape[-1]), dtype=np.float32) else: signal = np.zeros((len(guidance), image.shape[-2], image.shape[-1]), dtype=np.float32) @@ -93,7 +131,7 @@ def _get_signal(self, image, guidance): if np.any(np.asarray(point) < 0): continue - if self.dimensions == 3: + if dimensions == 3: p1 = max(0, min(int(point[-3]), sshape[-3] - 1)) p2 = max(0, min(int(point[-2]), sshape[-2] - 1)) p3 = max(0, min(int(point[-1]), sshape[-1] - 1)) @@ -163,13 +201,10 @@ def __call__(self, data): class AddRandomGuidanced(Randomizable, Transform): - def __init__( - self, guidance="guidance", discrepancy="discrepancy", probability="probability", dimensions=2, batched=True - ): + def __init__(self, guidance="guidance", discrepancy="discrepancy", probability="probability", batched=True): self.guidance = guidance self.discrepancy = discrepancy self.probability = probability - self.dimensions = dimensions self.batched = batched def randomize(self, data=None): @@ -212,24 +247,22 @@ def add_guidance(discrepancy, probability): def _apply(self, guidance, discrepancy, probability): guidance = guidance.tolist() if isinstance(guidance, np.ndarray) else guidance - default_guidance = [-1] * (self.dimensions + 1) - if not self.batched: pos, neg = self.add_guidance(discrepancy, probability) if pos: guidance[0].append(pos) - guidance[1].append(default_guidance) + guidance[1].append([-1] * len(pos)) if neg: - guidance[0].append(default_guidance) + guidance[0].append([-1] * len(neg)) guidance[1].append(neg) else: for g, d, p in zip(guidance, discrepancy, probability): pos, neg = self.add_guidance(d, p) if pos: g[0].append(pos) - g[1].append(default_guidance) + g[1].append([-1] * len(pos)) if neg: - g[0].append(default_guidance) + g[0].append([-1] * len(neg)) g[1].append(neg) return np.asarray(guidance) @@ -343,12 +376,17 @@ def bounding_box(self, points, img_shape): def __call__(self, data): guidance = data[self.guidance] + box_start = None for key in self.keys: box_start, box_end = self.bounding_box(np.array(guidance[0] + guidance[1]), data[key].shape[1:]) center = np.mean([box_start, box_end], axis=0).astype(int).tolist() spatial_size = data.get(self.spatial_size_key, self.spatial_size) current_size = np.subtract(box_start, box_end).astype(int).tolist() + spatial_size = spatial_size[-len(current_size) :] + if len(spatial_size) < len(current_size): # 3D spatial_size = [256,256] (include all slices in such case) + diff = len(current_size) - len(spatial_size) + spatial_size = list(data[key].shape[1 : (1 + diff)]) + spatial_size if np.all(np.less(current_size, self.spatial_size)): cropper = SpatialCrop(roi_center=center, roi_size=spatial_size) else: @@ -491,7 +529,6 @@ def __init__( background="background", axis=0, channel_first=True, - dimensions=2, slice_key="slice", meta_key_postfix: str = "meta_dict", ): @@ -501,19 +538,18 @@ def __init__( self.background = background self.axis = axis self.channel_first = channel_first - self.dimensions = dimensions self.slice_key = slice_key self.meta_key_postfix = meta_key_postfix def randomize(self, data=None): pass - def _apply(self, pos_clicks, neg_clicks, factor, slice_num=None): + def _apply(self, dimensions, pos_clicks, neg_clicks, factor, slice_num=None): points = pos_clicks points.extend(neg_clicks) points = np.array(points) - if self.dimensions == 2: + if dimensions == 2: slices = np.unique(points[:, self.axis]).tolist() slice_idx = slices[0] if slice_num is None else next(x for x in slices if x == slice_num) @@ -539,6 +575,7 @@ def __call__(self, data): meta_dict = data[f"{self.ref_image}_{self.meta_key_postfix}"] original_shape = meta_dict["spatial_shape"] current_shape = list(data[self.ref_image].shape) + dimensions = 3 if len(current_shape) >= 3 else 2 clicks = [data[self.foreground], data[self.background]] if self.channel_first: @@ -551,7 +588,7 @@ def __call__(self, data): factor = np.array(current_shape) / original_shape - data[self.guidance] = self._apply(clicks[0], clicks[1], factor, data.get(self.slice_key)) + data[self.guidance] = self._apply(dimensions, clicks[0], clicks[1], factor, data.get(self.slice_key)) return data From bf4d031365672914674220031398cf286744d397 Mon Sep 17 00:00:00 2001 From: Sachidanand Alle Date: Sat, 9 Jan 2021 12:48:48 -0800 Subject: [PATCH 08/20] fix 2D transform Signed-off-by: Sachidanand Alle --- monai/apps/deepgrow/transforms.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/monai/apps/deepgrow/transforms.py b/monai/apps/deepgrow/transforms.py index 92b23d24c1..8a104b2d43 100644 --- a/monai/apps/deepgrow/transforms.py +++ b/monai/apps/deepgrow/transforms.py @@ -45,7 +45,9 @@ def _apply(self, label): return np.asarray(sids) def __call__(self, data): - data[self.sids] = self._apply(data[self.label]) + sids = self._apply(data[self.label]) + if sids is not None and len(sids): + data[self.sids] = sids return data @@ -382,12 +384,15 @@ def __call__(self, data): center = np.mean([box_start, box_end], axis=0).astype(int).tolist() spatial_size = data.get(self.spatial_size_key, self.spatial_size) - current_size = np.subtract(box_start, box_end).astype(int).tolist() + current_size = np.absolute(np.subtract(box_start, box_end)).astype(int).tolist() spatial_size = spatial_size[-len(current_size) :] if len(spatial_size) < len(current_size): # 3D spatial_size = [256,256] (include all slices in such case) diff = len(current_size) - len(spatial_size) spatial_size = list(data[key].shape[1 : (1 + diff)]) + spatial_size - if np.all(np.less(current_size, self.spatial_size)): + + if np.all(np.less(current_size, spatial_size)): + if len(center) == 3: + center[0] = center[0] + (spatial_size[0] // 2 - center[0]) cropper = SpatialCrop(roi_center=center, roi_size=spatial_size) else: cropper = SpatialCrop(roi_start=box_start, roi_end=box_end) @@ -529,6 +534,7 @@ def __init__( background="background", axis=0, channel_first=True, + dimensions=2, slice_key="slice", meta_key_postfix: str = "meta_dict", ): @@ -538,18 +544,19 @@ def __init__( self.background = background self.axis = axis self.channel_first = channel_first + self.dimensions = dimensions self.slice_key = slice_key self.meta_key_postfix = meta_key_postfix def randomize(self, data=None): pass - def _apply(self, dimensions, pos_clicks, neg_clicks, factor, slice_num=None): + def _apply(self, pos_clicks, neg_clicks, factor, slice_num=None): points = pos_clicks points.extend(neg_clicks) points = np.array(points) - if dimensions == 2: + if self.dimensions == 2: slices = np.unique(points[:, self.axis]).tolist() slice_idx = slices[0] if slice_num is None else next(x for x in slices if x == slice_num) @@ -575,7 +582,6 @@ def __call__(self, data): meta_dict = data[f"{self.ref_image}_{self.meta_key_postfix}"] original_shape = meta_dict["spatial_shape"] current_shape = list(data[self.ref_image].shape) - dimensions = 3 if len(current_shape) >= 3 else 2 clicks = [data[self.foreground], data[self.background]] if self.channel_first: @@ -587,8 +593,7 @@ def __call__(self, data): clicks[i][j] = np.roll(clicks[i][j], 1).tolist() factor = np.array(current_shape) / original_shape - - data[self.guidance] = self._apply(dimensions, clicks[0], clicks[1], factor, data.get(self.slice_key)) + data[self.guidance] = self._apply(clicks[0], clicks[1], factor, data.get(self.slice_key)) return data @@ -600,7 +605,7 @@ def __init__(self, keys, guidance="guidance", axis=0, meta_key_postfix: str = "m self.meta_key_postfix = meta_key_postfix def _apply(self, image, guidance): - slice_idx = guidance[2] + slice_idx = guidance[2] # (pos, neg, slice_idx, factor) idx = [] for i in range(len(image.shape)): idx.append(slice_idx) if i == self.axis else idx.append(slice(0, image.shape[i])) From a2c0c8941b4cbe02ae52b32ac333781b1b839cf7 Mon Sep 17 00:00:00 2001 From: Sachidanand Alle Date: Fri, 15 Jan 2021 10:10:41 -0800 Subject: [PATCH 09/20] Fix ci build Signed-off-by: Sachidanand Alle --- monai/apps/deepgrow/dataset.py | 5 +++-- monai/apps/deepgrow/transforms.py | 11 ++++++----- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/monai/apps/deepgrow/dataset.py b/monai/apps/deepgrow/dataset.py index b714184fcf..dc511030b8 100644 --- a/monai/apps/deepgrow/dataset.py +++ b/monai/apps/deepgrow/dataset.py @@ -18,7 +18,8 @@ import numpy as np from monai.apps.datasets import DecathlonDataset -from monai.transforms import AsChannelFirstd, Compose, GridSampleMode, LoadNiftid, Orientationd, Spacingd +from monai.transforms import AsChannelFirstd, Compose, LoadImaged, Orientationd, Spacingd +from monai.utils import GridSampleMode # TODO:: Test basic functionality @@ -80,7 +81,7 @@ def _generate_data_list(self, dataset_dir: str) -> List[Dict]: def _get_transforms(keys, pixdim): mode = [GridSampleMode.BILINEAR, GridSampleMode.NEAREST] if len(keys) == 2 else [GridSampleMode.BILINEAR] transforms = [ - LoadNiftid(keys=keys), + LoadImaged(keys=keys), AsChannelFirstd(keys=keys), Spacingd(keys=keys, pixdim=pixdim, mode=mode), Orientationd(keys=keys, axcodes="RAS"), diff --git a/monai/apps/deepgrow/transforms.py b/monai/apps/deepgrow/transforms.py index 8a104b2d43..2bfe63a937 100644 --- a/monai/apps/deepgrow/transforms.py +++ b/monai/apps/deepgrow/transforms.py @@ -13,15 +13,16 @@ https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design """ import json -from typing import Optional, Union +from typing import Optional, Sequence, Union import numpy as np from monai.config import KeysCollection -from monai.transforms import InterpolateMode, InterpolateModeSequence, Resize, SpatialCrop +from monai.transforms import Resize, SpatialCrop from monai.transforms.compose import MapTransform, Randomizable, Transform +from monai.transforms.spatial.dictionary import InterpolateModeSequence from monai.transforms.utils import generate_spatial_bounding_box -from monai.utils import Sequence, ensure_tuple_rep, min_version, optional_import +from monai.utils import InterpolateMode, ensure_tuple_rep, min_version, optional_import measure, _ = optional_import("skimage.measure", "0.14.2", min_version) distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt") @@ -295,7 +296,7 @@ def __init__( super().__init__(keys) self.source_key = source_key - self.spatial_size = spatial_size + self.spatial_size = list(spatial_size) self.select_fn = select_fn self.channel_indices = channel_indices self.margin = margin @@ -350,7 +351,7 @@ def __init__( super().__init__(keys) self.guidance = guidance - self.spatial_size = spatial_size + self.spatial_size = list(spatial_size) self.spatial_size_key = spatial_size_key self.margin = margin self.meta_key_postfix = meta_key_postfix From cad793ab0458a01813d1e64f828f04d31ce2d83c Mon Sep 17 00:00:00 2001 From: Sachidanand Alle Date: Mon, 8 Feb 2021 13:15:07 -0800 Subject: [PATCH 10/20] fix comments + add unit tests Signed-off-by: Sachidanand Alle --- monai/apps/deepgrow/dataset.py | 206 ++++++++++----------- monai/apps/deepgrow/handler.py | 287 ----------------------------- monai/apps/deepgrow/interaction.py | 9 +- monai/apps/deepgrow/transforms.py | 187 ++++++++++++++++++- tests/test_deepgrow_dataset.py | 53 ++++++ tests/test_deepgrow_interaction.py | 37 ++++ tests/test_deepgrow_transforms.py | 95 ++++++++++ 7 files changed, 468 insertions(+), 406 deletions(-) delete mode 100644 monai/apps/deepgrow/handler.py create mode 100644 tests/test_deepgrow_dataset.py create mode 100644 tests/test_deepgrow_interaction.py create mode 100644 tests/test_deepgrow_transforms.py diff --git a/monai/apps/deepgrow/dataset.py b/monai/apps/deepgrow/dataset.py index dc511030b8..96a88b2a7f 100644 --- a/monai/apps/deepgrow/dataset.py +++ b/monai/apps/deepgrow/dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,76 +9,112 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import logging import os -import sys -from typing import Callable, Dict, List, Sequence, Union +from typing import Dict, List import numpy as np -from monai.apps.datasets import DecathlonDataset from monai.transforms import AsChannelFirstd, Compose, LoadImaged, Orientationd, Spacingd from monai.utils import GridSampleMode -# TODO:: Test basic functionality -# TODO:: Unit Test -class DeepgrowDataset(DecathlonDataset): - def __init__( - self, - dimension: int, - pixdim: Sequence[float], - root_dir: str, - task: str, - section: str, - transform: Union[Sequence[Callable], Callable] = (), - download: bool = False, - seed: int = 0, - val_frac: float = 0.2, - cache_num: int = sys.maxsize, - cache_rate: float = 1.0, - num_workers: int = 0, - limit: int = 0, - ) -> None: - self.dimension = dimension - self.pixdim = pixdim - self.limit = limit - - super().__init__( - root_dir=root_dir, - task=task, - section=section, - transform=transform, - download=download, - seed=seed, - val_frac=val_frac, - cache_num=cache_num, - cache_rate=cache_rate, - num_workers=num_workers, +def create_dataset( + datalist, + output_dir, + dimension, + pixdim, + keys=("image", "label"), + base_dir=None, + limit=0, + relative_path=False, + transforms=None, +) -> List[Dict]: + """ + Utility to pre-process and create dataset list for Deepgrow training over on existing one. + The input data list is normally a list of images and labels (3D volume) which needs pre-processing for Deepgrow + training pipeline. + + Args: + datalist: A generic dataset with a length property which normally contains a list of data dictionary. + For example, typical input data can be a list of dictionaries:: + [{'image': 'img1.nii', 'label': 'label1.nii'}] + + output_dir: target directory to store the training data for Deepgrow Training + pixdim: output voxel spacing. + dimension: dimension for Deepgrow training. It can be 2 or 3. + keys: Image and Label keys in input datalist. Defaults to 'image' and 'label' + base_dir: base directory in case related path is used for the keys in datalist. Defaults to None. + limit: limit number of inputs for pre-processing. Defaults to 0 (no limit). + relative_path: output keys values should be based on relative path. Defaults to False. + transforms: explicit transforms to execute operations on input data. + + Raises: + ValueError: When ``dimension`` is not one of [2, 3] + ValueError: When ``datalist`` is Empty + + Example:: + + datalist = create_dataset( + datalist=[{'image': 'img1.nii', 'label': 'label1.nii'}], + base_dir=None, + output_dir=output_2d, + dimension=2, + keys=('image', 'label') + pixdim=(1.0, 1.0), + limit=0, + relative_path=True ) - def _generate_data_list(self, dataset_dir: str) -> List[Dict]: - dataset = super()._generate_data_list(dataset_dir) - - tmp_dataset_dir = dataset_dir + "_{}.deep".format(self.section) - new_datalist = create_dataset( - datalist=dataset, - keys=["image", "label"], - output_dir=tmp_dataset_dir, - dimension=self.dimension, - pixdim=self.pixdim, - limit=self.limit, - relative_path=False, - ) + print(datalist[0]["image"], datalist[0]["label"]) + """ - dataset_json = os.path.join(tmp_dataset_dir, "dataset.json") - with open(dataset_json, "w") as fp: - json.dump({self.section: new_datalist}, fp, indent=2) - return new_datalist + if dimension not in [2, 3]: + raise ValueError("Dimension can be only 2 or 3 as Deepgrow supports only 2D/3D Training") + if not len(datalist): + raise ValueError("Input Datalist is empty") + + if not isinstance(keys, list) and not isinstance(keys, tuple): + keys = [keys] + + transforms = _default_transforms(keys, pixdim) if transforms is None else transforms + new_datalist = [] + for idx in range(len(datalist)): + if limit and idx >= limit: + break -def _get_transforms(keys, pixdim): + image = datalist[idx][keys[0]] + label = datalist[idx].get(keys[1]) if len(keys) > 1 else None + if base_dir: + image = os.path.join(base_dir, image) + label = os.path.join(base_dir, label) if label else None + + image = os.path.abspath(image) + label = os.path.abspath(label) if label else None + + logging.info("Image: {}; Label: {}".format(image, label if label else None)) + if dimension == 2: + data = _save_data_2d( + vol_idx=idx, + data=transforms({"image": image, "label": label}), + keys=("image", "label"), + dataset_dir=output_dir, + relative_path=relative_path, + ) + else: + data = _save_data_3d( + vol_idx=idx, + data=transforms({"image": image, "label": label}), + keys=("image", "label"), + dataset_dir=output_dir, + relative_path=relative_path, + ) + new_datalist.extend(data) + return new_datalist + + +def _default_transforms(keys, pixdim): mode = [GridSampleMode.BILINEAR, GridSampleMode.NEAREST] if len(keys) == 2 else [GridSampleMode.BILINEAR] transforms = [ LoadImaged(keys=keys), @@ -96,7 +132,11 @@ def _save_data_2d(vol_idx, data, keys, dataset_dir, relative_path): data_list = [] if len(vol_image.shape) == 4: - logging.info("4D-Image, pick only first series; Image: {}; Label: {}".format(vol_image.shape, vol_label.shape)) + logging.info( + "4D-Image, pick only first series; Image: {}; Label: {}".format( + vol_image.shape, vol_label.shape if vol_label else None + ) + ) vol_image = vol_image[0] vol_image = np.moveaxis(vol_image, -1, 0) @@ -150,8 +190,8 @@ def _save_data_2d(vol_idx, data, keys, dataset_dir, relative_path): } ) - print( - "{} => Image: {} => {}; Label: {} => {}; Unique Labels: {}".format( + logging.info( + "{} => Image Shape: {} => {}; Label Shape: {} => {}; Unique Labels: {}".format( vol_idx, vol_image.shape, image_count, @@ -216,8 +256,8 @@ def _save_data_3d(vol_idx, data, keys, dataset_dir, relative_path): } ) - print( - "{} => Image: {} => {}; Label: {} => {}; Unique Labels: {}".format( + logging.info( + "{} => Image Shape: {} => {}; Label Shape: {} => {}; Unique Labels: {}".format( vol_idx, vol_image.shape, image_count, @@ -227,45 +267,3 @@ def _save_data_3d(vol_idx, data, keys, dataset_dir, relative_path): ) ) return data_list - - -def create_dataset( - datalist, output_dir, dimension, pixdim, keys=("image", "label"), base_dir=None, limit=0, relative_path=False -) -> List[Dict]: - if not isinstance(keys, list) and not isinstance(keys, tuple): - keys = [keys] - - transforms = _get_transforms(keys, pixdim) - new_datalist = [] - for idx in range(len(datalist)): - if limit and idx >= limit: - break - - image = datalist[idx][keys[0]] - label = datalist[idx].get(keys[1]) if len(keys) > 1 else None - if base_dir: - image = os.path.join(base_dir, image) - label = os.path.join(base_dir, label) if label else None - - image = os.path.abspath(image) - label = os.path.abspath(label) if label else None - - print("{} => {}".format(image, label if label else None)) - if dimension == 2: - data = _save_data_2d( - vol_idx=idx, - data=transforms({"image": image, "label": label}), - keys=("image", "label"), - dataset_dir=output_dir, - relative_path=relative_path, - ) - else: - data = _save_data_3d( - vol_idx=idx, - data=transforms({"image": image, "label": label}), - keys=("image", "label"), - dataset_dir=output_dir, - relative_path=relative_path, - ) - new_datalist.extend(data) - return new_datalist diff --git a/monai/apps/deepgrow/handler.py b/monai/apps/deepgrow/handler.py deleted file mode 100644 index 185c5b1d3d..0000000000 --- a/monai/apps/deepgrow/handler.py +++ /dev/null @@ -1,287 +0,0 @@ -# Copyright 2020 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import os -import statistics - -import numpy as np -import torch -import torch.distributed - -from monai.engines.workflow import Engine, Events -from monai.handlers.tensorboard_handlers import SummaryWriter -from monai.metrics import compute_meandice -from monai.transforms import rescale_array -from monai.utils import optional_import -from monai.visualize import plot_2d_or_3d_image - -nib, _ = optional_import("nibabel") -torchvision, _ = optional_import("torchvision") -make_grid, _ = optional_import("torchvision.utils", name="make_grid") -Image, _ = optional_import("PIL.Image") -ImageDraw, _ = optional_import("PIL.ImageDraw") - - -# TODO:: Unit Test - - -class MeanDice: - def __init__(self): - self.data = [] - - def reset(self): - self.data = [] - - def update(self, y_pred, y, batched=True): - if not batched: - y_pred = y_pred[None] - y = y[None] - score = compute_meandice(y_pred=y_pred, y=y, include_background=False).mean() - self.data.append(score.item()) - - def mean(self): - return statistics.mean(self.data) - - def stdev(self): - return statistics.stdev(self.data) if len(self.data) > 1 else 0 - - -class DeepgrowStatsHandler: - def __init__( - self, - summary_writer=None, - interval=1, - log_dir="./runs", - tag_name="val_dice", - compute_metric=True, - images=True, - image_interval=1, - max_channels=1, - max_frames=64, - add_scalar=True, - merge_scalar=False, - fold_size=0, - ): - self.writer = SummaryWriter(log_dir=log_dir) if summary_writer is None else summary_writer - self.interval = interval - self.tag_name = tag_name - self.compute_metric = compute_metric - self.images = images - self.image_interval = image_interval - self.max_channels = max_channels - self.max_frames = max_frames - self.add_scalar = add_scalar - self.merge_scalar = merge_scalar - self.fold_size = fold_size - self.logger = logging.getLogger(__name__) - - if torch.distributed.is_initialized(): - self.tag_name = "{}-r{}".format(self.tag_name, torch.distributed.get_rank()) - - self.plot_data = {} - self.metric_data = {} - - def attach(self, engine: Engine) -> None: - engine.add_event_handler(Events.ITERATION_COMPLETED(every=self.interval), self, "iteration") - engine.add_event_handler(Events.EPOCH_COMPLETED(every=1), self, "epoch") - - def write_images(self, epoch): - if not self.plot_data or not len(self.plot_data): - return - - all_imgs = [] - for region in sorted(self.plot_data.keys()): - metric = self.metric_data.get(region) - region_data = self.plot_data[region] - if len(region_data[0].shape) == 3: - ti = Image.new("RGB", region_data[0].shape[1:]) - d = ImageDraw.Draw(ti) - t = "region: {}".format(region) - if self.compute_metric: - t = t + "\ndice: {:.4f}".format(metric.mean()) - t = t + "\nstdev: {:.4f}".format(metric.stdev()) - d.multiline_text((10, 10), t, fill=(255, 255, 0)) - ti = rescale_array(np.rollaxis(np.array(ti), 2, 0)[0][np.newaxis]) - all_imgs.append(ti) - all_imgs.extend(region_data) - - if len(all_imgs[0].shape) == 3: - img_tensor = make_grid(tensor=torch.from_numpy(np.array(all_imgs)), nrow=4, normalize=True, pad_value=2) - self.writer.add_image(tag=f"Deepgrow Regions ({self.tag_name})", img_tensor=img_tensor, global_step=epoch) - - if len(all_imgs[0].shape) == 4: - for region in sorted(self.plot_data.keys()): - tags = [f"region_{region}_image", f"region_{region}_label", f"region_{region}_output"] - if torch.distributed.is_initialized(): - rank = "r{}-".format(torch.distributed.get_rank()) - tags = [rank + tags[0], rank + tags[1], rank + tags[2]] - for i in range(3): - img = self.plot_data[region][i] - img = np.moveaxis(img, -3, -1) - plot_2d_or_3d_image( - img[np.newaxis], epoch, self.writer, 0, self.max_channels, self.max_frames, tags[i] - ) - - self.logger.info( - "Saved {} Regions {} into Tensorboard at epoch: {}".format( - len(self.plot_data), sorted([*self.plot_data]), epoch - ) - ) - self.writer.flush() - - def write_region_metrics(self, epoch): - metric_sum = 0 - means = {} - for region in self.metric_data: - metric = self.metric_data[region].mean() - self.logger.info( - "Epoch[{}] Metrics -- Region: {:0>2d}, {}: {:.4f}".format(epoch, region, self.tag_name, metric) - ) - - if self.merge_scalar: - means["{:0>2d}".format(region)] = metric - else: - self.writer.add_scalar("{}_{:0>2d}".format(self.tag_name, region), metric, epoch) - metric_sum += metric - - if self.merge_scalar: - means["avg"] = metric_sum / len(self.metric_data) - self.writer.add_scalars("{}_region".format(self.tag_name), means, epoch) - elif len(self.metric_data) > 1: - metric_avg = metric_sum / len(self.metric_data) - self.writer.add_scalar("{}_regions_avg".format(self.tag_name), metric_avg, epoch) - self.writer.flush() - - def __call__(self, engine: Engine, action) -> None: - total_steps = engine.state.iteration - if total_steps < engine.state.epoch_length: - total_steps = engine.state.epoch_length * (engine.state.epoch - 1) + total_steps - - if action == "epoch" and not self.fold_size: - epoch = engine.state.epoch - elif self.fold_size and total_steps % self.fold_size == 0: - epoch = int(total_steps / self.fold_size) - else: - epoch = None - - if epoch: - if self.images and epoch % self.image_interval == 0: - self.write_images(epoch) - if self.add_scalar: - self.write_region_metrics(epoch) - - if action == "epoch" or epoch: - self.plot_data = {} - self.metric_data = {} - return - - device = engine.state.device - batch_data = engine.state.batch - output_data = engine.state.output - - for bidx in range(len(batch_data.get("region", []))): - region = batch_data.get("region")[bidx] - region = region.item() if torch.is_tensor(region) else region - - if self.images and self.plot_data.get(region) is None: - self.plot_data[region] = [ - rescale_array(batch_data["image"][bidx][0].detach().cpu().numpy()[np.newaxis], 0, 1), - rescale_array(batch_data["label"][bidx].detach().cpu().numpy(), 0, 1), - rescale_array(output_data["pred"][bidx].detach().cpu().numpy(), 0, 1), - ] - - if self.compute_metric: - if self.metric_data.get(region) is None: - self.metric_data[region] = MeanDice() - self.metric_data[region].update( - y_pred=output_data["pred"][bidx].to(device), y=batch_data["label"][bidx].to(device), batched=False - ) - - -class SegmentationSaver: - def __init__( - self, - output_dir: str = "./runs", - save_np=False, - images=True, - ): - self.output_dir = output_dir - self.save_np = save_np - self.images = images - os.makedirs(self.output_dir, exist_ok=True) - - def attach(self, engine: Engine) -> None: - if not engine.has_event_handler(self, Events.ITERATION_COMPLETED): - engine.add_event_handler(Events.ITERATION_COMPLETED, self) - - def __call__(self, engine: Engine): - batch_data = engine.state.batch - output_data = engine.state.output - device = engine.state.device - tag = "" - if torch.distributed.is_initialized(): - tag = "r{}-".format(torch.distributed.get_rank()) - - for bidx in range(len(batch_data.get("image"))): - step = engine.state.iteration - region = batch_data.get("region")[bidx] - region = region.item() if torch.is_tensor(region) else region - - image = batch_data["image"][bidx][0].detach().cpu().numpy()[np.newaxis] - label = batch_data["label"][bidx].detach().cpu().numpy() - pred = output_data["pred"][bidx].detach().cpu().numpy() - dice = compute_meandice( - y_pred=output_data["pred"][bidx][None].to(device), - y=batch_data["label"][bidx][None].to(device), - include_background=False, - ).mean() - - if self.save_np: - np.savez( - os.path.join( - self.output_dir, - "{}img_label_pred_{}_{:0>4d}_{:0>2d}_{:.4f}".format(tag, region, step, bidx, dice), - ), - image, - label, - pred, - ) - - if self.images and len(image.shape) == 3: - img = make_grid(torch.from_numpy(rescale_array(image, 0, 1)[0])) - lab = make_grid(torch.from_numpy(rescale_array(label, 0, 1)[0])) - - pos = rescale_array(output_data["image"][bidx][1].detach().cpu().numpy()[np.newaxis], 0, 1)[0] - neg = rescale_array(output_data["image"][bidx][2].detach().cpu().numpy()[np.newaxis], 0, 1)[0] - pre = make_grid(torch.from_numpy(np.array([rescale_array(pred, 0, 1)[0], pos, neg]))) - - torchvision.utils.save_image( - tensor=[img, lab, pre], - nrow=3, - pad_value=2, - fp=os.path.join( - self.output_dir, - "{}img_label_pred_{}_{:0>4d}_{:0>2d}_{:.4f}.png".format(tag, region, step, bidx, dice), - ), - ) - - if self.images and len(image.shape) == 4: - samples = {"image": image[0], "label": label[0], "pred": pred[0]} - for sample in samples: - img = np.moveaxis(samples[sample], -3, -1) - img = nib.Nifti1Image(img, np.eye(4)) - nib.save( - img, - os.path.join( - self.output_dir, "{}{}_{:0>4d}_{:0>2d}_{:.4f}.nii.gz".format(tag, sample, step, bidx, dice) - ), - ) diff --git a/monai/apps/deepgrow/interaction.py b/monai/apps/deepgrow/interaction.py index 32cc9eb949..0b10cb6588 100644 --- a/monai/apps/deepgrow/interaction.py +++ b/monai/apps/deepgrow/interaction.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,8 +17,6 @@ from monai.engines.workflow import Engine, Events from monai.transforms import Compose -# TODO:: Unit Test - class Interaction: """ @@ -40,7 +38,7 @@ def __init__(self, transforms, max_interactions: int, train: bool, key_probabili if not isinstance(self.transforms, Compose): transforms = [] - for t in transforms: + for t in self.transforms: transforms.append(self.init_external_class(t)) self.transforms = Compose(transforms) @@ -55,7 +53,8 @@ def init_external_class(config_dict): return c(**class_args) if class_args else c() def attach(self, engine: Engine) -> None: - engine.add_event_handler(Events.ITERATION_STARTED, self) + if not engine.has_event_handler(self, Events.ITERATION_STARTED): + engine.add_event_handler(Events.ITERATION_STARTED, self) def __call__(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): if batchdata is None: diff --git a/monai/apps/deepgrow/transforms.py b/monai/apps/deepgrow/transforms.py index 2bfe63a937..a5519ff47a 100644 --- a/monai/apps/deepgrow/transforms.py +++ b/monai/apps/deepgrow/transforms.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -30,7 +30,16 @@ class FindAllValidSlicesd(Transform): + """ + Find/List all valid slices in the label. Label is assumed to be a 3D Volume with channel_first axis + """ + def __init__(self, label="label", sids="sids"): + """ + Args: + label: label source. + sids: key to store slices indices having valid label map. + """ self.label = label self.sids = sids @@ -53,7 +62,19 @@ def __call__(self, data): class AddInitialSeedPointd(Randomizable, Transform): + """ + Add Random guidance as initial seed point for a given label. + """ + def __init__(self, label="label", guidance="guidance", sids="sids", sid="sid", connected_regions=6): + """ + Args: + label: label source. + guidance: key to store guidance. + sids: key that represents list of valid slice indices for the given label. + sid: key that represents sid for which initial seed point. If not present, random sid will be chosen. + connected_regions: maximum connected regions to use for adding initial points. + """ self.label = label self.sids = sids self.sid = sid @@ -113,7 +134,19 @@ def __call__(self, data): class AddGuidanceSignald(Transform): + """ + Add Guidance signal for input image. + """ + def __init__(self, image="image", guidance="guidance", sigma=2, number_intensity_ch=1, batched=False): + """ + Args: + image: image source. + guidance: key to store guidance. + sigma: standard deviation for Gaussian kernel. + number_intensity_ch: channel index. + batched: defines if input is batched. + """ self.image = image self.guidance = guidance self.sigma = sigma @@ -170,7 +203,18 @@ def __call__(self, data): class FindDiscrepancyRegionsd(Transform): + """ + Find discrepancy between prediction and actual during click interactions during training. + """ + def __init__(self, label="label", pred="pred", discrepancy="discrepancy", batched=True): + """ + Args: + label: label source. + pred: prediction source. + discrepancy: key to store discrepancies found between label and prediction. + batched: defines if input is batched. + """ self.label = label self.pred = pred self.discrepancy = discrepancy @@ -204,7 +248,18 @@ def __call__(self, data): class AddRandomGuidanced(Randomizable, Transform): + """ + Add random guidance based on discrepancies that were found between label and prediction. + """ + def __init__(self, guidance="guidance", discrepancy="discrepancy", probability="probability", batched=True): + """ + Args: + guidance: guidance source. + discrepancy: key that represents discrepancies found between label and prediction. + probability: key that represents click/interaction probability. + batched: defines if input is batched. + """ self.guidance = guidance self.discrepancy = discrepancy self.probability = probability @@ -213,14 +268,13 @@ def __init__(self, guidance="guidance", discrepancy="discrepancy", probability=" def randomize(self, data=None): pass - @staticmethod - def find_guidance(discrepancy): + def find_guidance(self, discrepancy): distance = distance_transform_cdt(discrepancy).flatten() probability = np.exp(distance) - 1.0 idx = np.where(discrepancy.flatten() > 0)[0] if np.sum(discrepancy > 0) > 0: - seed = np.random.choice(idx, size=1, p=probability[idx] / np.sum(probability[idx])) + seed = self.R.choice(idx, size=1, p=probability[idx] / np.sum(probability[idx])) dst = distance[seed] g = np.asarray(np.unravel_index(seed, discrepancy.shape)).transpose().tolist()[0] @@ -228,9 +282,8 @@ def find_guidance(discrepancy): return g return None - @staticmethod - def add_guidance(discrepancy, probability): - will_interact = np.random.choice([True, False], p=[probability, 1.0 - probability]) + def add_guidance(self, discrepancy, probability): + will_interact = self.R.choice([True, False], p=[probability, 1.0 - probability]) if not will_interact: return None, None @@ -242,10 +295,10 @@ def add_guidance(discrepancy, probability): correct_pos = np.sum(pos_discr) >= np.sum(neg_discr) if correct_pos and can_be_positive: - return AddRandomGuidanced.find_guidance(pos_discr), None + return self.find_guidance(pos_discr), None if not correct_pos and can_be_negative: - return None, AddRandomGuidanced.find_guidance(neg_discr) + return None, self.find_guidance(neg_discr) return None, None def _apply(self, guidance, discrepancy, probability): @@ -279,6 +332,10 @@ def __call__(self, data): class SpatialCropForegroundd(MapTransform): + """ + Crop the foreground object of the expected images based on label that fits minimum spatial size. + """ + def __init__( self, keys, @@ -293,6 +350,24 @@ def __init__( original_shape_key: str = "foreground_original_shape", cropped_shape_key: str = "foreground_cropped_shape", ) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + source_key: data source to generate the bounding box of foreground, can be image or label, etc. + spatial_size: minimal spatial size of the image patch e.g. [128, 128, 128] to fit in. + select_fn: function to select expected foreground, default is to select values > 0. + channel_indices: if defined, select foreground only on the specified channels + of image. if None, select foreground on the whole image. + margin: add margin value to spatial dims of the bounding box, if only 1 value provided, use it for all dims. + meta_key_postfix: use `key_{postfix}` to to fetch the meta data according to the key data, + default is `meta_dict`, the meta data is a dictionary object. + For example, to handle key `image`, read/write affine matrices from the + metadata `image_meta_dict` dictionary's `affine` field. + start_coord_key: key to record the start coordinate of spatial bounding box for foreground. + end_coord_key: key to record the end coordinate of spatial bounding box for foreground. + original_shape_key: key to record original shape for foreground. + cropped_shape_key: key to record cropped shape for foreground. + """ super().__init__(keys) self.source_key = source_key @@ -333,8 +408,12 @@ def __call__(self, data): return data -# Transforms to support Inference +# Transforms to support Inference for deepgrow models class SpatialCropGuidanced(MapTransform): + """ + Crop image based on user guidance/clicks with minimal spatial size. + """ + def __init__( self, keys, @@ -348,6 +427,21 @@ def __init__( original_shape_key: str = "foreground_original_shape", cropped_shape_key: str = "foreground_cropped_shape", ) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + guidance: user input clicks/guidance used to generate the bounding box of foreground + spatial_size: minimal spatial size of the image patch e.g. [128, 128, 128] to fit in. + margin: add margin value to spatial dims of the bounding box, if only 1 value provided, use it for all dims. + meta_key_postfix: use `key_{postfix}` to to fetch the meta data according to the key data, + default is `meta_dict`, the meta data is a dictionary object. + For example, to handle key `image`, read/write affine matrices from the + metadata `image_meta_dict` dictionary's `affine` field. + start_coord_key: key to record the start coordinate of spatial bounding box for foreground. + end_coord_key: key to record the end coordinate of spatial bounding box for foreground. + original_shape_key: key to record original shape for foreground. + cropped_shape_key: key to record cropped shape for foreground. + """ super().__init__(keys) self.guidance = guidance @@ -417,6 +511,10 @@ def __call__(self, data): class ResizeGuidanced(Transform): + """ + Resize/re-scale user click/guidance based on original vs resized/cropped image. + """ + def __init__( self, guidance: str, @@ -424,6 +522,16 @@ def __init__( meta_key_postfix="meta_dict", cropped_shape_key: str = "foreground_cropped_shape", ) -> None: + """ + Args: + guidance: user input clicks/guidance used to generate the bounding box of foreground + ref_image: reference image to fetch current and original image details + meta_key_postfix: use `key_{postfix}` to to fetch the meta data according to the key data, + default is `meta_dict`, the meta data is a dictionary object. + For example, to handle key `image`, read/write affine matrices from the + metadata `image_meta_dict` dictionary's `affine` field. + cropped_shape_key: key that records cropped shape for foreground. + """ self.guidance = guidance self.ref_image = ref_image self.meta_key_postfix = meta_key_postfix @@ -445,6 +553,10 @@ def __call__(self, data): class RestoreCroppedLabeld(MapTransform): + """ + Restore/resize label based on original vs resized/cropped image. + """ + def __init__( self, keys: KeysCollection, @@ -459,6 +571,28 @@ def __init__( original_shape_key: str = "foreground_original_shape", cropped_shape_key: str = "foreground_cropped_shape", ) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + ref_image: reference image to fetch current and original image details + slice_only: apply only to an applicable slice, in case of 2D model/prediction + channel_first: if channel is positioned at first + mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, + ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + One of the listed string values or a user supplied function for padding. Defaults to ``"constant"``. + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + align_corners: Geometrically, we consider the pixels of the input as squares rather than points. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + It also can be a sequence of bool, each element corresponds to a key in ``keys``. + meta_key_postfix: use `key_{postfix}` to to fetch the meta data according to the key data, + default is `meta_dict`, the meta data is a dictionary object. + For example, to handle key `image`, read/write affine matrices from the + metadata `image_meta_dict` dictionary's `affine` field. + start_coord_key: key to record the start coordinate of spatial bounding box for foreground. + end_coord_key: key to record the end coordinate of spatial bounding box for foreground. + original_shape_key: key to record original shape for foreground. + cropped_shape_key: key to record cropped shape for foreground. + """ super().__init__(keys) self.ref_image = ref_image self.slice_only = slice_only @@ -527,6 +661,10 @@ def __call__(self, data): class AddGuidanceFromPointsd(Randomizable, Transform): + """ + Add guidance based on user clicks. + """ + def __init__( self, ref_image, @@ -539,6 +677,21 @@ def __init__( slice_key="slice", meta_key_postfix: str = "meta_dict", ): + """ + Args: + ref_image: reference image to fetch current and original image details. + guidance: output key to store guidance. + foreground: key that represents user foreground (+ve) clicks. + background: key that represents user background (-ve) clicks. + axis: axis that represents slice in 3D volume. + channel_first: if channel is positioned at first. + dimensions: dimensions based on model used for deepgrow (2D vs 3D). + slice_key: key that represents applicable slice to add guidance. + meta_key_postfix: use `key_{postfix}` to to fetch the meta data according to the key data, + default is `meta_dict`, the meta data is a dictionary object. + For example, to handle key `image`, read/write affine matrices from the + metadata `image_meta_dict` dictionary's `affine` field. + """ self.ref_image = ref_image self.guidance = guidance self.foreground = foreground @@ -599,7 +752,21 @@ def __call__(self, data): class Fetch2DSliced(MapTransform): + """ + Fetch once slice in case of a 3D volume. + """ + def __init__(self, keys, guidance="guidance", axis=0, meta_key_postfix: str = "meta_dict"): + """ + Args: + keys: keys of the corresponding items to be transformed. + guidance: key that represents guidance. + axis: axis that represents slice in 3D volume. + meta_key_postfix: use `key_{postfix}` to to fetch the meta data according to the key data, + default is `meta_dict`, the meta data is a dictionary object. + For example, to handle key `image`, read/write affine matrices from the + metadata `image_meta_dict` dictionary's `affine` field. + """ super().__init__(keys) self.guidance = guidance self.axis = axis diff --git a/tests/test_deepgrow_dataset.py b/tests/test_deepgrow_dataset.py new file mode 100644 index 0000000000..3b6969acd3 --- /dev/null +++ b/tests/test_deepgrow_dataset.py @@ -0,0 +1,53 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import nibabel as nib +import numpy as np + +from monai.apps.deepgrow.dataset import create_dataset + + +class TestCreateDataset(unittest.TestCase): + def _create_data(self, tempdir): + image = np.random.randint(0, 2, size=(4, 4, 4)) + image_file = os.path.join(tempdir, "image1.nii.gz") + nib.save(nib.Nifti1Image(image, np.eye(4)), image_file) + + label = np.random.randint(0, 1, size=(4, 4, 4)) + label[0][0][2] = 1 + label[0][1][2] = 1 + label[0][1][0] = 1 + label_file = os.path.join(tempdir, "label1.nii.gz") + nib.save(nib.Nifti1Image(label, np.eye(4)), label_file) + + return [{"image": image_file, "label": label_file}] + + def test_create_dataset_2d(self): + with tempfile.TemporaryDirectory() as tempdir: + datalist = self._create_data(tempdir) + output_dir = os.path.join(tempdir, "2d") + deepgrow_datalist = create_dataset(datalist=datalist, output_dir=output_dir, dimension=2, pixdim=(1, 1)) + assert len(deepgrow_datalist) == 2 and deepgrow_datalist[0]["region"] == 1 + + def test_create_dataset_3d(self): + with tempfile.TemporaryDirectory() as tempdir: + datalist = self._create_data(tempdir) + output_dir = os.path.join(tempdir, "3d") + deepgrow_datalist = create_dataset(datalist=datalist, output_dir=output_dir, dimension=3, pixdim=(1, 1, 1)) + assert len(deepgrow_datalist) == 1 and deepgrow_datalist[0]["region"] == 1 + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_deepgrow_interaction.py b/tests/test_deepgrow_interaction.py new file mode 100644 index 0000000000..9b865fdfad --- /dev/null +++ b/tests/test_deepgrow_interaction.py @@ -0,0 +1,37 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from monai.apps.deepgrow.interaction import Interaction + + +class TestInteractions(unittest.TestCase): + def test_interaction(self): + transforms = [ + { + "name": "Activationsd", + "path": "monai.transforms.Activationsd", + "args": {"keys": "pred", "sigmoid": True}, + }, + { + "name": "FindDiscrepancyRegionsd", + "path": "monai.apps.deepgrow.transforms.FindDiscrepancyRegionsd", + "args": {"label": "label", "pred": "pred", "discrepancy": "discrepancy", "batched": True}, + }, + ] + + i = Interaction(transforms=transforms, train=True, max_interactions=5) + assert len(i.transforms.transforms) == 2 + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_deepgrow_transforms.py b/tests/test_deepgrow_transforms.py new file mode 100644 index 0000000000..80b0be14a6 --- /dev/null +++ b/tests/test_deepgrow_transforms.py @@ -0,0 +1,95 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +from monai.apps.deepgrow.transforms import ( + AddGuidanceFromPointsd, + AddGuidanceSignald, + AddInitialSeedPointd, + Fetch2DSliced, + FindAllValidSlicesd, + FindDiscrepancyRegionsd, + ResizeGuidanced, + RestoreCroppedLabeld, + SpatialCropForegroundd, + SpatialCropGuidanced, +) +from monai.transforms import AddChanneld + +DATA = { + "image": np.array([[[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]]]), + "label": np.array([[[0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]), + "image_meta_dict": {}, + "label_meta_dict": {}, + "foreground": [[2, 2, 0]], + "background": [], +} + + +class TestTransforms(unittest.TestCase): + def test_findallvalidslicesd(self): + result = AddChanneld(keys=("image", "label"))(DATA.copy()) + result = FindAllValidSlicesd()(result) + assert len(result["sids"]) == 1 + + def test_spatialcropforegroundd(self): + roi_size = [4, 4, 4] + result = AddChanneld(keys=("image", "label"))(DATA.copy()) + result = SpatialCropForegroundd(keys=("image", "label"), source_key="label", spatial_size=roi_size)(result) + assert result["image"].shape == (1, 1, 4, 4) + + def test_addinitialseedpointd_addguidancesignald(self): + result = AddChanneld(keys=("image", "label"))(DATA.copy()) + result = AddInitialSeedPointd(label="label", guidance="guidance", sids="sids")(result) + assert len(result["guidance"]) + + result = AddGuidanceSignald(image="image", guidance="guidance")(result) + assert result["image"].shape == (3, 1, 5, 5) + + def test_finddiscrepancyregionsd(self): + result = DATA.copy() + result["pred"] = np.array( + [[[0, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]] + ) + result = AddChanneld(keys=("image", "label", "pred"))(result) + result = FindDiscrepancyRegionsd(label="label", pred="pred", discrepancy="discrepancy")(result) + assert np.sum(result["discrepancy"]) > 0 + + def test_inference(self): + result = DATA.copy() + result["image_meta_dict"]["spatial_shape"] = (5, 5, 1) + result["image_meta_dict"]["original_affine"] = (0, 0) + + result = AddGuidanceFromPointsd( + ref_image="image", guidance="guidance", foreground="foreground", background="background", dimensions=2 + )(result) + assert len(result["guidance"][0][0]) == 2 + + result = Fetch2DSliced(keys="image", guidance="guidance")(result) + assert result["image"].shape == (5, 5) + + result = AddChanneld(keys="image")(result) + + result = SpatialCropGuidanced(keys="image", guidance="guidance", spatial_size=(4, 4))(result) + assert result["image"].shape == (1, 4, 4) + + result = ResizeGuidanced(guidance="guidance", ref_image="image")(result) + + result["pred"] = np.random.randint(0, 2, size=(1, 4, 4)) + result = RestoreCroppedLabeld(keys="pred", ref_image="image", mode="nearest")(result) + assert result["pred"].shape == (1, 5, 5) + + +if __name__ == "__main__": + unittest.main() From e45e4e04823d1cb8d75f693139305418c569c8fa Mon Sep 17 00:00:00 2001 From: Sachidanand Alle Date: Mon, 8 Feb 2021 15:23:08 -0800 Subject: [PATCH 11/20] add deegpr tests to exlucde list Signed-off-by: Sachidanand Alle --- tests/min_tests.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/min_tests.py b/tests/min_tests.py index 0fd6985067..b24da7dc21 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -104,6 +104,9 @@ def run_testsuit(): "test_handler_metrics_saver_dist", "test_evenly_divisible_all_gather_dist", "test_handler_classification_saver_dist", + "test_deepgrow_dataset", + "test_deepgrow_interaction", + "test_deepgrow_transforms", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" From 7c641b83eb1daa0e830cff5c17af368934dd3aa9 Mon Sep 17 00:00:00 2001 From: Sachidanand Alle Date: Mon, 8 Feb 2021 17:34:13 -0800 Subject: [PATCH 12/20] fix docs Signed-off-by: Sachidanand Alle --- docs/source/apps.rst | 31 +++++++++++++++++++++++++++++++ monai/apps/deepgrow/__init__.py | 16 ++++++++++++++++ monai/apps/deepgrow/dataset.py | 5 +++-- monai/apps/deepgrow/transforms.py | 3 ++- 4 files changed, 52 insertions(+), 3 deletions(-) diff --git a/docs/source/apps.rst b/docs/source/apps.rst index f2afd93836..dcbe1fc9af 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -28,3 +28,34 @@ Applications .. autofunction:: extractall .. autofunction:: download_and_extract + +`Deepgrow` +----------- + +`Data` +^^^^^^^^^^^^^^^^^^ +.. automodule:: monai.apps.deepgrow +.. autofunction:: create_dataset + +`Transforms` +^^^^^^^^^^^^^^ +.. autoclass:: AddInitialSeedPointd + :members: +.. autoclass:: AddGuidanceSignald + :members: +.. autoclass:: AddRandomGuidanced + :members: +.. autoclass:: AddGuidanceFromPointsd + :members: +.. autoclass:: SpatialCropForegroundd + :members: +.. autoclass:: SpatialCropGuidanced + :members: +.. autoclass:: RestoreCroppedLabeld + :members: +.. autoclass:: FindDiscrepancyRegionsd + :members: +.. autoclass:: FindAllValidSlicesd + :members: +.. autoclass:: Fetch2DSliced + :members: diff --git a/monai/apps/deepgrow/__init__.py b/monai/apps/deepgrow/__init__.py index d0044e3563..0b3b9200de 100644 --- a/monai/apps/deepgrow/__init__.py +++ b/monai/apps/deepgrow/__init__.py @@ -8,3 +8,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .dataset import create_dataset +from .interaction import Interaction +from .transforms import ( + AddGuidanceFromPointsd, + AddGuidanceSignald, + AddInitialSeedPointd, + AddRandomGuidanced, + Fetch2DSliced, + FindAllValidSlicesd, + FindDiscrepancyRegionsd, + ResizeGuidanced, + RestoreCroppedLabeld, + SpatialCropForegroundd, + SpatialCropGuidanced, +) diff --git a/monai/apps/deepgrow/dataset.py b/monai/apps/deepgrow/dataset.py index 96a88b2a7f..849b44bed3 100644 --- a/monai/apps/deepgrow/dataset.py +++ b/monai/apps/deepgrow/dataset.py @@ -32,12 +32,13 @@ def create_dataset( ) -> List[Dict]: """ Utility to pre-process and create dataset list for Deepgrow training over on existing one. - The input data list is normally a list of images and labels (3D volume) which needs pre-processing for Deepgrow - training pipeline. + The input data list is normally a list of images and labels (3D volume) that needs pre-processing + for Deepgrow training pipeline. Args: datalist: A generic dataset with a length property which normally contains a list of data dictionary. For example, typical input data can be a list of dictionaries:: + [{'image': 'img1.nii', 'label': 'label1.nii'}] output_dir: target directory to store the training data for Deepgrow Training diff --git a/monai/apps/deepgrow/transforms.py b/monai/apps/deepgrow/transforms.py index a5519ff47a..2084831eb0 100644 --- a/monai/apps/deepgrow/transforms.py +++ b/monai/apps/deepgrow/transforms.py @@ -29,6 +29,7 @@ gaussian_filter, _ = optional_import("scipy.ndimage", name="gaussian_filter") +# Transforms to support Training for Deepgrow models class FindAllValidSlicesd(Transform): """ Find/List all valid slices in the label. Label is assumed to be a 3D Volume with channel_first axis @@ -408,7 +409,7 @@ def __call__(self, data): return data -# Transforms to support Inference for deepgrow models +# Transforms to support Inference for Deepgrow models class SpatialCropGuidanced(MapTransform): """ Crop image based on user guidance/clicks with minimal spatial size. From 8787d3ee1f6f9081bd1d877909552b5025a949c3 Mon Sep 17 00:00:00 2001 From: Yuan-Ting Hsieh Date: Tue, 9 Feb 2021 00:13:30 -0800 Subject: [PATCH 13/20] Add more unittests and docstrings Signed-off-by: Yuan-Ting Hsieh --- monai/apps/deepgrow/transforms.py | 51 ++++++++++------ tests/test_crop_foregroundd.py | 14 ++--- tests/test_deepgrow_find_all_valid_slicesd.py | 46 +++++++++++++++ .../test_deepgrow_spatial_crop_foregroundd.py | 58 +++++++++++++++++++ tests/test_deepgrow_transforms.py | 12 ---- 5 files changed, 145 insertions(+), 36 deletions(-) create mode 100644 tests/test_deepgrow_find_all_valid_slicesd.py create mode 100644 tests/test_deepgrow_spatial_crop_foregroundd.py diff --git a/monai/apps/deepgrow/transforms.py b/monai/apps/deepgrow/transforms.py index 2084831eb0..35a8ef8d92 100644 --- a/monai/apps/deepgrow/transforms.py +++ b/monai/apps/deepgrow/transforms.py @@ -32,7 +32,8 @@ # Transforms to support Training for Deepgrow models class FindAllValidSlicesd(Transform): """ - Find/List all valid slices in the label. Label is assumed to be a 3D Volume with channel_first axis + Find/List all valid slices in the label. + Label is assumed to be a 4D Volume with shape CDHW, where C=1. """ def __init__(self, label="label", sids="sids"): @@ -45,9 +46,6 @@ def __init__(self, label="label", sids="sids"): self.sids = sids def _apply(self, label): - if len(label.shape) != 4: # only for 3D - return None - sids = [] for sid in range(label.shape[1]): # Assume channel is first if np.sum(label[0][sid]) == 0: @@ -56,10 +54,18 @@ def _apply(self, label): return np.asarray(sids) def __call__(self, data): - sids = self._apply(data[self.label]) + d = dict(data) + label = d[self.label] + if label.shape[0] != 1: + raise ValueError("Only supports single channel labels!") + + if len(label.shape) != 4: # only for 3D + raise ValueError("Only supports label with shape CDHW!") + + sids = self._apply(label) if sids is not None and len(sids): - data[self.sids] = sids - return data + d[self.sids] = sids + return d class AddInitialSeedPointd(Randomizable, Transform): @@ -334,7 +340,17 @@ def __call__(self, data): class SpatialCropForegroundd(MapTransform): """ - Crop the foreground object of the expected images based on label that fits minimum spatial size. + Crop only the foreground object of the expected images. + Note that if the bounding box is smaller than spatial size in all dimensions then we will crop the + object using box's center and spatial_size. + + The typical usage is to help training and evaluation if the valid part is small in the whole medical image. + The valid part can be determined by any field in the data with `source_key`, for example: + - Select values > 0 in image field as the foreground and crop on all fields specified by `keys`. + - Select label = 3 in label field as the foreground to crop on all fields specified by `keys`. + - Select label > 0 in the third channel of a One-Hot label field as the foreground to crop all `keys` fields. + Users can define arbitrary function to select expected foreground from the whole source image or specified + channels. And it can also add margin to every dim of the bounding box of foreground object. """ def __init__( @@ -383,8 +399,9 @@ def __init__( self.cropped_shape_key = cropped_shape_key def __call__(self, data): + d = dict(data) box_start, box_end = generate_spatial_bounding_box( - data[self.source_key], self.select_fn, self.channel_indices, self.margin + d[self.source_key], self.select_fn, self.channel_indices, self.margin ) center = np.mean([box_start, box_end], axis=0).astype(int).tolist() @@ -399,14 +416,14 @@ def __call__(self, data): for key in self.keys: meta_key = f"{key}_{self.meta_key_postfix}" - data[meta_key][self.start_coord_key] = box_start - data[meta_key][self.end_coord_key] = box_end - data[meta_key][self.original_shape_key] = data[key].shape - - image = cropper(data[key]) - data[meta_key][self.cropped_shape_key] = image.shape - data[key] = image - return data + d[meta_key][self.start_coord_key] = box_start + d[meta_key][self.end_coord_key] = box_end + d[meta_key][self.original_shape_key] = d[key].shape + + image = cropper(d[key]) + d[meta_key][self.cropped_shape_key] = image.shape + d[key] = image + return d # Transforms to support Inference for Deepgrow models diff --git a/tests/test_crop_foregroundd.py b/tests/test_crop_foregroundd.py index cacf990763..42e02d56ff 100644 --- a/tests/test_crop_foregroundd.py +++ b/tests/test_crop_foregroundd.py @@ -58,19 +58,19 @@ class TestCropForegroundd(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) - def test_value(self, argments, image, expected_data): - result = CropForegroundd(**argments)(image) + def test_value(self, arguments, image, expected_data): + result = CropForegroundd(**arguments)(image) np.testing.assert_allclose(result["img"], expected_data) @parameterized.expand([TEST_CASE_1]) - def test_foreground_position(self, argments, image, _): - result = CropForegroundd(**argments)(image) + def test_foreground_position(self, arguments, image, _): + result = CropForegroundd(**arguments)(image) np.testing.assert_allclose(result["foreground_start_coord"], np.array([1, 1])) np.testing.assert_allclose(result["foreground_end_coord"], np.array([4, 4])) - argments["start_coord_key"] = "test_start_coord" - argments["end_coord_key"] = "test_end_coord" - result = CropForegroundd(**argments)(image) + arguments["start_coord_key"] = "test_start_coord" + arguments["end_coord_key"] = "test_end_coord" + result = CropForegroundd(**arguments)(image) np.testing.assert_allclose(result["test_start_coord"], np.array([1, 1])) np.testing.assert_allclose(result["test_end_coord"], np.array([4, 4])) diff --git a/tests/test_deepgrow_find_all_valid_slicesd.py b/tests/test_deepgrow_find_all_valid_slicesd.py new file mode 100644 index 0000000000..7c2fc2da5c --- /dev/null +++ b/tests/test_deepgrow_find_all_valid_slicesd.py @@ -0,0 +1,46 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.apps.deepgrow.transforms import FindAllValidSlicesd + +TEST_CASE_1 = [ + { + "image": np.array([[[[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]]]]), + "label": np.array( + [[np.zeros((5, 5)), [[0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]] + ), + }, + [1], +] + +TEST_CASE_2 = [ + { + "image": np.array([[[[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]]]]), + "label": np.zeros((1, 3, 5, 5)), + }, + None, +] + + +class TestFindAllValidSlicesd(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_correct_results(self, input_data, expected_result): + result = FindAllValidSlicesd()(input_data) + assert result.get("sids", None) == expected_result + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_deepgrow_spatial_crop_foregroundd.py b/tests/test_deepgrow_spatial_crop_foregroundd.py new file mode 100644 index 0000000000..47005bca4d --- /dev/null +++ b/tests/test_deepgrow_spatial_crop_foregroundd.py @@ -0,0 +1,58 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.apps.deepgrow.transforms import SpatialCropForegroundd + +TEST_CASE_1 = [ + { + "keys": ["img", "label"], + "source_key": "label", + "select_fn": lambda x: x > 0, + "channel_indices": None, + "margin": 0, + "spatial_size": [4, 4, 4], + }, + { + "img": np.array([[[[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]]]]), + "label": np.array([[[[0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]]), + "img_meta_dict": {}, + "label_meta_dict": {}, + }, + np.array([[[[1, 0, 2, 0], [0, 1, 2, 1], [2, 2, 3, 2], [0, 1, 2, 1]]]]), +] + + +class TestSpatialCropForegroundd(unittest.TestCase): + @parameterized.expand([TEST_CASE_1]) + def test_value(self, arguments, input_data, expected_data): + result = SpatialCropForegroundd(**arguments)(input_data) + np.testing.assert_allclose(result["img"], expected_data) + + @parameterized.expand([TEST_CASE_1]) + def test_foreground_position(self, arguments, input_data, _): + result = SpatialCropForegroundd(**arguments)(input_data) + np.testing.assert_allclose(result["img_meta_dict"]["foreground_start_coord"], np.array([0, 0, 0])) + np.testing.assert_allclose(result["img_meta_dict"]["foreground_end_coord"], np.array([1, 4, 4])) + + arguments["start_coord_key"] = "test_start_coord" + arguments["end_coord_key"] = "test_end_coord" + result = SpatialCropForegroundd(**arguments)(input_data) + np.testing.assert_allclose(result["img_meta_dict"]["test_start_coord"], np.array([0, 0, 0])) + np.testing.assert_allclose(result["img_meta_dict"]["test_end_coord"], np.array([1, 4, 4])) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_deepgrow_transforms.py b/tests/test_deepgrow_transforms.py index 80b0be14a6..31c9ec4c22 100644 --- a/tests/test_deepgrow_transforms.py +++ b/tests/test_deepgrow_transforms.py @@ -18,7 +18,6 @@ AddGuidanceSignald, AddInitialSeedPointd, Fetch2DSliced, - FindAllValidSlicesd, FindDiscrepancyRegionsd, ResizeGuidanced, RestoreCroppedLabeld, @@ -38,17 +37,6 @@ class TestTransforms(unittest.TestCase): - def test_findallvalidslicesd(self): - result = AddChanneld(keys=("image", "label"))(DATA.copy()) - result = FindAllValidSlicesd()(result) - assert len(result["sids"]) == 1 - - def test_spatialcropforegroundd(self): - roi_size = [4, 4, 4] - result = AddChanneld(keys=("image", "label"))(DATA.copy()) - result = SpatialCropForegroundd(keys=("image", "label"), source_key="label", spatial_size=roi_size)(result) - assert result["image"].shape == (1, 1, 4, 4) - def test_addinitialseedpointd_addguidancesignald(self): result = AddChanneld(keys=("image", "label"))(DATA.copy()) result = AddInitialSeedPointd(label="label", guidance="guidance", sids="sids")(result) From 28191ddb96fe25ada35bd2de181b8b9b78afe764 Mon Sep 17 00:00:00 2001 From: Sachidanand Alle Date: Tue, 9 Feb 2021 04:07:17 -0800 Subject: [PATCH 14/20] fix build Signed-off-by: Sachidanand Alle --- tests/test_deepgrow_transforms.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_deepgrow_transforms.py b/tests/test_deepgrow_transforms.py index 31c9ec4c22..0b58788b1d 100644 --- a/tests/test_deepgrow_transforms.py +++ b/tests/test_deepgrow_transforms.py @@ -21,7 +21,6 @@ FindDiscrepancyRegionsd, ResizeGuidanced, RestoreCroppedLabeld, - SpatialCropForegroundd, SpatialCropGuidanced, ) from monai.transforms import AddChanneld From f92c319f84951ccc1c2473544be55e8b60f1370f Mon Sep 17 00:00:00 2001 From: Sachidanand Alle Date: Tue, 9 Feb 2021 05:21:55 -0800 Subject: [PATCH 15/20] fix test Signed-off-by: Sachidanand Alle --- tests/min_tests.py | 1 + tests/test_deepgrow_spatial_crop_foregroundd.py | 8 ++++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/min_tests.py b/tests/min_tests.py index b24da7dc21..adff41d407 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -107,6 +107,7 @@ def run_testsuit(): "test_deepgrow_dataset", "test_deepgrow_interaction", "test_deepgrow_transforms", + "test_deepgrow_spatial_crop_foregroundd", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_deepgrow_spatial_crop_foregroundd.py b/tests/test_deepgrow_spatial_crop_foregroundd.py index 47005bca4d..c3b4de314a 100644 --- a/tests/test_deepgrow_spatial_crop_foregroundd.py +++ b/tests/test_deepgrow_spatial_crop_foregroundd.py @@ -23,7 +23,7 @@ "select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0, - "spatial_size": [4, 4, 4], + "spatial_size": [1, 4, 4], }, { "img": np.array([[[[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]]]]), @@ -31,7 +31,7 @@ "img_meta_dict": {}, "label_meta_dict": {}, }, - np.array([[[[1, 0, 2, 0], [0, 1, 2, 1], [2, 2, 3, 2], [0, 1, 2, 1]]]]), + np.array([[[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]]), ] @@ -44,13 +44,13 @@ def test_value(self, arguments, input_data, expected_data): @parameterized.expand([TEST_CASE_1]) def test_foreground_position(self, arguments, input_data, _): result = SpatialCropForegroundd(**arguments)(input_data) - np.testing.assert_allclose(result["img_meta_dict"]["foreground_start_coord"], np.array([0, 0, 0])) + np.testing.assert_allclose(result["img_meta_dict"]["foreground_start_coord"], np.array([0, 1, 1])) np.testing.assert_allclose(result["img_meta_dict"]["foreground_end_coord"], np.array([1, 4, 4])) arguments["start_coord_key"] = "test_start_coord" arguments["end_coord_key"] = "test_end_coord" result = SpatialCropForegroundd(**arguments)(input_data) - np.testing.assert_allclose(result["img_meta_dict"]["test_start_coord"], np.array([0, 0, 0])) + np.testing.assert_allclose(result["img_meta_dict"]["test_start_coord"], np.array([0, 1, 1])) np.testing.assert_allclose(result["img_meta_dict"]["test_end_coord"], np.array([1, 4, 4])) From 284fa1329b40e8bac3f3d948cea63a667fe3ddf2 Mon Sep 17 00:00:00 2001 From: Sachidanand Alle Date: Tue, 9 Feb 2021 05:30:22 -0800 Subject: [PATCH 16/20] fix test Signed-off-by: Sachidanand Alle --- tests/min_tests.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/min_tests.py b/tests/min_tests.py index adff41d407..6b2633b030 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -108,6 +108,7 @@ def run_testsuit(): "test_deepgrow_interaction", "test_deepgrow_transforms", "test_deepgrow_spatial_crop_foregroundd", + "test_deepgrow_find_all_valid_slicesd", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" From cfdc769d5298afac58369de2b3e16042c71562fe Mon Sep 17 00:00:00 2001 From: Sachidanand Alle Date: Tue, 9 Feb 2021 06:03:04 -0800 Subject: [PATCH 17/20] fix build + docs Signed-off-by: Sachidanand Alle --- docs/source/apps.rst | 11 ++++++----- monai/apps/deepgrow/__init__.py | 16 ---------------- 2 files changed, 6 insertions(+), 21 deletions(-) diff --git a/docs/source/apps.rst b/docs/source/apps.rst index dcbe1fc9af..e5acbd226d 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -32,13 +32,14 @@ Applications `Deepgrow` ----------- -`Data` -^^^^^^^^^^^^^^^^^^ -.. automodule:: monai.apps.deepgrow +.. automodule:: monai.apps.deepgrow.dataset .. autofunction:: create_dataset -`Transforms` -^^^^^^^^^^^^^^ +.. automodule:: monai.apps.deepgrow.interaction +.. autoclass:: Interaction + :members: + +.. automodule:: monai.apps.deepgrow.transforms .. autoclass:: AddInitialSeedPointd :members: .. autoclass:: AddGuidanceSignald diff --git a/monai/apps/deepgrow/__init__.py b/monai/apps/deepgrow/__init__.py index 0b3b9200de..d0044e3563 100644 --- a/monai/apps/deepgrow/__init__.py +++ b/monai/apps/deepgrow/__init__.py @@ -8,19 +8,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from .dataset import create_dataset -from .interaction import Interaction -from .transforms import ( - AddGuidanceFromPointsd, - AddGuidanceSignald, - AddInitialSeedPointd, - AddRandomGuidanced, - Fetch2DSliced, - FindAllValidSlicesd, - FindDiscrepancyRegionsd, - ResizeGuidanced, - RestoreCroppedLabeld, - SpatialCropForegroundd, - SpatialCropGuidanced, -) From cc8175936c26365d44a08c032b24abb9c376ba5c Mon Sep 17 00:00:00 2001 From: Yuan-Ting Hsieh Date: Wed, 10 Feb 2021 18:04:58 -0800 Subject: [PATCH 18/20] Fix connected_regions / refactor docstring / add unit tests Signed-off-by: Yuan-Ting Hsieh --- monai/apps/deepgrow/transforms.py | 62 +++--- tests/test_deepgrow_find_all_valid_slicesd.py | 46 ----- .../test_deepgrow_spatial_crop_foregroundd.py | 58 ------ tests/test_deepgrow_transforms.py | 185 ++++++++++++++++-- 4 files changed, 196 insertions(+), 155 deletions(-) delete mode 100644 tests/test_deepgrow_find_all_valid_slicesd.py delete mode 100644 tests/test_deepgrow_spatial_crop_foregroundd.py diff --git a/monai/apps/deepgrow/transforms.py b/monai/apps/deepgrow/transforms.py index 35a8ef8d92..e034f2e686 100644 --- a/monai/apps/deepgrow/transforms.py +++ b/monai/apps/deepgrow/transforms.py @@ -34,14 +34,13 @@ class FindAllValidSlicesd(Transform): """ Find/List all valid slices in the label. Label is assumed to be a 4D Volume with shape CDHW, where C=1. + + Args: + label: key to the label source. + sids: key to store slices indices having valid label map. """ def __init__(self, label="label", sids="sids"): - """ - Args: - label: label source. - sids: key to store slices indices having valid label map. - """ self.label = label self.sids = sids @@ -70,18 +69,17 @@ def __call__(self, data): class AddInitialSeedPointd(Randomizable, Transform): """ - Add Random guidance as initial seed point for a given label. + Add random guidance as initial seed point for a given label. + + Args: + label: label source. + guidance: key to store guidance. + sids: key that represents list of valid slice indices for the given label. + sid: key that represents the slice to add initial seed point. If not present, random sid will be chosen. + connected_regions: maximum connected regions to use for adding initial points. """ - def __init__(self, label="label", guidance="guidance", sids="sids", sid="sid", connected_regions=6): - """ - Args: - label: label source. - guidance: key to store guidance. - sids: key that represents list of valid slice indices for the given label. - sid: key that represents sid for which initial seed point. If not present, random sid will be chosen. - connected_regions: maximum connected regions to use for adding initial points. - """ + def __init__(self, label="label", guidance="guidance", sids="sids", sid="sid", connected_regions=5): self.label = label self.sids = sids self.sid = sid @@ -105,7 +103,7 @@ def _apply(self, label, sid): assert np.max(blobs_labels) > 0, "Not a valid Label" pos_guidance = [] - for ridx in range(1, 2 if dims == 3 else self.connected_regions): + for ridx in range(1, 2 if dims == 3 else self.connected_regions + 1): if dims == 2: label = (blobs_labels == ridx).astype(np.float32) if np.sum(label) == 0: @@ -120,7 +118,7 @@ def _apply(self, label, sid): dst = distance[seed] g = np.asarray(np.unravel_index(seed, label.shape)).transpose().tolist()[0] - g[0] = dst[0] + g[0] = dst[0] # for debug if dimensions == 2 or dims == 3: pos_guidance.append(g) else: @@ -143,17 +141,18 @@ def __call__(self, data): class AddGuidanceSignald(Transform): """ Add Guidance signal for input image. + + Based on the "guidance" points, apply gaussian to them and add them as new channel for input image. + + Args: + image: key to the image source. + guidance: key to store guidance. + sigma: standard deviation for Gaussian kernel. + number_intensity_ch: channel index. + batched: whether input is batched or not. """ def __init__(self, image="image", guidance="guidance", sigma=2, number_intensity_ch=1, batched=False): - """ - Args: - image: image source. - guidance: key to store guidance. - sigma: standard deviation for Gaussian kernel. - number_intensity_ch: channel index. - batched: defines if input is batched. - """ self.image = image self.guidance = guidance self.sigma = sigma @@ -212,16 +211,15 @@ def __call__(self, data): class FindDiscrepancyRegionsd(Transform): """ Find discrepancy between prediction and actual during click interactions during training. + + Args: + label: key to label source. + pred: key to prediction source. + discrepancy: key to store discrepancies found between label and prediction. + batched: whether input is batched or not. """ def __init__(self, label="label", pred="pred", discrepancy="discrepancy", batched=True): - """ - Args: - label: label source. - pred: prediction source. - discrepancy: key to store discrepancies found between label and prediction. - batched: defines if input is batched. - """ self.label = label self.pred = pred self.discrepancy = discrepancy diff --git a/tests/test_deepgrow_find_all_valid_slicesd.py b/tests/test_deepgrow_find_all_valid_slicesd.py deleted file mode 100644 index 7c2fc2da5c..0000000000 --- a/tests/test_deepgrow_find_all_valid_slicesd.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import numpy as np -from parameterized import parameterized - -from monai.apps.deepgrow.transforms import FindAllValidSlicesd - -TEST_CASE_1 = [ - { - "image": np.array([[[[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]]]]), - "label": np.array( - [[np.zeros((5, 5)), [[0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]] - ), - }, - [1], -] - -TEST_CASE_2 = [ - { - "image": np.array([[[[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]]]]), - "label": np.zeros((1, 3, 5, 5)), - }, - None, -] - - -class TestFindAllValidSlicesd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) - def test_correct_results(self, input_data, expected_result): - result = FindAllValidSlicesd()(input_data) - assert result.get("sids", None) == expected_result - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_deepgrow_spatial_crop_foregroundd.py b/tests/test_deepgrow_spatial_crop_foregroundd.py deleted file mode 100644 index c3b4de314a..0000000000 --- a/tests/test_deepgrow_spatial_crop_foregroundd.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import numpy as np -from parameterized import parameterized - -from monai.apps.deepgrow.transforms import SpatialCropForegroundd - -TEST_CASE_1 = [ - { - "keys": ["img", "label"], - "source_key": "label", - "select_fn": lambda x: x > 0, - "channel_indices": None, - "margin": 0, - "spatial_size": [1, 4, 4], - }, - { - "img": np.array([[[[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]]]]), - "label": np.array([[[[0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]]), - "img_meta_dict": {}, - "label_meta_dict": {}, - }, - np.array([[[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]]), -] - - -class TestSpatialCropForegroundd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1]) - def test_value(self, arguments, input_data, expected_data): - result = SpatialCropForegroundd(**arguments)(input_data) - np.testing.assert_allclose(result["img"], expected_data) - - @parameterized.expand([TEST_CASE_1]) - def test_foreground_position(self, arguments, input_data, _): - result = SpatialCropForegroundd(**arguments)(input_data) - np.testing.assert_allclose(result["img_meta_dict"]["foreground_start_coord"], np.array([0, 1, 1])) - np.testing.assert_allclose(result["img_meta_dict"]["foreground_end_coord"], np.array([1, 4, 4])) - - arguments["start_coord_key"] = "test_start_coord" - arguments["end_coord_key"] = "test_end_coord" - result = SpatialCropForegroundd(**arguments)(input_data) - np.testing.assert_allclose(result["img_meta_dict"]["test_start_coord"], np.array([0, 1, 1])) - np.testing.assert_allclose(result["img_meta_dict"]["test_end_coord"], np.array([1, 4, 4])) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_deepgrow_transforms.py b/tests/test_deepgrow_transforms.py index 0b58788b1d..fbcc93ccee 100644 --- a/tests/test_deepgrow_transforms.py +++ b/tests/test_deepgrow_transforms.py @@ -12,20 +12,56 @@ import unittest import numpy as np +from parameterized import parameterized from monai.apps.deepgrow.transforms import ( AddGuidanceFromPointsd, AddGuidanceSignald, AddInitialSeedPointd, Fetch2DSliced, + FindAllValidSlicesd, FindDiscrepancyRegionsd, ResizeGuidanced, RestoreCroppedLabeld, + SpatialCropForegroundd, SpatialCropGuidanced, ) from monai.transforms import AddChanneld -DATA = { +DATA_1 = { + "image": np.array([[[[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]]]]), + "label": np.array([[[[0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]]), + "image_meta_dict": {}, + "label_meta_dict": {}, +} + +DATA_2 = { + "image": np.array( + [ + [ + [[1, 2, 3, 2, 1], [1, 1, 3, 2, 1], [0, 0, 0, 0, 0], [1, 1, 1, 2, 1], [0, 2, 2, 2, 1]], + [[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]], + ] + ] + ), + "label": np.array( + [ + [ + [[0, 0, 1, 0, 0], [0, 1, 1, 1, 0], [1, 1, 1, 1, 1], [0, 1, 1, 1, 0], [0, 0, 1, 0, 0]], + [[0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]], + ] + ] + ), + "guidance": [[[1, 0, 2, 2], [1, 1, 2, 2]], [[-1, -1, -1, -1], [-1, -1, -1, -1]]], +} + +DATA_3 = { + "image": np.array([[[[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]]]]), + "label": np.array([[[[0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]]), + "pred": np.array([[[[0, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 1, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]]), +} + +DATA_INFER = { "image": np.array([[[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]]]), "label": np.array([[[0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]), "image_meta_dict": {}, @@ -34,27 +70,138 @@ "background": [], } +FIND_SLICE_TEST_CASE_1 = [ + {"label": "label", "sids": "sids"}, + DATA_1, + [0], +] + +FIND_SLICE_TEST_CASE_2 = [ + {"label": "label", "sids": "sids"}, + DATA_2, + [0, 1], +] + +CROP_TEST_CASE_1 = [ + { + "keys": ["image", "label"], + "source_key": "label", + "select_fn": lambda x: x > 0, + "channel_indices": None, + "margin": 0, + "spatial_size": [1, 4, 4], + }, + DATA_1, + np.array([[[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]]), +] + +ADD_INITIAL_POINT_TEST_CASE_1 = [ + {"label": "label", "guidance": "guidance", "sids": "sids"}, + DATA_1, + [[[1, 0, 2, 2]], [[-1, -1, -1, -1]]], +] + +ADD_GUIDANCE_TEST_CASE_1 = [ + {"image": "image", "guidance": "guidance"}, + DATA_2, + np.array( + [ + [ + [[1, 2, 3, 2, 1], [1, 1, 3, 2, 1], [0, 0, 0, 0, 0], [1, 1, 1, 2, 1], [0, 2, 2, 2, 1]], + [[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]], + ], + [ + [ + [0.0, 0.28531367, 0.46186933, 0.28531367, 0.0], + [0.28531399, 0.59972864, 0.79429233, 0.59972864, 0.28531399], + [0.46186963, 0.79429233, 1.0, 0.79429233, 0.46186963], + [0.28531399, 0.59972864, 0.79429233, 0.59972864, 0.28531399], + [0.0, 0.28531367, 0.46186933, 0.28531367, 0.0], + ], + [ + [0.0, 0.28531367, 0.46186933, 0.28531367, 0.0], + [0.28531399, 0.59972864, 0.79429233, 0.59972864, 0.28531399], + [0.46186963, 0.79429233, 1.0, 0.79429233, 0.46186963], + [0.28531399, 0.59972864, 0.79429233, 0.59972864, 0.28531399], + [0.0, 0.28531367, 0.46186933, 0.28531367, 0.0], + ], + ], + [ + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + ], + ] + ), +] + +FIND_DISCREPANCY_TEST_CASE_1 = [ + {"label": "label", "pred": "pred", "discrepancy": "discrepancy"}, + DATA_3, + np.array( + [ + [ + [[[0, 0, 0, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]], + [[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]], + ] + ] + ), +] + + +class TestFindAllValidSlicesd(unittest.TestCase): + @parameterized.expand([FIND_SLICE_TEST_CASE_1, FIND_SLICE_TEST_CASE_2]) + def test_correct_results(self, arguments, input_data, expected_result): + result = FindAllValidSlicesd(**arguments)(input_data) + np.testing.assert_allclose(result[arguments["sids"]], expected_result) + + +class TestSpatialCropForegroundd(unittest.TestCase): + @parameterized.expand([CROP_TEST_CASE_1]) + def test_correct_results(self, arguments, input_data, expected_result): + result = SpatialCropForegroundd(**arguments)(input_data) + np.testing.assert_allclose(result["image"], expected_result) + + @parameterized.expand([CROP_TEST_CASE_1]) + def test_foreground_position(self, arguments, input_data, _): + result = SpatialCropForegroundd(**arguments)(input_data) + np.testing.assert_allclose(result["image_meta_dict"]["foreground_start_coord"], np.array([0, 1, 1])) + np.testing.assert_allclose(result["image_meta_dict"]["foreground_end_coord"], np.array([1, 4, 4])) + + arguments["start_coord_key"] = "test_start_coord" + arguments["end_coord_key"] = "test_end_coord" + result = SpatialCropForegroundd(**arguments)(input_data) + np.testing.assert_allclose(result["image_meta_dict"]["test_start_coord"], np.array([0, 1, 1])) + np.testing.assert_allclose(result["image_meta_dict"]["test_end_coord"], np.array([1, 4, 4])) + + +class TestAddInitialSeedPointd(unittest.TestCase): + @parameterized.expand([ADD_INITIAL_POINT_TEST_CASE_1]) + def test_correct_results(self, arguments, input_data, expected_result): + seed = 0 + add_fn = AddInitialSeedPointd(**arguments) + add_fn.set_random_state(seed) + result = add_fn(input_data) + np.testing.assert_allclose(result[arguments["guidance"]], expected_result) + + +class TestAddGuidanceSignald(unittest.TestCase): + @parameterized.expand([ADD_GUIDANCE_TEST_CASE_1]) + def test_correct_results(self, arguments, input_data, expected_result): + add_fn = AddGuidanceSignald(**arguments) + result = add_fn(input_data) + np.testing.assert_allclose(result["image"], expected_result, rtol=1e-5) + + +class TestFindDiscrepancyRegionsd(unittest.TestCase): + @parameterized.expand([FIND_DISCREPANCY_TEST_CASE_1]) + def test_correct_results(self, arguments, input_data, expected_result): + result = FindDiscrepancyRegionsd(**arguments)(input_data) + np.testing.assert_allclose(result[arguments["discrepancy"]], expected_result) -class TestTransforms(unittest.TestCase): - def test_addinitialseedpointd_addguidancesignald(self): - result = AddChanneld(keys=("image", "label"))(DATA.copy()) - result = AddInitialSeedPointd(label="label", guidance="guidance", sids="sids")(result) - assert len(result["guidance"]) - - result = AddGuidanceSignald(image="image", guidance="guidance")(result) - assert result["image"].shape == (3, 1, 5, 5) - - def test_finddiscrepancyregionsd(self): - result = DATA.copy() - result["pred"] = np.array( - [[[0, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]] - ) - result = AddChanneld(keys=("image", "label", "pred"))(result) - result = FindDiscrepancyRegionsd(label="label", pred="pred", discrepancy="discrepancy")(result) - assert np.sum(result["discrepancy"]) > 0 +class TestTransforms(unittest.TestCase): def test_inference(self): - result = DATA.copy() + result = DATA_INFER.copy() result["image_meta_dict"]["spatial_shape"] = (5, 5, 1) result["image_meta_dict"]["original_affine"] = (0, 0) From d5a6d6aab118b4c1a1de18fd5fd48dbda8463185 Mon Sep 17 00:00:00 2001 From: Yuan-Ting Hsieh Date: Wed, 10 Feb 2021 18:09:44 -0800 Subject: [PATCH 19/20] remove tests in min_tests.py Signed-off-by: Yuan-Ting Hsieh --- tests/min_tests.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/min_tests.py b/tests/min_tests.py index 6b2633b030..b24da7dc21 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -107,8 +107,6 @@ def run_testsuit(): "test_deepgrow_dataset", "test_deepgrow_interaction", "test_deepgrow_transforms", - "test_deepgrow_spatial_crop_foregroundd", - "test_deepgrow_find_all_valid_slicesd", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" From 3893a62497f6299fe6033a3d51ea4232b0be978f Mon Sep 17 00:00:00 2001 From: Sachidanand Alle Date: Thu, 11 Feb 2021 02:31:24 -0800 Subject: [PATCH 20/20] fix review comments Signed-off-by: Sachidanand Alle --- docs/source/apps.rst | 2 +- monai/apps/deepgrow/dataset.py | 6 ++---- monai/apps/deepgrow/interaction.py | 2 +- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/docs/source/apps.rst b/docs/source/apps.rst index e5acbd226d..4f79559e72 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -30,7 +30,7 @@ Applications .. autofunction:: download_and_extract `Deepgrow` ------------ +---------- .. automodule:: monai.apps.deepgrow.dataset .. autofunction:: create_dataset diff --git a/monai/apps/deepgrow/dataset.py b/monai/apps/deepgrow/dataset.py index 849b44bed3..66796f211e 100644 --- a/monai/apps/deepgrow/dataset.py +++ b/monai/apps/deepgrow/dataset.py @@ -117,14 +117,12 @@ def create_dataset( def _default_transforms(keys, pixdim): mode = [GridSampleMode.BILINEAR, GridSampleMode.NEAREST] if len(keys) == 2 else [GridSampleMode.BILINEAR] - transforms = [ + return Compose([ LoadImaged(keys=keys), AsChannelFirstd(keys=keys), Spacingd(keys=keys, pixdim=pixdim, mode=mode), Orientationd(keys=keys, axcodes="RAS"), - ] - - return Compose(transforms) + ]) def _save_data_2d(vol_idx, data, keys, dataset_dir, relative_path): diff --git a/monai/apps/deepgrow/interaction.py b/monai/apps/deepgrow/interaction.py index 0b10cb6588..9a77473d6f 100644 --- a/monai/apps/deepgrow/interaction.py +++ b/monai/apps/deepgrow/interaction.py @@ -20,7 +20,7 @@ class Interaction: """ - Deepgrow Training/Evaluation iteration method with interactions (simulation of clicks) support for image and label. + Ignite handler used to introduce interactions (simulation of clicks) for Deepgrow Training/Evaluation. Args: transforms: execute additional transformation during every iteration (before train).