From 52d1338b6e9fe5f944916aac3d9013bd26249fe2 Mon Sep 17 00:00:00 2001 From: YuanTingHsieh Date: Thu, 25 Feb 2021 18:57:06 -0800 Subject: [PATCH 1/3] Add inference transforms Signed-off-by: YuanTingHsieh --- docs/source/apps.rst | 10 + monai/apps/deepgrow/transforms.py | 461 +++++++++++++++++++++++++++++- tests/test_deepgrow_transforms.py | 226 +++++++++++++++ 3 files changed, 689 insertions(+), 8 deletions(-) diff --git a/docs/source/apps.rst b/docs/source/apps.rst index b8c8b4d341..1c4f4c3dfb 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -46,9 +46,19 @@ Applications :members: .. autoclass:: AddRandomGuidanced :members: +.. autoclass:: AddGuidanceFromPointsd + :members: .. autoclass:: SpatialCropForegroundd :members: +.. autoclass:: SpatialCropGuidanced + :members: +.. autoclass:: RestoreLabeld + :members: +.. autoclass:: ResizeGuidanced + :members: .. autoclass:: FindDiscrepancyRegionsd :members: .. autoclass:: FindAllValidSlicesd :members: +.. autoclass:: Fetch2DSliced + :members: diff --git a/monai/apps/deepgrow/transforms.py b/monai/apps/deepgrow/transforms.py index f178360031..4e985bdf16 100644 --- a/monai/apps/deepgrow/transforms.py +++ b/monai/apps/deepgrow/transforms.py @@ -8,18 +8,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from typing import Callable, Optional, Sequence, Union +import json +from typing import Callable, Dict, Hashable, Mapping, Optional, Sequence, Union import numpy as np import torch from monai.config import IndexSelection, KeysCollection from monai.networks.layers import GaussianFilter -from monai.transforms import SpatialCrop +from monai.transforms import Resize, SpatialCrop from monai.transforms.compose import MapTransform, Randomizable, Transform from monai.transforms.utils import generate_spatial_bounding_box -from monai.utils import min_version, optional_import +from monai.utils import InterpolateMode, ensure_tuple_rep, min_version, optional_import measure, _ = optional_import("skimage.measure", "0.14.2", min_version) distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt") @@ -67,7 +67,8 @@ class AddInitialSeedPointd(Randomizable, Transform): Add random guidance as initial seed point for a given label. Note that the label is of size (C, D, H, W) or (C, H, W) - The guidance is of size (2, N, # of dims) where N is number of guidance added + + The guidance is of size (2, N, # of dims) where N is number of guidance added. # of dims = 4 when C, D, H, W; # of dims = 3 when (C, H, W) Args: @@ -232,6 +233,7 @@ class FindDiscrepancyRegionsd(Transform): Find discrepancy between prediction and actual during click interactions during training. If batched is true: + label is in shape (B, C, D, H, W) or (B, C, H, W) pred has same shape as label discrepancy will have shape (B, 2, C, D, H, W) or (B, 2, C, H, W) @@ -283,7 +285,7 @@ class AddRandomGuidanced(Randomizable, Transform): """ Add random guidance based on discrepancies that were found between label and prediction. - If batched is True: + If batched is True, input shape is as below: Guidance is of shape (B, 2, N, # of dim) where B is batch size, 2 means positive and negative, N means how many guidance points, # of dim is the total number of dimensions of the image @@ -291,7 +293,15 @@ class AddRandomGuidanced(Randomizable, Transform): Discrepancy is of shape (B, 2, C, D, H, W) or (B, 2, C, H, W) - Probability is of shape (B,) + Probability is of shape (B, 1) + + else: + + Guidance is of shape (2, N, # of dim) + + Discrepancy is of shape (2, C, D, H, W) or (2, C, H, W) + + Probability is of shape (1) Args: guidance: key to guidance source. @@ -389,7 +399,7 @@ class SpatialCropForegroundd(MapTransform): """ Crop only the foreground object of the expected images. - Difference VS CropForegroundd: + Difference VS :py:class:`monai.transforms.CropForegroundd`: 1. If the bounding box is smaller than spatial size in all dimensions then this transform will crop the object using box's center and spatial_size. @@ -399,9 +409,11 @@ class SpatialCropForegroundd(MapTransform): The typical usage is to help training and evaluation if the valid part is small in the whole medical image. The valid part can be determined by any field in the data with `source_key`, for example: + - Select values > 0 in image field as the foreground and crop on all fields specified by `keys`. - Select label = 3 in label field as the foreground to crop on all fields specified by `keys`. - Select label > 0 in the third channel of a One-Hot label field as the foreground to crop all `keys` fields. + Users can define arbitrary function to select expected foreground from the whole source image or specified channels. And it can also add margin to every dim of the bounding box of foreground object. @@ -477,3 +489,436 @@ def __call__(self, data): d[meta_key][self.cropped_shape_key] = image.shape d[key] = image return d + + +# Transforms to support Inference for Deepgrow models +class AddGuidanceFromPointsd(Transform): + """ + Add guidance based on user clicks. + + We assume the input is loaded by LoadImaged and has the shape of (H, W, D) originally. + Clicks always specify the coordinates in (H, W, D) + + If depth_first is True: + + Input is now of shape (D, H, W), will return guidance that specifies the coordinates in (D, H, W) + + else: + + Input is now of shape (H, W, D), will return guidance that specifies the coordinates in (H, W, D) + + Args: + ref_image: key to reference image to fetch current and original image details. + guidance: output key to store guidance. + foreground: key that represents user foreground (+ve) clicks. + background: key that represents user background (-ve) clicks. + axis: axis that represents slices in 3D volume. (axis to Depth) + depth_first: if depth (slices) is positioned at first dimension. + dimensions: dimensions based on model used for deepgrow (2D vs 3D). + slice: key that represents applicable slice to add guidance. + meta_key_postfix: use `{ref_image}_{postfix}` to to fetch the meta data according to the key data, + default is `meta_dict`, the meta data is a dictionary object. + For example, to handle key `image`, read/write affine matrices from the + metadata `image_meta_dict` dictionary's `affine` field. + """ + + def __init__( + self, + ref_image, + guidance: str = "guidance", + foreground: str = "foreground", + background: str = "background", + axis: int = 0, + depth_first: bool = True, + dimensions: int = 2, + slice: str = "slice", + meta_key_postfix: str = "meta_dict", + ): + self.ref_image = ref_image + self.guidance = guidance + self.foreground = foreground + self.background = background + self.axis = axis + self.depth_first = depth_first + self.dimensions = dimensions + self.slice = slice + self.meta_key_postfix = meta_key_postfix + + def _apply(self, pos_clicks, neg_clicks, factor, slice_num): + pos = neg = [] + + if self.dimensions == 2: + points = list(pos_clicks) + points.extend(neg_clicks) + points = np.array(points) + + slices = np.unique(points[:, self.axis]).tolist() + slice_idx = slices[0] if slice_num is None else next(x for x in slices if x == slice_num) + + if len(pos_clicks): + pos_clicks = np.array(pos_clicks) + pos = (pos_clicks[np.where(pos_clicks[:, self.axis] == slice_idx)] * factor)[:, 1:].astype(int).tolist() + if len(neg_clicks): + neg_clicks = np.array(neg_clicks) + neg = (neg_clicks[np.where(neg_clicks[:, self.axis] == slice_idx)] * factor)[:, 1:].astype(int).tolist() + + guidance = [pos, neg, slice_idx] + else: + if len(pos_clicks): + pos = np.multiply(pos_clicks, factor).astype(int).tolist() + if len(neg_clicks): + neg = np.multiply(neg_clicks, factor).astype(int).tolist() + guidance = [pos, neg] + return guidance + + def __call__(self, data): + d = dict(data) + meta_dict_key = f"{self.ref_image}_{self.meta_key_postfix}" + if meta_dict_key not in d: + raise RuntimeError(f"Missing meta_dict {meta_dict_key} in data!") + if "spatial_shape" not in d[meta_dict_key]: + raise RuntimeError('Missing "spatial_shape" in meta_dict!') + original_shape = d[meta_dict_key]["spatial_shape"] + current_shape = list(d[self.ref_image].shape) + + if self.depth_first: + if self.axis != 0: + raise RuntimeError("Depth first means the depth axis should be 0.") + # in here we assume the depth dimension was in the last dimension of "original_shape" + original_shape = np.roll(original_shape, 1) + + factor = np.array(current_shape) / original_shape + + fg_bg_clicks = [] + for key in [self.foreground, self.background]: + clicks = json.loads(d[key]) if isinstance(d[key], str) else d[key] + clicks = np.array(clicks).astype(int).tolist() + if self.depth_first: + for i in range(len(clicks)): + clicks[i] = np.roll(clicks[i], 1).tolist() + fg_bg_clicks.append(clicks) + d[self.guidance] = self._apply(fg_bg_clicks[0], fg_bg_clicks[1], factor, d.get(self.slice, None)) + return d + + +class SpatialCropGuidanced(MapTransform): + """ + Crop image based on guidance with minimal spatial size. + + - If the bounding box is smaller than spatial size in all dimensions then this transform will crop the + object using box's center and spatial_size. + + - This transform will set "start_coord_key", "end_coord_key", "original_shape_key" and "cropped_shape_key" + in data[{key}_{meta_key_postfix}] + + Input data is of shape (C, spatial_1, [spatial_2, ...]) + + Args: + keys: keys of the corresponding items to be transformed. + guidance: key to the guidance. It is used to generate the bounding box of foreground + spatial_size: minimal spatial size of the image patch e.g. [128, 128, 128] to fit in. + margin: add margin value to spatial dims of the bounding box, if only 1 value provided, use it for all dims. + meta_key_postfix: use `key_{postfix}` to to fetch the meta data according to the key data, + default is `meta_dict`, the meta data is a dictionary object. + For example, to handle key `image`, read/write affine matrices from the + metadata `image_meta_dict` dictionary's `affine` field. + start_coord_key: key to record the start coordinate of spatial bounding box for foreground. + end_coord_key: key to record the end coordinate of spatial bounding box for foreground. + original_shape_key: key to record original shape for foreground. + cropped_shape_key: key to record cropped shape for foreground. + """ + + def __init__( + self, + keys: KeysCollection, + guidance: str, + spatial_size, + margin=20, + meta_key_postfix="meta_dict", + start_coord_key: str = "foreground_start_coord", + end_coord_key: str = "foreground_end_coord", + original_shape_key: str = "foreground_original_shape", + cropped_shape_key: str = "foreground_cropped_shape", + ) -> None: + super().__init__(keys) + + self.guidance = guidance + self.spatial_size = list(spatial_size) + self.margin = margin + self.meta_key_postfix = meta_key_postfix + self.start_coord_key = start_coord_key + self.end_coord_key = end_coord_key + self.original_shape_key = original_shape_key + self.cropped_shape_key = cropped_shape_key + + def bounding_box(self, points, img_shape): + ndim = len(img_shape) + margin = ensure_tuple_rep(self.margin, ndim) + for m in margin: + if m < 0: + raise ValueError("margin value should not be negative number.") + + box_start = [0] * ndim + box_end = [0] * ndim + + for di in range(ndim): + dt = points[..., di] + min_d = max(min(dt - margin[di]), 0) + max_d = min(img_shape[di], max(dt + margin[di] + 1)) + box_start[di], box_end[di] = min_d, max_d + return box_start, box_end + + def __call__(self, data): + d = dict(data) + guidance = d[self.guidance] + original_spatial_shape = d[self.keys[0]].shape[1:] + box_start, box_end = self.bounding_box(np.array(guidance[0] + guidance[1]), original_spatial_shape) + center = np.mean([box_start, box_end], axis=0).astype(int).tolist() + spatial_size = self.spatial_size + + box_size = np.subtract(box_end, box_start).astype(int).tolist() + spatial_size = spatial_size[-len(box_size) :] + + if len(spatial_size) < len(box_size): + # If the data is in 3D and spatial_size is specified as 2D [256,256] + # Then we will get all slices in such case + diff = len(box_size) - len(spatial_size) + spatial_size = list(original_spatial_shape[1 : (1 + diff)]) + spatial_size + + if np.all(np.less(box_size, spatial_size)): + if len(center) == 3: + # 3D Deepgrow: set center to be middle of the depth dimension (D) + center[0] = spatial_size[0] // 2 + cropper = SpatialCrop(roi_center=center, roi_size=spatial_size) + else: + cropper = SpatialCrop(roi_start=box_start, roi_end=box_end) + box_start, box_end = cropper.roi_start, cropper.roi_end + + for key in self.keys: + if not np.array_equal(d[key].shape[1:], original_spatial_shape): + raise RuntimeError("All the image specified in keys should have same spatial shape") + meta_key = f"{key}_{self.meta_key_postfix}" + d[meta_key][self.start_coord_key] = box_start + d[meta_key][self.end_coord_key] = box_end + d[meta_key][self.original_shape_key] = d[key].shape + + image = cropper(d[key]) + d[meta_key][self.cropped_shape_key] = image.shape + d[key] = image + + pos_clicks, neg_clicks = guidance[0], guidance[1] + pos = np.subtract(pos_clicks, box_start).tolist() if len(pos_clicks) else [] + neg = np.subtract(neg_clicks, box_start).tolist() if len(neg_clicks) else [] + + d[self.guidance] = [pos, neg] + return d + + +class ResizeGuidanced(Transform): + """ + Resize the guidance based on cropped vs resized image. + + This transform assumes that the images have been cropped and resized. And the shape after cropped is store inside + the meta dict of ref image. + + Args: + guidance: key to guidance + ref_image: key to reference image to fetch current and original image details + meta_key_postfix: use `{ref_image}_{postfix}` to to fetch the meta data according to the key data, + default is `meta_dict`, the meta data is a dictionary object. + For example, to handle key `image`, read/write affine matrices from the + metadata `image_meta_dict` dictionary's `affine` field. + cropped_shape_key: key that records cropped shape for foreground. + """ + + def __init__( + self, + guidance: str, + ref_image: str, + meta_key_postfix="meta_dict", + cropped_shape_key: str = "foreground_cropped_shape", + ) -> None: + self.guidance = guidance + self.ref_image = ref_image + self.meta_key_postfix = meta_key_postfix + self.cropped_shape_key = cropped_shape_key + + def __call__(self, data): + d = dict(data) + guidance = d[self.guidance] + meta_dict = d[f"{self.ref_image}_{self.meta_key_postfix}"] + current_shape = d[self.ref_image].shape[1:] + cropped_shape = meta_dict[self.cropped_shape_key][1:] + factor = np.divide(current_shape, cropped_shape) + + pos_clicks, neg_clicks = guidance[0], guidance[1] + pos = np.multiply(pos_clicks, factor).astype(int).tolist() if len(pos_clicks) else [] + neg = np.multiply(neg_clicks, factor).astype(int).tolist() if len(neg_clicks) else [] + + d[self.guidance] = [pos, neg] + return d + + +class RestoreLabeld(MapTransform): + """ + Restores label based on the ref image. + + The ref_image is assumed that it went through the following transforms: + + 1. Fetch2DSliced (If 2D) + 2. Spacingd + 3. SpatialCropGuidanced + 4. Resized + + And its shape is assumed to be (C, D, H, W) + + This transform tries to undo these operation so that the result label can be overlapped with original volume. + It does the following operation: + + 1. Undo Resized + 2. Undo SpatialCropGuidanced + 3. Undo Spacingd + 4. Undo Fetch2DSliced + + The resulting label is of shape (D, H, W) + + Args: + keys: keys of the corresponding items to be transformed. + ref_image: reference image to fetch current and original image details + slice_only: apply only to an applicable slice, in case of 2D model/prediction + mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, + ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + One of the listed string values or a user supplied function for padding. Defaults to ``"constant"``. + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + align_corners: Geometrically, we consider the pixels of the input as squares rather than points. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + It also can be a sequence of bool, each element corresponds to a key in ``keys``. + meta_key_postfix: use `{ref_image}_{meta_key_postfix}` to to fetch the meta data according to the key data, + default is `meta_dict`, the meta data is a dictionary object. + For example, to handle key `image`, read/write affine matrices from the + metadata `image_meta_dict` dictionary's `affine` field. + start_coord_key: key that records the start coordinate of spatial bounding box for foreground. + end_coord_key: key that records the end coordinate of spatial bounding box for foreground. + original_shape_key: key that records original shape for foreground. + cropped_shape_key: key that records cropped shape for foreground. + """ + + def __init__( + self, + keys: KeysCollection, + ref_image: str, + slice_only: bool = False, + mode: Union[Sequence[Union[InterpolateMode, str]], InterpolateMode, str] = InterpolateMode.NEAREST, + align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None, + meta_key_postfix: str = "meta_dict", + start_coord_key: str = "foreground_start_coord", + end_coord_key: str = "foreground_end_coord", + original_shape_key: str = "foreground_original_shape", + cropped_shape_key: str = "foreground_cropped_shape", + ) -> None: + super().__init__(keys) + self.ref_image = ref_image + self.slice_only = slice_only + self.mode = ensure_tuple_rep(mode, len(self.keys)) + self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) + self.meta_key_postfix = meta_key_postfix + self.start_coord_key = start_coord_key + self.end_coord_key = end_coord_key + self.original_shape_key = original_shape_key + self.cropped_shape_key = cropped_shape_key + + def __call__(self, data): + d = dict(data) + meta_dict = d[f"{self.ref_image}_{self.meta_key_postfix}"] + + for idx, key in enumerate(self.keys): + image = d[key] + + # Undo Resize + current_shape = image.shape + cropped_shape = meta_dict[self.cropped_shape_key] + if np.any(np.not_equal(current_shape, cropped_shape)): + resizer = Resize(spatial_size=cropped_shape[1:], mode=self.mode[idx]) + image = resizer(image, mode=self.mode[idx], align_corners=self.align_corners[idx]) + + # Undo Crop + original_shape = meta_dict[self.original_shape_key] + result = np.zeros(original_shape, dtype=np.float32) + box_start = meta_dict[self.start_coord_key] + box_end = meta_dict[self.end_coord_key] + + spatial_dims = min(len(box_start), len(image.shape[1:])) + slices = [slice(None)] + [slice(s, e) for s, e in zip(box_start[:spatial_dims], box_end[:spatial_dims])] + slices = tuple(slices) + result[slices] = image + + # Undo Spacing + current_size = result.shape[1:] + # change spatial_shape from HWD to DHW + spatial_shape = np.roll(meta_dict["spatial_shape"], 1).tolist() + spatial_size = spatial_shape[-len(current_size) :] + + if np.any(np.not_equal(current_size, spatial_size)): + resizer = Resize(spatial_size=spatial_size, mode=self.mode[idx]) + result = resizer(result, mode=self.mode[idx], align_corners=self.align_corners[idx]) + + # Undo Slicing + slice_idx = meta_dict.get("slice_idx") + if slice_idx is None or self.slice_only: + final_result = result if len(result.shape) <= 3 else result[0] + else: + slice_idx = meta_dict["slice_idx"][0] + final_result = np.zeros(spatial_shape) + final_result[slice_idx] = result + d[key] = final_result + + meta = d.get(f"{key}_{self.meta_key_postfix}") + if meta is None: + meta = dict() + d[f"{key}_{self.meta_key_postfix}"] = meta + meta["slice_idx"] = slice_idx + meta["affine"] = meta_dict["original_affine"] + return d + + +class Fetch2DSliced(MapTransform): + """ + Fetch one slice in case of a 3D volume. + + The volume only contains spatial coordinates. + + Args: + keys: keys of the corresponding items to be transformed. + guidance: key that represents guidance. + axis: axis that represents slice in 3D volume. + meta_key_postfix: use `key_{meta_key_postfix}` to to fetch the meta data according to the key data, + default is `meta_dict`, the meta data is a dictionary object. + For example, to handle key `image`, read/write affine matrices from the + metadata `image_meta_dict` dictionary's `affine` field. + """ + + def __init__(self, keys, guidance="guidance", axis: int = 0, meta_key_postfix: str = "meta_dict"): + super().__init__(keys) + self.guidance = guidance + self.axis = axis + self.meta_key_postfix = meta_key_postfix + + def _apply(self, image, guidance): + slice_idx = guidance[2] # (pos, neg, slice_idx) + idx = [] + for i in range(len(image.shape)): + idx.append(slice_idx) if i == self.axis else idx.append(slice(0, image.shape[i])) + + idx = tuple(idx) + return image[idx], idx + + def __call__(self, data): + d = dict(data) + guidance = d[self.guidance] + if len(guidance) < 3: + raise RuntimeError("Guidance does not container slice_idx!") + for key in self.keys: + img_slice, idx = self._apply(d[key], guidance) + d[key] = img_slice + d[f"{key}_{self.meta_key_postfix}"]["slice_idx"] = idx + return d diff --git a/tests/test_deepgrow_transforms.py b/tests/test_deepgrow_transforms.py index f534813832..1f7e8a2488 100644 --- a/tests/test_deepgrow_transforms.py +++ b/tests/test_deepgrow_transforms.py @@ -15,12 +15,17 @@ from parameterized import parameterized from monai.apps.deepgrow.transforms import ( + AddGuidanceFromPointsd, AddGuidanceSignald, AddInitialSeedPointd, AddRandomGuidanced, + Fetch2DSliced, FindAllValidSlicesd, FindDiscrepancyRegionsd, + ResizeGuidanced, + RestoreLabeld, SpatialCropForegroundd, + SpatialCropGuidanced, ) IMAGE = np.array([[[[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]]]]) @@ -76,6 +81,76 @@ "probability": [1.0], } +DATA_5 = { + "image": np.arange(25).reshape((1, 5, 5)), + "image_meta_dict": {"spatial_shape": [5, 5, 1]}, + "foreground": [[2, 2, 0]], + "background": [], +} + +DATA_6 = { + "image": np.arange(25).reshape((1, 5, 5)), + "image_meta_dict": {"spatial_shape": [5, 2, 1]}, + "foreground": [[2, 1, 0]], + "background": [[1, 0, 0]], +} + +DATA_7 = { + "image": np.arange(500).reshape((5, 10, 10)), + "image_meta_dict": {"spatial_shape": [20, 20, 10]}, + "foreground": [[10, 14, 6], [10, 14, 8]], + "background": [[10, 16, 8]], + "slice": 6, +} + +DATA_8 = { + "image": np.arange(500).reshape((1, 5, 10, 10)), + "image_meta_dict": {"spatial_shape": [20, 20, 10]}, + "guidance": [[[3, 5, 7], [4, 5, 7]], [[4, 5, 8]]], +} + +DATA_9 = { + "image": np.arange(1000).reshape((1, 5, 10, 20)), + "image_meta_dict": {"foreground_cropped_shape": (1, 10, 20, 40)}, + "guidance": [[[6, 10, 14], [8, 10, 14]], [[8, 10, 16]]], +} + +DATA_10 = { + "image": np.arange(9).reshape((1, 1, 3, 3)), + "image_meta_dict": { + "spatial_shape": [3, 3, 1], + "foreground_start_coord": np.array([0, 0, 0]), + "foreground_end_coord": np.array([1, 3, 3]), + "foreground_original_shape": (1, 1, 3, 3), + "foreground_cropped_shape": (1, 1, 3, 3), + "original_affine": np.array( + [[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]] + ), + }, + "pred": np.array([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]]), +} + +DATA_11 = { + "image": np.arange(500).reshape((1, 5, 10, 10)), + "image_meta_dict": { + "spatial_shape": [20, 20, 10], + "foreground_start_coord": np.array([2, 2, 2]), + "foreground_end_coord": np.array([4, 4, 4]), + "foreground_original_shape": (1, 5, 10, 10), + "foreground_cropped_shape": (1, 2, 2, 2), + "original_affine": np.array( + [[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]] + ), + }, + "pred": np.array([[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]]), +} + +DATA_12 = { + "image": np.arange(27).reshape(3, 3, 3), + "image_meta_dict": {}, + "guidance": [[0, 0, 0], [0, 1, 1], 1], +} + FIND_SLICE_TEST_CASE_1 = [ {"label": "label", "sids": "sids"}, DATA_1, @@ -159,6 +234,111 @@ np.array([[[[1, 0, 2, 2], [1, 0, 1, 3]], [[-1, -1, -1, -1], [-1, -1, -1, -1]]]]), ] +ADD_GUIDANCE_FROM_POINTS_TEST_CASE_1 = [ + {"ref_image": "image", "dimensions": 3, "guidance": "guidance", "depth_first": True}, + DATA_5, + [[0, 2, 2]], + [], +] + +ADD_GUIDANCE_FROM_POINTS_TEST_CASE_2 = [ + {"ref_image": "image", "dimensions": 3, "guidance": "guidance", "depth_first": True}, + DATA_6, + [[0, 2, 2]], + [[0, 1, 0]], +] + +ADD_GUIDANCE_FROM_POINTS_TEST_CASE_3 = [ + {"ref_image": "image", "dimensions": 3, "guidance": "guidance", "depth_first": True}, + DATA_7, + [[3, 5, 7], [4, 5, 7]], + [[4, 5, 8]], +] + +ADD_GUIDANCE_FROM_POINTS_TEST_CASE_4 = [ + {"ref_image": "image", "dimensions": 2, "guidance": "guidance", "depth_first": True}, + DATA_6, + [[2, 2]], + [[1, 0]], +] + +ADD_GUIDANCE_FROM_POINTS_TEST_CASE_5 = [ + {"ref_image": "image", "dimensions": 2, "guidance": "guidance", "depth_first": True, "slice": "slice"}, + DATA_7, + [[5, 7]], + [], +] + +SPATIAL_CROP_GUIDANCE_TEST_CASE_1 = [ + {"keys": ["image"], "guidance": "guidance", "spatial_size": [1, 4, 4], "margin": 0}, + DATA_8, + np.array([[[[357, 358]], [[457, 458]]]]), +] + +SPATIAL_CROP_GUIDANCE_TEST_CASE_2 = [ + {"keys": ["image"], "guidance": "guidance", "spatial_size": [2, 2], "margin": 1}, + DATA_8, + np.array( + [ + [ + [[246, 247, 248, 249], [256, 257, 258, 259], [266, 267, 268, 269]], + [[346, 347, 348, 349], [356, 357, 358, 359], [366, 367, 368, 369]], + [[446, 447, 448, 449], [456, 457, 458, 459], [466, 467, 468, 469]], + ] + ] + ), +] + +SPATIAL_CROP_GUIDANCE_TEST_CASE_3 = [ + {"keys": ["image"], "guidance": "guidance", "spatial_size": [3, 3], "margin": 0}, + DATA_8, + np.array( + [ + [ + [[47, 48, 49], [57, 58, 59], [67, 68, 69]], + [[147, 148, 149], [157, 158, 159], [167, 168, 169]], + [[247, 248, 249], [257, 258, 259], [267, 268, 269]], + [[347, 348, 349], [357, 358, 359], [367, 368, 369]], + [[447, 448, 449], [457, 458, 459], [467, 468, 469]], + ] + ] + ), +] + +RESIZE_GUIDANCE_TEST_CASE_1 = [ + {"ref_image": "image", "guidance": "guidance"}, + DATA_9, + [[[3, 5, 7], [4, 5, 7]], [[4, 5, 8]]], +] + +RESTORE_LABEL_TEST_CASE_1 = [ + {"keys": ["pred"], "ref_image": "image", "mode": "nearest"}, + DATA_10, + np.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]), +] + +RESULT = np.zeros((10, 20, 20)) +RESULT[4:8, 4:8, 4:8] = np.array( + [ + [[1.0, 1.0, 2.0, 2.0], [1.0, 1.0, 2.0, 2.0], [3.0, 3.0, 4.0, 4.0], [3.0, 3.0, 4.0, 4.0]], + [[1.0, 1.0, 2.0, 2.0], [1.0, 1.0, 2.0, 2.0], [3.0, 3.0, 4.0, 4.0], [3.0, 3.0, 4.0, 4.0]], + [[5.0, 5.0, 6.0, 6.0], [5.0, 5.0, 6.0, 6.0], [7.0, 7.0, 8.0, 8.0], [7.0, 7.0, 8.0, 8.0]], + [[5.0, 5.0, 6.0, 6.0], [5.0, 5.0, 6.0, 6.0], [7.0, 7.0, 8.0, 8.0], [7.0, 7.0, 8.0, 8.0]], + ], +) + +RESTORE_LABEL_TEST_CASE_2 = [ + {"keys": ["pred"], "ref_image": "image", "mode": "nearest"}, + DATA_11, + RESULT, +] + +FETCH_2D_SLICE_TEST_CASE_1 = [ + {"keys": ["image"], "guidance": "guidance"}, + DATA_12, + np.array([[9, 10, 11], [12, 13, 14], [15, 16, 17]]), +] + class TestFindAllValidSlicesd(unittest.TestCase): @parameterized.expand([FIND_SLICE_TEST_CASE_1, FIND_SLICE_TEST_CASE_2]) @@ -220,5 +400,51 @@ def test_correct_results(self, arguments, input_data, expected_result): np.testing.assert_allclose(result[arguments["guidance"]], expected_result, rtol=1e-5) +class TestAddGuidanceFromPointsd(unittest.TestCase): + @parameterized.expand( + [ + ADD_GUIDANCE_FROM_POINTS_TEST_CASE_1, + ADD_GUIDANCE_FROM_POINTS_TEST_CASE_2, + ADD_GUIDANCE_FROM_POINTS_TEST_CASE_3, + ADD_GUIDANCE_FROM_POINTS_TEST_CASE_4, + ADD_GUIDANCE_FROM_POINTS_TEST_CASE_5, + ] + ) + def test_correct_results(self, arguments, input_data, expected_pos, expected_neg): + result = AddGuidanceFromPointsd(**arguments)(input_data) + self.assertEqual(result[arguments["guidance"]][0], expected_pos) + self.assertEqual(result[arguments["guidance"]][1], expected_neg) + + +class TestSpatialCropGuidanced(unittest.TestCase): + @parameterized.expand( + [SPATIAL_CROP_GUIDANCE_TEST_CASE_1, SPATIAL_CROP_GUIDANCE_TEST_CASE_2, SPATIAL_CROP_GUIDANCE_TEST_CASE_3] + ) + def test_correct_results(self, arguments, input_data, expected_result): + result = SpatialCropGuidanced(**arguments)(input_data) + np.testing.assert_allclose(result["image"], expected_result) + + +class TestResizeGuidanced(unittest.TestCase): + @parameterized.expand([RESIZE_GUIDANCE_TEST_CASE_1]) + def test_correct_results(self, arguments, input_data, expected_result): + result = ResizeGuidanced(**arguments)(input_data) + self.assertEqual(result[arguments["guidance"]], expected_result) + + +class TestRestoreLabeld(unittest.TestCase): + @parameterized.expand([RESTORE_LABEL_TEST_CASE_1, RESTORE_LABEL_TEST_CASE_2]) + def test_correct_results(self, arguments, input_data, expected_result): + result = RestoreLabeld(**arguments)(input_data) + np.testing.assert_allclose(result["pred"], expected_result) + + +class TestFetch2DSliced(unittest.TestCase): + @parameterized.expand([FETCH_2D_SLICE_TEST_CASE_1]) + def test_correct_results(self, arguments, input_data, expected_result): + result = Fetch2DSliced(**arguments)(input_data) + np.testing.assert_allclose(result["image"], expected_result) + + if __name__ == "__main__": unittest.main() From c7c91785ab1376169a36908bf769dea12f6dba74 Mon Sep 17 00:00:00 2001 From: YuanTingHsieh Date: Thu, 25 Feb 2021 19:08:54 -0800 Subject: [PATCH 2/3] Remove unused import Signed-off-by: YuanTingHsieh --- monai/apps/deepgrow/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/apps/deepgrow/transforms.py b/monai/apps/deepgrow/transforms.py index 6d63a857e9..2eb3496f41 100644 --- a/monai/apps/deepgrow/transforms.py +++ b/monai/apps/deepgrow/transforms.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import json -from typing import Callable, Dict, Hashable, Mapping, Optional, Sequence, Union +from typing import Callable, Optional, Sequence, Union import numpy as np import torch From 168eb198e2a7f13d67ec6f152d9a5775e77d519c Mon Sep 17 00:00:00 2001 From: YuanTingHsieh Date: Fri, 26 Feb 2021 15:16:38 -0800 Subject: [PATCH 3/3] Fix review comments Signed-off-by: YuanTingHsieh --- monai/apps/deepgrow/transforms.py | 65 ++++++++++++++++--------------- tests/test_deepgrow_transforms.py | 10 ++++- 2 files changed, 43 insertions(+), 32 deletions(-) diff --git a/monai/apps/deepgrow/transforms.py b/monai/apps/deepgrow/transforms.py index 4d81b6cf1e..cc01a717ad 100644 --- a/monai/apps/deepgrow/transforms.py +++ b/monai/apps/deepgrow/transforms.py @@ -8,8 +8,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json -from typing import Callable, Optional, Sequence, Union +from typing import Callable, Dict, Optional, Sequence, Union import numpy as np import torch @@ -48,7 +47,7 @@ def _apply(self, label): return np.asarray(sids) def __call__(self, data): - d = dict(data) + d: Dict = dict(data) label = d[self.label] if label.shape[0] != 1: raise ValueError("Only supports single channel labels!") @@ -87,14 +86,23 @@ def __init__( sid: str = "sid", connected_regions: int = 5, ): + super().__init__(prob=1.0, do_transform=True) self.label = label - self.sids = sids - self.sid = sid + self.sids_key = sids + self.sid_key = sid + self.sid = None self.guidance = guidance self.connected_regions = connected_regions - def randomize(self, data=None): - pass + def randomize(self, data): + sid = data.get(self.sid_key, None) + sids = data.get(self.sids_key, None) + if sids is not None: + if sid is None or sid not in sids: + sid = self.R.choice(sids, replace=False) + else: + sid = None + self.sid = sid def _apply(self, label, sid): dimensions = 3 if len(label.shape) > 3 else 2 @@ -135,14 +143,8 @@ def _apply(self, label, sid): def __call__(self, data): d = dict(data) - sid = d.get(self.sid, None) - sids = d.get(self.sids, None) - if sids is not None: - if sid is None or sid not in sids: - sid = self.R.choice(sids, replace=False) - else: - sid = None - d[self.guidance] = self._apply(d[self.label], sid) + self.randomize(data) + d[self.guidance] = self._apply(d[self.label], self.sid) return d @@ -317,6 +319,7 @@ def __init__( probability: str = "probability", batched: bool = True, ): + super().__init__(prob=1.0, do_transform=True) self.guidance = guidance self.discrepancy = discrepancy self.probability = probability @@ -469,8 +472,8 @@ def __call__(self, data): d[self.source_key], self.select_fn, self.channel_indices, self.margin ) - center = np.mean([box_start, box_end], axis=0).astype(int).tolist() - current_size = np.subtract(box_end, box_start).astype(int).tolist() + center = list(np.mean([box_start, box_end], axis=0).astype(int)) + current_size = list(np.subtract(box_end, box_start).astype(int)) if np.all(np.less(current_size, self.spatial_size)): cropper = SpatialCrop(roi_center=center, roi_size=self.spatial_size) @@ -515,7 +518,7 @@ class AddGuidanceFromPointsd(Transform): axis: axis that represents slices in 3D volume. (axis to Depth) depth_first: if depth (slices) is positioned at first dimension. dimensions: dimensions based on model used for deepgrow (2D vs 3D). - slice: key that represents applicable slice to add guidance. + slice_key: key that represents applicable slice to add guidance. meta_key_postfix: use `{ref_image}_{postfix}` to to fetch the meta data according to the key data, default is `meta_dict`, the meta data is a dictionary object. For example, to handle key `image`, read/write affine matrices from the @@ -531,7 +534,7 @@ def __init__( axis: int = 0, depth_first: bool = True, dimensions: int = 2, - slice: str = "slice", + slice_key: str = "slice", meta_key_postfix: str = "meta_dict", ): self.ref_image = ref_image @@ -541,7 +544,7 @@ def __init__( self.axis = axis self.depth_first = depth_first self.dimensions = dimensions - self.slice = slice + self.slice = slice_key self.meta_key_postfix = meta_key_postfix def _apply(self, pos_clicks, neg_clicks, factor, slice_num): @@ -552,7 +555,7 @@ def _apply(self, pos_clicks, neg_clicks, factor, slice_num): points.extend(neg_clicks) points = np.array(points) - slices = np.unique(points[:, self.axis]).tolist() + slices = list(np.unique(points[:, self.axis])) slice_idx = slices[0] if slice_num is None else next(x for x in slices if x == slice_num) if len(pos_clicks): @@ -591,11 +594,11 @@ def __call__(self, data): fg_bg_clicks = [] for key in [self.foreground, self.background]: - clicks = json.loads(d[key]) if isinstance(d[key], str) else d[key] - clicks = np.array(clicks).astype(int).tolist() + clicks = d[key] + clicks = list(np.array(clicks).astype(int)) if self.depth_first: for i in range(len(clicks)): - clicks[i] = np.roll(clicks[i], 1).tolist() + clicks[i] = list(np.roll(clicks[i], 1)) fg_bg_clicks.append(clicks) d[self.guidance] = self._apply(fg_bg_clicks[0], fg_bg_clicks[1], factor, d.get(self.slice, None)) return d @@ -669,14 +672,14 @@ def bounding_box(self, points, img_shape): return box_start, box_end def __call__(self, data): - d = dict(data) + d: Dict = dict(data) guidance = d[self.guidance] original_spatial_shape = d[self.keys[0]].shape[1:] box_start, box_end = self.bounding_box(np.array(guidance[0] + guidance[1]), original_spatial_shape) - center = np.mean([box_start, box_end], axis=0).astype(int).tolist() + center = list(np.mean([box_start, box_end], axis=0).astype(int)) spatial_size = self.spatial_size - box_size = np.subtract(box_end, box_start).astype(int).tolist() + box_size = list(np.subtract(box_end, box_start).astype(int)) spatial_size = spatial_size[-len(box_size) :] if len(spatial_size) < len(box_size): @@ -746,7 +749,7 @@ def __init__( def __call__(self, data): d = dict(data) guidance = d[self.guidance] - meta_dict = d[f"{self.ref_image}_{self.meta_key_postfix}"] + meta_dict: Dict = d[f"{self.ref_image}_{self.meta_key_postfix}"] current_shape = d[self.ref_image].shape[1:] cropped_shape = meta_dict[self.cropped_shape_key][1:] factor = np.divide(current_shape, cropped_shape) @@ -829,7 +832,7 @@ def __init__( def __call__(self, data): d = dict(data) - meta_dict = d[f"{self.ref_image}_{self.meta_key_postfix}"] + meta_dict: Dict = d[f"{self.ref_image}_{self.meta_key_postfix}"] for idx, key in enumerate(self.keys): image = d[key] @@ -855,7 +858,7 @@ def __call__(self, data): # Undo Spacing current_size = result.shape[1:] # change spatial_shape from HWD to DHW - spatial_shape = np.roll(meta_dict["spatial_shape"], 1).tolist() + spatial_shape = list(np.roll(meta_dict["spatial_shape"], 1)) spatial_size = spatial_shape[-len(current_size) :] if np.any(np.not_equal(current_size, spatial_size)): @@ -868,7 +871,7 @@ def __call__(self, data): final_result = result if len(result.shape) <= 3 else result[0] else: slice_idx = meta_dict["slice_idx"][0] - final_result = np.zeros(spatial_shape) + final_result = np.zeros(tuple(spatial_shape)) final_result[slice_idx] = result d[key] = final_result diff --git a/tests/test_deepgrow_transforms.py b/tests/test_deepgrow_transforms.py index 1f7e8a2488..2d57ed9325 100644 --- a/tests/test_deepgrow_transforms.py +++ b/tests/test_deepgrow_transforms.py @@ -263,12 +263,19 @@ ] ADD_GUIDANCE_FROM_POINTS_TEST_CASE_5 = [ - {"ref_image": "image", "dimensions": 2, "guidance": "guidance", "depth_first": True, "slice": "slice"}, + {"ref_image": "image", "dimensions": 2, "guidance": "guidance", "depth_first": True, "slice_key": "slice"}, DATA_7, [[5, 7]], [], ] +ADD_GUIDANCE_FROM_POINTS_TEST_CASE_6 = [ + {"ref_image": "image", "dimensions": 2, "guidance": "guidance", "depth_first": True}, + DATA_5, + [[2, 2]], + [], +] + SPATIAL_CROP_GUIDANCE_TEST_CASE_1 = [ {"keys": ["image"], "guidance": "guidance", "spatial_size": [1, 4, 4], "margin": 0}, DATA_8, @@ -408,6 +415,7 @@ class TestAddGuidanceFromPointsd(unittest.TestCase): ADD_GUIDANCE_FROM_POINTS_TEST_CASE_3, ADD_GUIDANCE_FROM_POINTS_TEST_CASE_4, ADD_GUIDANCE_FROM_POINTS_TEST_CASE_5, + ADD_GUIDANCE_FROM_POINTS_TEST_CASE_6, ] ) def test_correct_results(self, arguments, input_data, expected_pos, expected_neg):