diff --git a/docs/source/apps.rst b/docs/source/apps.rst index f2afd93836..4f79559e72 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -28,3 +28,35 @@ Applications .. autofunction:: extractall .. autofunction:: download_and_extract + +`Deepgrow` +---------- + +.. automodule:: monai.apps.deepgrow.dataset +.. autofunction:: create_dataset + +.. automodule:: monai.apps.deepgrow.interaction +.. autoclass:: Interaction + :members: + +.. automodule:: monai.apps.deepgrow.transforms +.. autoclass:: AddInitialSeedPointd + :members: +.. autoclass:: AddGuidanceSignald + :members: +.. autoclass:: AddRandomGuidanced + :members: +.. autoclass:: AddGuidanceFromPointsd + :members: +.. autoclass:: SpatialCropForegroundd + :members: +.. autoclass:: SpatialCropGuidanced + :members: +.. autoclass:: RestoreCroppedLabeld + :members: +.. autoclass:: FindDiscrepancyRegionsd + :members: +.. autoclass:: FindAllValidSlicesd + :members: +.. autoclass:: Fetch2DSliced + :members: diff --git a/monai/apps/deepgrow/__init__.py b/monai/apps/deepgrow/__init__.py new file mode 100644 index 0000000000..d0044e3563 --- /dev/null +++ b/monai/apps/deepgrow/__init__.py @@ -0,0 +1,10 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/monai/apps/deepgrow/dataset.py b/monai/apps/deepgrow/dataset.py new file mode 100644 index 0000000000..66796f211e --- /dev/null +++ b/monai/apps/deepgrow/dataset.py @@ -0,0 +1,268 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from typing import Dict, List + +import numpy as np + +from monai.transforms import AsChannelFirstd, Compose, LoadImaged, Orientationd, Spacingd +from monai.utils import GridSampleMode + + +def create_dataset( + datalist, + output_dir, + dimension, + pixdim, + keys=("image", "label"), + base_dir=None, + limit=0, + relative_path=False, + transforms=None, +) -> List[Dict]: + """ + Utility to pre-process and create dataset list for Deepgrow training over on existing one. + The input data list is normally a list of images and labels (3D volume) that needs pre-processing + for Deepgrow training pipeline. + + Args: + datalist: A generic dataset with a length property which normally contains a list of data dictionary. + For example, typical input data can be a list of dictionaries:: + + [{'image': 'img1.nii', 'label': 'label1.nii'}] + + output_dir: target directory to store the training data for Deepgrow Training + pixdim: output voxel spacing. + dimension: dimension for Deepgrow training. It can be 2 or 3. + keys: Image and Label keys in input datalist. Defaults to 'image' and 'label' + base_dir: base directory in case related path is used for the keys in datalist. Defaults to None. + limit: limit number of inputs for pre-processing. Defaults to 0 (no limit). + relative_path: output keys values should be based on relative path. Defaults to False. + transforms: explicit transforms to execute operations on input data. + + Raises: + ValueError: When ``dimension`` is not one of [2, 3] + ValueError: When ``datalist`` is Empty + + Example:: + + datalist = create_dataset( + datalist=[{'image': 'img1.nii', 'label': 'label1.nii'}], + base_dir=None, + output_dir=output_2d, + dimension=2, + keys=('image', 'label') + pixdim=(1.0, 1.0), + limit=0, + relative_path=True + ) + + print(datalist[0]["image"], datalist[0]["label"]) + """ + + if dimension not in [2, 3]: + raise ValueError("Dimension can be only 2 or 3 as Deepgrow supports only 2D/3D Training") + + if not len(datalist): + raise ValueError("Input Datalist is empty") + + if not isinstance(keys, list) and not isinstance(keys, tuple): + keys = [keys] + + transforms = _default_transforms(keys, pixdim) if transforms is None else transforms + new_datalist = [] + for idx in range(len(datalist)): + if limit and idx >= limit: + break + + image = datalist[idx][keys[0]] + label = datalist[idx].get(keys[1]) if len(keys) > 1 else None + if base_dir: + image = os.path.join(base_dir, image) + label = os.path.join(base_dir, label) if label else None + + image = os.path.abspath(image) + label = os.path.abspath(label) if label else None + + logging.info("Image: {}; Label: {}".format(image, label if label else None)) + if dimension == 2: + data = _save_data_2d( + vol_idx=idx, + data=transforms({"image": image, "label": label}), + keys=("image", "label"), + dataset_dir=output_dir, + relative_path=relative_path, + ) + else: + data = _save_data_3d( + vol_idx=idx, + data=transforms({"image": image, "label": label}), + keys=("image", "label"), + dataset_dir=output_dir, + relative_path=relative_path, + ) + new_datalist.extend(data) + return new_datalist + + +def _default_transforms(keys, pixdim): + mode = [GridSampleMode.BILINEAR, GridSampleMode.NEAREST] if len(keys) == 2 else [GridSampleMode.BILINEAR] + return Compose([ + LoadImaged(keys=keys), + AsChannelFirstd(keys=keys), + Spacingd(keys=keys, pixdim=pixdim, mode=mode), + Orientationd(keys=keys, axcodes="RAS"), + ]) + + +def _save_data_2d(vol_idx, data, keys, dataset_dir, relative_path): + vol_image = data[keys[0]] + vol_label = data.get(keys[1]) + data_list = [] + + if len(vol_image.shape) == 4: + logging.info( + "4D-Image, pick only first series; Image: {}; Label: {}".format( + vol_image.shape, vol_label.shape if vol_label else None + ) + ) + vol_image = vol_image[0] + vol_image = np.moveaxis(vol_image, -1, 0) + + image_count = 0 + label_count = 0 + unique_labels_count = 0 + for sid in range(vol_image.shape[0]): + image = vol_image[sid, ...] + label = vol_label[sid, ...] if vol_label is not None else None + + if vol_label is not None and np.sum(label) == 0: + continue + + image_file_prefix = "vol_idx_{:0>4d}_slice_{:0>3d}".format(vol_idx, sid) + image_file = os.path.join(dataset_dir, "images", image_file_prefix) + image_file += ".npy" + + os.makedirs(os.path.join(dataset_dir, "images"), exist_ok=True) + np.save(image_file, image) + image_count += 1 + + # Test Data + if vol_label is None: + data_list.append( + { + "image": image_file.replace(dataset_dir + "/", "") if relative_path else image_file, + } + ) + continue + + # For all Labels + unique_labels = np.unique(label.flatten()) + unique_labels = unique_labels[unique_labels != 0] + unique_labels_count = max(unique_labels_count, len(unique_labels)) + + for idx in unique_labels: + label_file_prefix = "{}_region_{:0>2d}".format(image_file_prefix, int(idx)) + label_file = os.path.join(dataset_dir, "labels", label_file_prefix) + label_file += ".npy" + + os.makedirs(os.path.join(dataset_dir, "labels"), exist_ok=True) + curr_label = (label == idx).astype(np.float32) + np.save(label_file, curr_label) + + label_count += 1 + data_list.append( + { + "image": image_file.replace(dataset_dir + "/", "") if relative_path else image_file, + "label": label_file.replace(dataset_dir + "/", "") if relative_path else label_file, + "region": int(idx), + } + ) + + logging.info( + "{} => Image Shape: {} => {}; Label Shape: {} => {}; Unique Labels: {}".format( + vol_idx, + vol_image.shape, + image_count, + vol_label.shape if vol_label is not None else None, + label_count, + unique_labels_count, + ) + ) + return data_list + + +def _save_data_3d(vol_idx, data, keys, dataset_dir, relative_path): + vol_image = data[keys[0]] + vol_label = data.get(keys[1]) + data_list = [] + + if len(vol_image.shape) == 4: + logging.info("4D-Image, pick only first series; Image: {}; Label: {}".format(vol_image.shape, vol_label.shape)) + vol_image = vol_image[0] + vol_image = np.moveaxis(vol_image, -1, 0) + + image_count = 0 + label_count = 0 + unique_labels_count = 0 + + image_file_prefix = "vol_idx_{:0>4d}".format(vol_idx) + image_file = os.path.join(dataset_dir, "images", image_file_prefix) + image_file += ".npy" + + os.makedirs(os.path.join(dataset_dir, "images"), exist_ok=True) + np.save(image_file, vol_image) + image_count += 1 + + # Test Data + if vol_label is None: + data_list.append( + { + "image": image_file.replace(dataset_dir + "/", "") if relative_path else image_file, + } + ) + else: + # For all Labels + unique_labels = np.unique(vol_label.flatten()) + unique_labels = unique_labels[unique_labels != 0] + unique_labels_count = max(unique_labels_count, len(unique_labels)) + + for idx in unique_labels: + label_file_prefix = "{}_region_{:0>2d}".format(image_file_prefix, int(idx)) + label_file = os.path.join(dataset_dir, "labels", label_file_prefix) + label_file += ".npy" + + curr_label = (vol_label == idx).astype(np.float32) + os.makedirs(os.path.join(dataset_dir, "labels"), exist_ok=True) + np.save(label_file, curr_label) + + label_count += 1 + data_list.append( + { + "image": image_file.replace(dataset_dir + "/", "") if relative_path else image_file, + "label": label_file.replace(dataset_dir + "/", "") if relative_path else label_file, + "region": int(idx), + } + ) + + logging.info( + "{} => Image Shape: {} => {}; Label Shape: {} => {}; Unique Labels: {}".format( + vol_idx, + vol_image.shape, + image_count, + vol_label.shape if vol_label is not None else None, + label_count, + unique_labels_count, + ) + ) + return data_list diff --git a/monai/apps/deepgrow/interaction.py b/monai/apps/deepgrow/interaction.py new file mode 100644 index 0000000000..9a77473d6f --- /dev/null +++ b/monai/apps/deepgrow/interaction.py @@ -0,0 +1,81 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +from typing import Dict + +import torch + +from monai.engines.utils import CommonKeys +from monai.engines.workflow import Engine, Events +from monai.transforms import Compose + + +class Interaction: + """ + Ignite handler used to introduce interactions (simulation of clicks) for Deepgrow Training/Evaluation. + + Args: + transforms: execute additional transformation during every iteration (before train). + Typically, several Tensor based transforms composed by `Compose`. + max_interactions: maximum number of interactions per iteration + train: training or evaluation + key_probability: field name to fill probability for every interaction + """ + + def __init__(self, transforms, max_interactions: int, train: bool, key_probability: str = "probability") -> None: + self.transforms = transforms + self.max_interactions = max_interactions + self.train = train + self.key_probability = key_probability + + if not isinstance(self.transforms, Compose): + transforms = [] + for t in self.transforms: + transforms.append(self.init_external_class(t)) + self.transforms = Compose(transforms) + + @staticmethod + def init_external_class(config_dict): + class_args = None if config_dict.get("args") is None else dict(config_dict.get("args")) + class_path = config_dict.get("path", config_dict["name"]) + + module_name, class_name = class_path.rsplit(".", 1) + m = importlib.import_module(module_name) + c = getattr(m, class_name) + return c(**class_args) if class_args else c() + + def attach(self, engine: Engine) -> None: + if not engine.has_event_handler(self, Events.ITERATION_STARTED): + engine.add_event_handler(Events.ITERATION_STARTED, self) + + def __call__(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): + if batchdata is None: + raise ValueError("Must provide batch data for current iteration.") + + for j in range(self.max_interactions): + inputs, _ = engine.prepare_batch(batchdata) + inputs = inputs.to(engine.state.device) + + engine.network.eval() + with torch.no_grad(): + if engine.amp: + with torch.cuda.amp.autocast(): + predictions = engine.inferer(inputs, engine.network) + else: + predictions = engine.inferer(inputs, engine.network) + + batchdata.update({CommonKeys.PRED: predictions}) + batchdata[self.key_probability] = torch.as_tensor( + ([1.0 - ((1.0 / self.max_interactions) * j)] if self.train else [1.0]) * len(inputs) + ) + batchdata = self.transforms(batchdata) + + return engine._iteration(engine, batchdata) diff --git a/monai/apps/deepgrow/transforms.py b/monai/apps/deepgrow/transforms.py new file mode 100644 index 0000000000..e034f2e686 --- /dev/null +++ b/monai/apps/deepgrow/transforms.py @@ -0,0 +1,806 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +A collection of "vanilla" transforms for spatial operations +https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design +""" +import json +from typing import Optional, Sequence, Union + +import numpy as np + +from monai.config import KeysCollection +from monai.transforms import Resize, SpatialCrop +from monai.transforms.compose import MapTransform, Randomizable, Transform +from monai.transforms.spatial.dictionary import InterpolateModeSequence +from monai.transforms.utils import generate_spatial_bounding_box +from monai.utils import InterpolateMode, ensure_tuple_rep, min_version, optional_import + +measure, _ = optional_import("skimage.measure", "0.14.2", min_version) +distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt") +gaussian_filter, _ = optional_import("scipy.ndimage", name="gaussian_filter") + + +# Transforms to support Training for Deepgrow models +class FindAllValidSlicesd(Transform): + """ + Find/List all valid slices in the label. + Label is assumed to be a 4D Volume with shape CDHW, where C=1. + + Args: + label: key to the label source. + sids: key to store slices indices having valid label map. + """ + + def __init__(self, label="label", sids="sids"): + self.label = label + self.sids = sids + + def _apply(self, label): + sids = [] + for sid in range(label.shape[1]): # Assume channel is first + if np.sum(label[0][sid]) == 0: + continue + sids.append(sid) + return np.asarray(sids) + + def __call__(self, data): + d = dict(data) + label = d[self.label] + if label.shape[0] != 1: + raise ValueError("Only supports single channel labels!") + + if len(label.shape) != 4: # only for 3D + raise ValueError("Only supports label with shape CDHW!") + + sids = self._apply(label) + if sids is not None and len(sids): + d[self.sids] = sids + return d + + +class AddInitialSeedPointd(Randomizable, Transform): + """ + Add random guidance as initial seed point for a given label. + + Args: + label: label source. + guidance: key to store guidance. + sids: key that represents list of valid slice indices for the given label. + sid: key that represents the slice to add initial seed point. If not present, random sid will be chosen. + connected_regions: maximum connected regions to use for adding initial points. + """ + + def __init__(self, label="label", guidance="guidance", sids="sids", sid="sid", connected_regions=5): + self.label = label + self.sids = sids + self.sid = sid + self.guidance = guidance + self.connected_regions = connected_regions + + def randomize(self, data=None): + pass + + def _apply(self, label, sid): + dimensions = 3 if len(label.shape) > 3 else 2 + default_guidance = [-1] * (dimensions + 1) + + dims = dimensions + if sid is not None and dimensions == 3: + dims = 2 + label = label[0][sid][np.newaxis] # Assume channel is first + + label = (label > 0.5).astype(np.float32) + blobs_labels = measure.label(label.astype(int), background=0) if dims == 2 else label + assert np.max(blobs_labels) > 0, "Not a valid Label" + + pos_guidance = [] + for ridx in range(1, 2 if dims == 3 else self.connected_regions + 1): + if dims == 2: + label = (blobs_labels == ridx).astype(np.float32) + if np.sum(label) == 0: + pos_guidance.append(default_guidance) + continue + + distance = distance_transform_cdt(label).flatten() + probability = np.exp(distance) - 1.0 + + idx = np.where(label.flatten() > 0)[0] + seed = self.R.choice(idx, size=1, p=probability[idx] / np.sum(probability[idx])) + dst = distance[seed] + + g = np.asarray(np.unravel_index(seed, label.shape)).transpose().tolist()[0] + g[0] = dst[0] # for debug + if dimensions == 2 or dims == 3: + pos_guidance.append(g) + else: + pos_guidance.append([g[0], sid, g[-2], g[-1]]) + + return np.asarray([pos_guidance, [default_guidance] * len(pos_guidance)]) + + def __call__(self, data): + sid = data.get(self.sid) + sids = data.get(self.sids) + if sids is not None: + if sid is None or sid not in sids: + sid = self.R.choice(sids, replace=False) + else: + sid = None + data[self.guidance] = self._apply(data[self.label], sid) + return data + + +class AddGuidanceSignald(Transform): + """ + Add Guidance signal for input image. + + Based on the "guidance" points, apply gaussian to them and add them as new channel for input image. + + Args: + image: key to the image source. + guidance: key to store guidance. + sigma: standard deviation for Gaussian kernel. + number_intensity_ch: channel index. + batched: whether input is batched or not. + """ + + def __init__(self, image="image", guidance="guidance", sigma=2, number_intensity_ch=1, batched=False): + self.image = image + self.guidance = guidance + self.sigma = sigma + self.number_intensity_ch = number_intensity_ch + self.batched = batched + + def _get_signal(self, image, guidance): + dimensions = 3 if len(image.shape) > 3 else 2 + guidance = guidance.tolist() if isinstance(guidance, np.ndarray) else guidance + if dimensions == 3: + signal = np.zeros((len(guidance), image.shape[-3], image.shape[-2], image.shape[-1]), dtype=np.float32) + else: + signal = np.zeros((len(guidance), image.shape[-2], image.shape[-1]), dtype=np.float32) + + sshape = signal.shape + for i in range(len(guidance)): + for point in guidance[i]: + if np.any(np.asarray(point) < 0): + continue + + if dimensions == 3: + p1 = max(0, min(int(point[-3]), sshape[-3] - 1)) + p2 = max(0, min(int(point[-2]), sshape[-2] - 1)) + p3 = max(0, min(int(point[-1]), sshape[-1] - 1)) + signal[i, p1, p2, p3] = 1.0 + else: + p1 = max(0, min(int(point[-2]), sshape[-2] - 1)) + p2 = max(0, min(int(point[-1]), sshape[-1] - 1)) + signal[i, p1, p2] = 1.0 + + if np.max(signal[i]) > 0: + signal[i] = gaussian_filter(signal[i], sigma=self.sigma) + signal[i] = (signal[i] - np.min(signal[i])) / (np.max(signal[i]) - np.min(signal[i])) + return signal + + def _apply(self, image, guidance): + if not self.batched: + signal = self._get_signal(image, guidance) + return np.concatenate([image, signal], axis=0) + + images = [] + for i, g in zip(image, guidance): + i = i[0 : 0 + self.number_intensity_ch, ...] + signal = self._get_signal(i, g) + images.append(np.concatenate([i, signal], axis=0)) + return images + + def __call__(self, data): + image = data[self.image] + guidance = data[self.guidance] + + data[self.image] = self._apply(image, guidance) + return data + + +class FindDiscrepancyRegionsd(Transform): + """ + Find discrepancy between prediction and actual during click interactions during training. + + Args: + label: key to label source. + pred: key to prediction source. + discrepancy: key to store discrepancies found between label and prediction. + batched: whether input is batched or not. + """ + + def __init__(self, label="label", pred="pred", discrepancy="discrepancy", batched=True): + self.label = label + self.pred = pred + self.discrepancy = discrepancy + self.batched = batched + + @staticmethod + def disparity(label, pred): + label = (label > 0.5).astype(np.float32) + pred = (pred > 0.5).astype(np.float32) + disparity = label - pred + + pos_disparity = (disparity > 0).astype(np.float32) + neg_disparity = (disparity < 0).astype(np.float32) + return [pos_disparity, neg_disparity] + + def _apply(self, label, pred): + if not self.batched: + return self.disparity(label, pred) + + disparity = [] + for la, pr in zip(label, pred): + disparity.append(self.disparity(la, pr)) + return disparity + + def __call__(self, data): + label = data[self.label] + pred = data[self.pred] + + data[self.discrepancy] = self._apply(label, pred) + return data + + +class AddRandomGuidanced(Randomizable, Transform): + """ + Add random guidance based on discrepancies that were found between label and prediction. + """ + + def __init__(self, guidance="guidance", discrepancy="discrepancy", probability="probability", batched=True): + """ + Args: + guidance: guidance source. + discrepancy: key that represents discrepancies found between label and prediction. + probability: key that represents click/interaction probability. + batched: defines if input is batched. + """ + self.guidance = guidance + self.discrepancy = discrepancy + self.probability = probability + self.batched = batched + + def randomize(self, data=None): + pass + + def find_guidance(self, discrepancy): + distance = distance_transform_cdt(discrepancy).flatten() + probability = np.exp(distance) - 1.0 + idx = np.where(discrepancy.flatten() > 0)[0] + + if np.sum(discrepancy > 0) > 0: + seed = self.R.choice(idx, size=1, p=probability[idx] / np.sum(probability[idx])) + dst = distance[seed] + + g = np.asarray(np.unravel_index(seed, discrepancy.shape)).transpose().tolist()[0] + g[0] = dst[0] + return g + return None + + def add_guidance(self, discrepancy, probability): + will_interact = self.R.choice([True, False], p=[probability, 1.0 - probability]) + if not will_interact: + return None, None + + pos_discr = discrepancy[0] + neg_discr = discrepancy[1] + + can_be_positive = np.sum(pos_discr) > 0 + can_be_negative = np.sum(neg_discr) > 0 + correct_pos = np.sum(pos_discr) >= np.sum(neg_discr) + + if correct_pos and can_be_positive: + return self.find_guidance(pos_discr), None + + if not correct_pos and can_be_negative: + return None, self.find_guidance(neg_discr) + return None, None + + def _apply(self, guidance, discrepancy, probability): + guidance = guidance.tolist() if isinstance(guidance, np.ndarray) else guidance + if not self.batched: + pos, neg = self.add_guidance(discrepancy, probability) + if pos: + guidance[0].append(pos) + guidance[1].append([-1] * len(pos)) + if neg: + guidance[0].append([-1] * len(neg)) + guidance[1].append(neg) + else: + for g, d, p in zip(guidance, discrepancy, probability): + pos, neg = self.add_guidance(d, p) + if pos: + g[0].append(pos) + g[1].append([-1] * len(pos)) + if neg: + g[0].append([-1] * len(neg)) + g[1].append(neg) + return np.asarray(guidance) + + def __call__(self, data): + guidance = data[self.guidance] + discrepancy = data[self.discrepancy] + probability = data[self.probability] + + data[self.guidance] = self._apply(guidance, discrepancy, probability) + return data + + +class SpatialCropForegroundd(MapTransform): + """ + Crop only the foreground object of the expected images. + Note that if the bounding box is smaller than spatial size in all dimensions then we will crop the + object using box's center and spatial_size. + + The typical usage is to help training and evaluation if the valid part is small in the whole medical image. + The valid part can be determined by any field in the data with `source_key`, for example: + - Select values > 0 in image field as the foreground and crop on all fields specified by `keys`. + - Select label = 3 in label field as the foreground to crop on all fields specified by `keys`. + - Select label > 0 in the third channel of a One-Hot label field as the foreground to crop all `keys` fields. + Users can define arbitrary function to select expected foreground from the whole source image or specified + channels. And it can also add margin to every dim of the bounding box of foreground object. + """ + + def __init__( + self, + keys, + source_key: str, + spatial_size, + select_fn=lambda x: x > 0, + channel_indices=None, + margin: int = 0, + meta_key_postfix="meta_dict", + start_coord_key: str = "foreground_start_coord", + end_coord_key: str = "foreground_end_coord", + original_shape_key: str = "foreground_original_shape", + cropped_shape_key: str = "foreground_cropped_shape", + ) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + source_key: data source to generate the bounding box of foreground, can be image or label, etc. + spatial_size: minimal spatial size of the image patch e.g. [128, 128, 128] to fit in. + select_fn: function to select expected foreground, default is to select values > 0. + channel_indices: if defined, select foreground only on the specified channels + of image. if None, select foreground on the whole image. + margin: add margin value to spatial dims of the bounding box, if only 1 value provided, use it for all dims. + meta_key_postfix: use `key_{postfix}` to to fetch the meta data according to the key data, + default is `meta_dict`, the meta data is a dictionary object. + For example, to handle key `image`, read/write affine matrices from the + metadata `image_meta_dict` dictionary's `affine` field. + start_coord_key: key to record the start coordinate of spatial bounding box for foreground. + end_coord_key: key to record the end coordinate of spatial bounding box for foreground. + original_shape_key: key to record original shape for foreground. + cropped_shape_key: key to record cropped shape for foreground. + """ + super().__init__(keys) + + self.source_key = source_key + self.spatial_size = list(spatial_size) + self.select_fn = select_fn + self.channel_indices = channel_indices + self.margin = margin + self.meta_key_postfix = meta_key_postfix + self.start_coord_key = start_coord_key + self.end_coord_key = end_coord_key + self.original_shape_key = original_shape_key + self.cropped_shape_key = cropped_shape_key + + def __call__(self, data): + d = dict(data) + box_start, box_end = generate_spatial_bounding_box( + d[self.source_key], self.select_fn, self.channel_indices, self.margin + ) + + center = np.mean([box_start, box_end], axis=0).astype(int).tolist() + current_size = np.subtract(box_end, box_start).astype(int).tolist() + + if np.all(np.less(current_size, self.spatial_size)): + cropper = SpatialCrop(roi_center=center, roi_size=self.spatial_size) + box_start = cropper.roi_start + box_end = cropper.roi_end + else: + cropper = SpatialCrop(roi_start=box_start, roi_end=box_end) + + for key in self.keys: + meta_key = f"{key}_{self.meta_key_postfix}" + d[meta_key][self.start_coord_key] = box_start + d[meta_key][self.end_coord_key] = box_end + d[meta_key][self.original_shape_key] = d[key].shape + + image = cropper(d[key]) + d[meta_key][self.cropped_shape_key] = image.shape + d[key] = image + return d + + +# Transforms to support Inference for Deepgrow models +class SpatialCropGuidanced(MapTransform): + """ + Crop image based on user guidance/clicks with minimal spatial size. + """ + + def __init__( + self, + keys, + guidance: str, + spatial_size, + spatial_size_key: str = "spatial_size", + margin: int = 20, + meta_key_postfix="meta_dict", + start_coord_key: str = "foreground_start_coord", + end_coord_key: str = "foreground_end_coord", + original_shape_key: str = "foreground_original_shape", + cropped_shape_key: str = "foreground_cropped_shape", + ) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + guidance: user input clicks/guidance used to generate the bounding box of foreground + spatial_size: minimal spatial size of the image patch e.g. [128, 128, 128] to fit in. + margin: add margin value to spatial dims of the bounding box, if only 1 value provided, use it for all dims. + meta_key_postfix: use `key_{postfix}` to to fetch the meta data according to the key data, + default is `meta_dict`, the meta data is a dictionary object. + For example, to handle key `image`, read/write affine matrices from the + metadata `image_meta_dict` dictionary's `affine` field. + start_coord_key: key to record the start coordinate of spatial bounding box for foreground. + end_coord_key: key to record the end coordinate of spatial bounding box for foreground. + original_shape_key: key to record original shape for foreground. + cropped_shape_key: key to record cropped shape for foreground. + """ + super().__init__(keys) + + self.guidance = guidance + self.spatial_size = list(spatial_size) + self.spatial_size_key = spatial_size_key + self.margin = margin + self.meta_key_postfix = meta_key_postfix + self.start_coord_key = start_coord_key + self.end_coord_key = end_coord_key + self.original_shape_key = original_shape_key + self.cropped_shape_key = cropped_shape_key + + def bounding_box(self, points, img_shape): + ndim = len(img_shape) + margin = ensure_tuple_rep(self.margin, ndim) + for m in margin: + if m < 0: + raise ValueError("margin value should not be negative number.") + + box_start = [0] * ndim + box_end = [0] * ndim + + for di in range(ndim): + dt = points[..., di] + min_d = max(min(dt - margin[di]), 0) + max_d = min(img_shape[di] - 1, max(dt + margin[di])) + box_start[di], box_end[di] = min_d, max_d + return box_start, box_end + + def __call__(self, data): + guidance = data[self.guidance] + box_start = None + for key in self.keys: + box_start, box_end = self.bounding_box(np.array(guidance[0] + guidance[1]), data[key].shape[1:]) + center = np.mean([box_start, box_end], axis=0).astype(int).tolist() + spatial_size = data.get(self.spatial_size_key, self.spatial_size) + + current_size = np.absolute(np.subtract(box_start, box_end)).astype(int).tolist() + spatial_size = spatial_size[-len(current_size) :] + if len(spatial_size) < len(current_size): # 3D spatial_size = [256,256] (include all slices in such case) + diff = len(current_size) - len(spatial_size) + spatial_size = list(data[key].shape[1 : (1 + diff)]) + spatial_size + + if np.all(np.less(current_size, spatial_size)): + if len(center) == 3: + center[0] = center[0] + (spatial_size[0] // 2 - center[0]) + cropper = SpatialCrop(roi_center=center, roi_size=spatial_size) + else: + cropper = SpatialCrop(roi_start=box_start, roi_end=box_end) + box_start, box_end = cropper.roi_start, cropper.roi_end + + meta_key = f"{key}_{self.meta_key_postfix}" + data[meta_key][self.start_coord_key] = box_start + data[meta_key][self.end_coord_key] = box_end + data[meta_key][self.original_shape_key] = data[key].shape + + image = cropper(data[key]) + data[meta_key][self.cropped_shape_key] = image.shape + data[key] = image + + pos_clicks, neg_clicks = guidance[0], guidance[1] + pos = np.subtract(pos_clicks, box_start).tolist() if len(pos_clicks) else [] + neg = np.subtract(neg_clicks, box_start).tolist() if len(neg_clicks) else [] + + data[self.guidance] = [pos, neg] + return data + + +class ResizeGuidanced(Transform): + """ + Resize/re-scale user click/guidance based on original vs resized/cropped image. + """ + + def __init__( + self, + guidance: str, + ref_image, + meta_key_postfix="meta_dict", + cropped_shape_key: str = "foreground_cropped_shape", + ) -> None: + """ + Args: + guidance: user input clicks/guidance used to generate the bounding box of foreground + ref_image: reference image to fetch current and original image details + meta_key_postfix: use `key_{postfix}` to to fetch the meta data according to the key data, + default is `meta_dict`, the meta data is a dictionary object. + For example, to handle key `image`, read/write affine matrices from the + metadata `image_meta_dict` dictionary's `affine` field. + cropped_shape_key: key that records cropped shape for foreground. + """ + self.guidance = guidance + self.ref_image = ref_image + self.meta_key_postfix = meta_key_postfix + self.cropped_shape_key = cropped_shape_key + + def __call__(self, data): + guidance = data[self.guidance] + meta_dict = data[f"{self.ref_image}_{self.meta_key_postfix}"] + current_shape = data[self.ref_image].shape[1:] + cropped_shape = meta_dict[self.cropped_shape_key][1:] + factor = np.divide(current_shape, cropped_shape) + + pos_clicks, neg_clicks = guidance[0], guidance[1] + pos = np.multiply(pos_clicks, factor).astype(int).tolist() if len(pos_clicks) else [] + neg = np.multiply(neg_clicks, factor).astype(int).tolist() if len(neg_clicks) else [] + + data[self.guidance] = [pos, neg] + return data + + +class RestoreCroppedLabeld(MapTransform): + """ + Restore/resize label based on original vs resized/cropped image. + """ + + def __init__( + self, + keys: KeysCollection, + ref_image: str, + slice_only=False, + channel_first=True, + mode: InterpolateModeSequence = InterpolateMode.NEAREST, + align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None, + meta_key_postfix: str = "meta_dict", + start_coord_key: str = "foreground_start_coord", + end_coord_key: str = "foreground_end_coord", + original_shape_key: str = "foreground_original_shape", + cropped_shape_key: str = "foreground_cropped_shape", + ) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + ref_image: reference image to fetch current and original image details + slice_only: apply only to an applicable slice, in case of 2D model/prediction + channel_first: if channel is positioned at first + mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, + ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + One of the listed string values or a user supplied function for padding. Defaults to ``"constant"``. + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + align_corners: Geometrically, we consider the pixels of the input as squares rather than points. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + It also can be a sequence of bool, each element corresponds to a key in ``keys``. + meta_key_postfix: use `key_{postfix}` to to fetch the meta data according to the key data, + default is `meta_dict`, the meta data is a dictionary object. + For example, to handle key `image`, read/write affine matrices from the + metadata `image_meta_dict` dictionary's `affine` field. + start_coord_key: key to record the start coordinate of spatial bounding box for foreground. + end_coord_key: key to record the end coordinate of spatial bounding box for foreground. + original_shape_key: key to record original shape for foreground. + cropped_shape_key: key to record cropped shape for foreground. + """ + super().__init__(keys) + self.ref_image = ref_image + self.slice_only = slice_only + self.channel_first = channel_first + self.mode = ensure_tuple_rep(mode, len(self.keys)) + self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) + self.meta_key_postfix = meta_key_postfix + self.start_coord_key = start_coord_key + self.end_coord_key = end_coord_key + self.original_shape_key = original_shape_key + self.cropped_shape_key = cropped_shape_key + + def __call__(self, data): + meta_dict = data[f"{self.ref_image}_{self.meta_key_postfix}"] + + for idx, key in enumerate(self.keys): + image = data[key] + + # Undo Resize + current_size = image.shape + cropped_size = meta_dict[self.cropped_shape_key] + if np.any(np.not_equal(current_size, cropped_size)): + resizer = Resize(spatial_size=cropped_size[1:], mode=self.mode[idx]) + image = resizer(image, mode=self.mode[idx], align_corners=self.align_corners[idx]) + + # Undo Crop + original_shape = meta_dict[self.original_shape_key] + result = np.zeros(original_shape, dtype=np.float32) + box_start = meta_dict[self.start_coord_key] + box_end = meta_dict[self.end_coord_key] + + sd = min(len(box_start), len(box_end), len(image.shape[1:])) # spatial dims + slices = [slice(None)] + [slice(s, e) for s, e in zip(box_start[:sd], box_end[:sd])] + slices = tuple(slices) + result[slices] = image + + # Undo Spacing + current_size = result.shape[1:] + spatial_shape = np.roll(meta_dict["spatial_shape"], 1).tolist() + spatial_size = spatial_shape[-len(current_size) :] + + if np.any(np.not_equal(current_size, spatial_size)): + resizer = Resize(spatial_size=spatial_size, mode=self.mode[idx]) + result = resizer(result, mode=self.mode[idx], align_corners=self.align_corners[idx]) + + # Undo Slicing + slice_idx = meta_dict.get("slice_idx") + if slice_idx is None or self.slice_only: + final_result = result if len(result.shape) <= 3 else result[0] + else: + slice_idx = meta_dict["slice_idx"][0] + final_result = np.zeros(spatial_shape) + if self.channel_first: + final_result[slice_idx] = result + else: + final_result[..., slice_idx] = result + data[key] = final_result + + meta = data.get(f"{key}_{self.meta_key_postfix}") + if meta is None: + meta = dict() + data[f"{key}_{self.meta_key_postfix}"] = meta + meta["slice_idx"] = slice_idx + meta["affine"] = meta_dict["original_affine"] + return data + + +class AddGuidanceFromPointsd(Randomizable, Transform): + """ + Add guidance based on user clicks. + """ + + def __init__( + self, + ref_image, + guidance="guidance", + foreground="foreground", + background="background", + axis=0, + channel_first=True, + dimensions=2, + slice_key="slice", + meta_key_postfix: str = "meta_dict", + ): + """ + Args: + ref_image: reference image to fetch current and original image details. + guidance: output key to store guidance. + foreground: key that represents user foreground (+ve) clicks. + background: key that represents user background (-ve) clicks. + axis: axis that represents slice in 3D volume. + channel_first: if channel is positioned at first. + dimensions: dimensions based on model used for deepgrow (2D vs 3D). + slice_key: key that represents applicable slice to add guidance. + meta_key_postfix: use `key_{postfix}` to to fetch the meta data according to the key data, + default is `meta_dict`, the meta data is a dictionary object. + For example, to handle key `image`, read/write affine matrices from the + metadata `image_meta_dict` dictionary's `affine` field. + """ + self.ref_image = ref_image + self.guidance = guidance + self.foreground = foreground + self.background = background + self.axis = axis + self.channel_first = channel_first + self.dimensions = dimensions + self.slice_key = slice_key + self.meta_key_postfix = meta_key_postfix + + def randomize(self, data=None): + pass + + def _apply(self, pos_clicks, neg_clicks, factor, slice_num=None): + points = pos_clicks + points.extend(neg_clicks) + points = np.array(points) + + if self.dimensions == 2: + slices = np.unique(points[:, self.axis]).tolist() + slice_idx = slices[0] if slice_num is None else next(x for x in slices if x == slice_num) + + pos = neg = [] + if len(pos_clicks): + pos_clicks = np.array(pos_clicks) + pos = (pos_clicks[np.where(pos_clicks[:, self.axis] == slice_idx)] * factor)[:, 1:].astype(int).tolist() + if len(neg_clicks): + neg_clicks = np.array(neg_clicks) + neg = (neg_clicks[np.where(neg_clicks[:, self.axis] == slice_idx)] * factor)[:, 1:].astype(int).tolist() + + guidance = [pos, neg, slice_idx, factor] + else: + pos = neg = [] + if len(pos_clicks): + pos = np.multiply(pos_clicks, factor).astype(int).tolist() + if len(neg_clicks): + neg = np.multiply(neg_clicks, factor).astype(int).tolist() + guidance = [pos, neg] + return guidance + + def __call__(self, data): + meta_dict = data[f"{self.ref_image}_{self.meta_key_postfix}"] + original_shape = meta_dict["spatial_shape"] + current_shape = list(data[self.ref_image].shape) + + clicks = [data[self.foreground], data[self.background]] + if self.channel_first: + original_shape = np.roll(original_shape, 1).tolist() + for i in range(len(clicks)): + clicks[i] = json.loads(clicks[i]) if isinstance(clicks[i], str) else clicks[i] + clicks[i] = np.array(clicks[i]).astype(int).tolist() + for j in range(len(clicks[i])): + clicks[i][j] = np.roll(clicks[i][j], 1).tolist() + + factor = np.array(current_shape) / original_shape + data[self.guidance] = self._apply(clicks[0], clicks[1], factor, data.get(self.slice_key)) + return data + + +class Fetch2DSliced(MapTransform): + """ + Fetch once slice in case of a 3D volume. + """ + + def __init__(self, keys, guidance="guidance", axis=0, meta_key_postfix: str = "meta_dict"): + """ + Args: + keys: keys of the corresponding items to be transformed. + guidance: key that represents guidance. + axis: axis that represents slice in 3D volume. + meta_key_postfix: use `key_{postfix}` to to fetch the meta data according to the key data, + default is `meta_dict`, the meta data is a dictionary object. + For example, to handle key `image`, read/write affine matrices from the + metadata `image_meta_dict` dictionary's `affine` field. + """ + super().__init__(keys) + self.guidance = guidance + self.axis = axis + self.meta_key_postfix = meta_key_postfix + + def _apply(self, image, guidance): + slice_idx = guidance[2] # (pos, neg, slice_idx, factor) + idx = [] + for i in range(len(image.shape)): + idx.append(slice_idx) if i == self.axis else idx.append(slice(0, image.shape[i])) + + idx = tuple(idx) + return image[idx], idx + + def __call__(self, data): + guidance = data[self.guidance] + for key in self.keys: + img, idx = self._apply(data[key], guidance) + data[key] = img + data[f"{key}_{self.meta_key_postfix}"]["slice_idx"] = idx + return data diff --git a/tests/min_tests.py b/tests/min_tests.py index 0fd6985067..b24da7dc21 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -104,6 +104,9 @@ def run_testsuit(): "test_handler_metrics_saver_dist", "test_evenly_divisible_all_gather_dist", "test_handler_classification_saver_dist", + "test_deepgrow_dataset", + "test_deepgrow_interaction", + "test_deepgrow_transforms", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_crop_foregroundd.py b/tests/test_crop_foregroundd.py index cacf990763..42e02d56ff 100644 --- a/tests/test_crop_foregroundd.py +++ b/tests/test_crop_foregroundd.py @@ -58,19 +58,19 @@ class TestCropForegroundd(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) - def test_value(self, argments, image, expected_data): - result = CropForegroundd(**argments)(image) + def test_value(self, arguments, image, expected_data): + result = CropForegroundd(**arguments)(image) np.testing.assert_allclose(result["img"], expected_data) @parameterized.expand([TEST_CASE_1]) - def test_foreground_position(self, argments, image, _): - result = CropForegroundd(**argments)(image) + def test_foreground_position(self, arguments, image, _): + result = CropForegroundd(**arguments)(image) np.testing.assert_allclose(result["foreground_start_coord"], np.array([1, 1])) np.testing.assert_allclose(result["foreground_end_coord"], np.array([4, 4])) - argments["start_coord_key"] = "test_start_coord" - argments["end_coord_key"] = "test_end_coord" - result = CropForegroundd(**argments)(image) + arguments["start_coord_key"] = "test_start_coord" + arguments["end_coord_key"] = "test_end_coord" + result = CropForegroundd(**arguments)(image) np.testing.assert_allclose(result["test_start_coord"], np.array([1, 1])) np.testing.assert_allclose(result["test_end_coord"], np.array([4, 4])) diff --git a/tests/test_deepgrow_dataset.py b/tests/test_deepgrow_dataset.py new file mode 100644 index 0000000000..3b6969acd3 --- /dev/null +++ b/tests/test_deepgrow_dataset.py @@ -0,0 +1,53 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import nibabel as nib +import numpy as np + +from monai.apps.deepgrow.dataset import create_dataset + + +class TestCreateDataset(unittest.TestCase): + def _create_data(self, tempdir): + image = np.random.randint(0, 2, size=(4, 4, 4)) + image_file = os.path.join(tempdir, "image1.nii.gz") + nib.save(nib.Nifti1Image(image, np.eye(4)), image_file) + + label = np.random.randint(0, 1, size=(4, 4, 4)) + label[0][0][2] = 1 + label[0][1][2] = 1 + label[0][1][0] = 1 + label_file = os.path.join(tempdir, "label1.nii.gz") + nib.save(nib.Nifti1Image(label, np.eye(4)), label_file) + + return [{"image": image_file, "label": label_file}] + + def test_create_dataset_2d(self): + with tempfile.TemporaryDirectory() as tempdir: + datalist = self._create_data(tempdir) + output_dir = os.path.join(tempdir, "2d") + deepgrow_datalist = create_dataset(datalist=datalist, output_dir=output_dir, dimension=2, pixdim=(1, 1)) + assert len(deepgrow_datalist) == 2 and deepgrow_datalist[0]["region"] == 1 + + def test_create_dataset_3d(self): + with tempfile.TemporaryDirectory() as tempdir: + datalist = self._create_data(tempdir) + output_dir = os.path.join(tempdir, "3d") + deepgrow_datalist = create_dataset(datalist=datalist, output_dir=output_dir, dimension=3, pixdim=(1, 1, 1)) + assert len(deepgrow_datalist) == 1 and deepgrow_datalist[0]["region"] == 1 + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_deepgrow_interaction.py b/tests/test_deepgrow_interaction.py new file mode 100644 index 0000000000..9b865fdfad --- /dev/null +++ b/tests/test_deepgrow_interaction.py @@ -0,0 +1,37 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from monai.apps.deepgrow.interaction import Interaction + + +class TestInteractions(unittest.TestCase): + def test_interaction(self): + transforms = [ + { + "name": "Activationsd", + "path": "monai.transforms.Activationsd", + "args": {"keys": "pred", "sigmoid": True}, + }, + { + "name": "FindDiscrepancyRegionsd", + "path": "monai.apps.deepgrow.transforms.FindDiscrepancyRegionsd", + "args": {"label": "label", "pred": "pred", "discrepancy": "discrepancy", "batched": True}, + }, + ] + + i = Interaction(transforms=transforms, train=True, max_interactions=5) + assert len(i.transforms.transforms) == 2 + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_deepgrow_transforms.py b/tests/test_deepgrow_transforms.py new file mode 100644 index 0000000000..fbcc93ccee --- /dev/null +++ b/tests/test_deepgrow_transforms.py @@ -0,0 +1,229 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.apps.deepgrow.transforms import ( + AddGuidanceFromPointsd, + AddGuidanceSignald, + AddInitialSeedPointd, + Fetch2DSliced, + FindAllValidSlicesd, + FindDiscrepancyRegionsd, + ResizeGuidanced, + RestoreCroppedLabeld, + SpatialCropForegroundd, + SpatialCropGuidanced, +) +from monai.transforms import AddChanneld + +DATA_1 = { + "image": np.array([[[[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]]]]), + "label": np.array([[[[0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]]), + "image_meta_dict": {}, + "label_meta_dict": {}, +} + +DATA_2 = { + "image": np.array( + [ + [ + [[1, 2, 3, 2, 1], [1, 1, 3, 2, 1], [0, 0, 0, 0, 0], [1, 1, 1, 2, 1], [0, 2, 2, 2, 1]], + [[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]], + ] + ] + ), + "label": np.array( + [ + [ + [[0, 0, 1, 0, 0], [0, 1, 1, 1, 0], [1, 1, 1, 1, 1], [0, 1, 1, 1, 0], [0, 0, 1, 0, 0]], + [[0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]], + ] + ] + ), + "guidance": [[[1, 0, 2, 2], [1, 1, 2, 2]], [[-1, -1, -1, -1], [-1, -1, -1, -1]]], +} + +DATA_3 = { + "image": np.array([[[[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]]]]), + "label": np.array([[[[0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]]), + "pred": np.array([[[[0, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 1, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]]), +} + +DATA_INFER = { + "image": np.array([[[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]]]), + "label": np.array([[[0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]), + "image_meta_dict": {}, + "label_meta_dict": {}, + "foreground": [[2, 2, 0]], + "background": [], +} + +FIND_SLICE_TEST_CASE_1 = [ + {"label": "label", "sids": "sids"}, + DATA_1, + [0], +] + +FIND_SLICE_TEST_CASE_2 = [ + {"label": "label", "sids": "sids"}, + DATA_2, + [0, 1], +] + +CROP_TEST_CASE_1 = [ + { + "keys": ["image", "label"], + "source_key": "label", + "select_fn": lambda x: x > 0, + "channel_indices": None, + "margin": 0, + "spatial_size": [1, 4, 4], + }, + DATA_1, + np.array([[[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]]), +] + +ADD_INITIAL_POINT_TEST_CASE_1 = [ + {"label": "label", "guidance": "guidance", "sids": "sids"}, + DATA_1, + [[[1, 0, 2, 2]], [[-1, -1, -1, -1]]], +] + +ADD_GUIDANCE_TEST_CASE_1 = [ + {"image": "image", "guidance": "guidance"}, + DATA_2, + np.array( + [ + [ + [[1, 2, 3, 2, 1], [1, 1, 3, 2, 1], [0, 0, 0, 0, 0], [1, 1, 1, 2, 1], [0, 2, 2, 2, 1]], + [[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]], + ], + [ + [ + [0.0, 0.28531367, 0.46186933, 0.28531367, 0.0], + [0.28531399, 0.59972864, 0.79429233, 0.59972864, 0.28531399], + [0.46186963, 0.79429233, 1.0, 0.79429233, 0.46186963], + [0.28531399, 0.59972864, 0.79429233, 0.59972864, 0.28531399], + [0.0, 0.28531367, 0.46186933, 0.28531367, 0.0], + ], + [ + [0.0, 0.28531367, 0.46186933, 0.28531367, 0.0], + [0.28531399, 0.59972864, 0.79429233, 0.59972864, 0.28531399], + [0.46186963, 0.79429233, 1.0, 0.79429233, 0.46186963], + [0.28531399, 0.59972864, 0.79429233, 0.59972864, 0.28531399], + [0.0, 0.28531367, 0.46186933, 0.28531367, 0.0], + ], + ], + [ + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + ], + ] + ), +] + +FIND_DISCREPANCY_TEST_CASE_1 = [ + {"label": "label", "pred": "pred", "discrepancy": "discrepancy"}, + DATA_3, + np.array( + [ + [ + [[[0, 0, 0, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]], + [[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]], + ] + ] + ), +] + + +class TestFindAllValidSlicesd(unittest.TestCase): + @parameterized.expand([FIND_SLICE_TEST_CASE_1, FIND_SLICE_TEST_CASE_2]) + def test_correct_results(self, arguments, input_data, expected_result): + result = FindAllValidSlicesd(**arguments)(input_data) + np.testing.assert_allclose(result[arguments["sids"]], expected_result) + + +class TestSpatialCropForegroundd(unittest.TestCase): + @parameterized.expand([CROP_TEST_CASE_1]) + def test_correct_results(self, arguments, input_data, expected_result): + result = SpatialCropForegroundd(**arguments)(input_data) + np.testing.assert_allclose(result["image"], expected_result) + + @parameterized.expand([CROP_TEST_CASE_1]) + def test_foreground_position(self, arguments, input_data, _): + result = SpatialCropForegroundd(**arguments)(input_data) + np.testing.assert_allclose(result["image_meta_dict"]["foreground_start_coord"], np.array([0, 1, 1])) + np.testing.assert_allclose(result["image_meta_dict"]["foreground_end_coord"], np.array([1, 4, 4])) + + arguments["start_coord_key"] = "test_start_coord" + arguments["end_coord_key"] = "test_end_coord" + result = SpatialCropForegroundd(**arguments)(input_data) + np.testing.assert_allclose(result["image_meta_dict"]["test_start_coord"], np.array([0, 1, 1])) + np.testing.assert_allclose(result["image_meta_dict"]["test_end_coord"], np.array([1, 4, 4])) + + +class TestAddInitialSeedPointd(unittest.TestCase): + @parameterized.expand([ADD_INITIAL_POINT_TEST_CASE_1]) + def test_correct_results(self, arguments, input_data, expected_result): + seed = 0 + add_fn = AddInitialSeedPointd(**arguments) + add_fn.set_random_state(seed) + result = add_fn(input_data) + np.testing.assert_allclose(result[arguments["guidance"]], expected_result) + + +class TestAddGuidanceSignald(unittest.TestCase): + @parameterized.expand([ADD_GUIDANCE_TEST_CASE_1]) + def test_correct_results(self, arguments, input_data, expected_result): + add_fn = AddGuidanceSignald(**arguments) + result = add_fn(input_data) + np.testing.assert_allclose(result["image"], expected_result, rtol=1e-5) + + +class TestFindDiscrepancyRegionsd(unittest.TestCase): + @parameterized.expand([FIND_DISCREPANCY_TEST_CASE_1]) + def test_correct_results(self, arguments, input_data, expected_result): + result = FindDiscrepancyRegionsd(**arguments)(input_data) + np.testing.assert_allclose(result[arguments["discrepancy"]], expected_result) + + +class TestTransforms(unittest.TestCase): + def test_inference(self): + result = DATA_INFER.copy() + result["image_meta_dict"]["spatial_shape"] = (5, 5, 1) + result["image_meta_dict"]["original_affine"] = (0, 0) + + result = AddGuidanceFromPointsd( + ref_image="image", guidance="guidance", foreground="foreground", background="background", dimensions=2 + )(result) + assert len(result["guidance"][0][0]) == 2 + + result = Fetch2DSliced(keys="image", guidance="guidance")(result) + assert result["image"].shape == (5, 5) + + result = AddChanneld(keys="image")(result) + + result = SpatialCropGuidanced(keys="image", guidance="guidance", spatial_size=(4, 4))(result) + assert result["image"].shape == (1, 4, 4) + + result = ResizeGuidanced(guidance="guidance", ref_image="image")(result) + + result["pred"] = np.random.randint(0, 2, size=(1, 4, 4)) + result = RestoreCroppedLabeld(keys="pred", ref_image="image", mode="nearest")(result) + assert result["pred"].shape == (1, 5, 5) + + +if __name__ == "__main__": + unittest.main()