Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 27 additions & 13 deletions monai/apps/deepgrow/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ class SpatialCropForegroundd(MapTransform):
end_coord_key: key to record the end coordinate of spatial bounding box for foreground.
original_shape_key: key to record original shape for foreground.
cropped_shape_key: key to record cropped shape for foreground.
allow_missing_keys: don't raise exception if key is missing.
"""

def __init__(
Expand All @@ -452,8 +453,9 @@ def __init__(
end_coord_key: str = "foreground_end_coord",
original_shape_key: str = "foreground_original_shape",
cropped_shape_key: str = "foreground_cropped_shape",
allow_missing_keys: bool = False,
) -> None:
super().__init__(keys)
super().__init__(keys, allow_missing_keys)

self.source_key = source_key
self.spatial_size = list(spatial_size)
Expand Down Expand Up @@ -482,7 +484,7 @@ def __call__(self, data):
else:
cropper = SpatialCrop(roi_start=box_start, roi_end=box_end)

for key in self.keys:
for key in self.key_iterator(d):
meta_key = f"{key}_{self.meta_key_postfix}"
d[meta_key][self.start_coord_key] = box_start
d[meta_key][self.end_coord_key] = box_end
Expand Down Expand Up @@ -629,6 +631,7 @@ class SpatialCropGuidanced(MapTransform):
end_coord_key: key to record the end coordinate of spatial bounding box for foreground.
original_shape_key: key to record original shape for foreground.
cropped_shape_key: key to record cropped shape for foreground.
allow_missing_keys: don't raise exception if key is missing.
"""

def __init__(
Expand All @@ -642,8 +645,9 @@ def __init__(
end_coord_key: str = "foreground_end_coord",
original_shape_key: str = "foreground_original_shape",
cropped_shape_key: str = "foreground_cropped_shape",
allow_missing_keys: bool = False,
) -> None:
super().__init__(keys)
super().__init__(keys, allow_missing_keys)

self.guidance = guidance
self.spatial_size = list(spatial_size)
Expand Down Expand Up @@ -697,7 +701,7 @@ def __call__(self, data):
cropper = SpatialCrop(roi_start=box_start, roi_end=box_end)
box_start, box_end = cropper.roi_start, cropper.roi_end

for key in self.keys:
for key in self.key_iterator(d):
if not np.array_equal(d[key].shape[1:], original_spatial_shape):
raise RuntimeError("All the image specified in keys should have same spatial shape")
meta_key = f"{key}_{self.meta_key_postfix}"
Expand Down Expand Up @@ -804,6 +808,7 @@ class RestoreLabeld(MapTransform):
end_coord_key: key that records the end coordinate of spatial bounding box for foreground.
original_shape_key: key that records original shape for foreground.
cropped_shape_key: key that records cropped shape for foreground.
allow_missing_keys: don't raise exception if key is missing.
"""

def __init__(
Expand All @@ -818,8 +823,9 @@ def __init__(
end_coord_key: str = "foreground_end_coord",
original_shape_key: str = "foreground_original_shape",
cropped_shape_key: str = "foreground_cropped_shape",
allow_missing_keys: bool = False,
) -> None:
super().__init__(keys)
super().__init__(keys, allow_missing_keys)
self.ref_image = ref_image
self.slice_only = slice_only
self.mode = ensure_tuple_rep(mode, len(self.keys))
Expand All @@ -834,15 +840,15 @@ def __call__(self, data):
d = dict(data)
meta_dict: Dict = d[f"{self.ref_image}_{self.meta_key_postfix}"]

for idx, key in enumerate(self.keys):
for key, mode, align_corners in self.key_iterator(d, self.mode, self.align_corners):
image = d[key]

# Undo Resize
current_shape = image.shape
cropped_shape = meta_dict[self.cropped_shape_key]
if np.any(np.not_equal(current_shape, cropped_shape)):
resizer = Resize(spatial_size=cropped_shape[1:], mode=self.mode[idx])
image = resizer(image, mode=self.mode[idx], align_corners=self.align_corners[idx])
resizer = Resize(spatial_size=cropped_shape[1:], mode=mode)
image = resizer(image, mode=mode, align_corners=align_corners)

# Undo Crop
original_shape = meta_dict[self.original_shape_key]
Expand All @@ -862,8 +868,8 @@ def __call__(self, data):
spatial_size = spatial_shape[-len(current_size) :]

if np.any(np.not_equal(current_size, spatial_size)):
resizer = Resize(spatial_size=spatial_size, mode=self.mode[idx])
result = resizer(result, mode=self.mode[idx], align_corners=self.align_corners[idx])
resizer = Resize(spatial_size=spatial_size, mode=mode)
result = resizer(result, mode=mode, align_corners=align_corners)

# Undo Slicing
slice_idx = meta_dict.get("slice_idx")
Expand Down Expand Up @@ -898,10 +904,18 @@ class Fetch2DSliced(MapTransform):
default is `meta_dict`, the meta data is a dictionary object.
For example, to handle key `image`, read/write affine matrices from the
metadata `image_meta_dict` dictionary's `affine` field.
allow_missing_keys: don't raise exception if key is missing.
"""

def __init__(self, keys, guidance="guidance", axis: int = 0, meta_key_postfix: str = "meta_dict"):
super().__init__(keys)
def __init__(
self,
keys,
guidance="guidance",
axis: int = 0,
meta_key_postfix: str = "meta_dict",
allow_missing_keys: bool = False,
):
super().__init__(keys, allow_missing_keys)
self.guidance = guidance
self.axis = axis
self.meta_key_postfix = meta_key_postfix
Expand All @@ -920,7 +934,7 @@ def __call__(self, data):
guidance = d[self.guidance]
if len(guidance) < 3:
raise RuntimeError("Guidance does not container slice_idx!")
for key in self.keys:
for key in self.key_iterator(d):
img_slice, idx = self._apply(d[key], guidance)
d[key] = img_slice
d[f"{key}_{self.meta_key_postfix}"]["slice_idx"] = idx
Expand Down
Loading