Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 42 additions & 48 deletions monai/transforms/croppad/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,17 +884,19 @@ class RandCropByPosNegLabel(Randomizable, Transform):

"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(
self,
spatial_size: Union[Sequence[int], int],
label: Optional[np.ndarray] = None,
label: Optional[NdarrayOrTensor] = None,
pos: float = 1.0,
neg: float = 1.0,
num_samples: int = 1,
image: Optional[np.ndarray] = None,
image: Optional[NdarrayOrTensor] = None,
image_threshold: float = 0.0,
fg_indices: Optional[np.ndarray] = None,
bg_indices: Optional[np.ndarray] = None,
fg_indices: Optional[NdarrayOrTensor] = None,
bg_indices: Optional[NdarrayOrTensor] = None,
) -> None:
self.spatial_size = ensure_tuple(spatial_size)
self.label = label
Expand All @@ -906,41 +908,39 @@ def __init__(
self.num_samples = num_samples
self.image = image
self.image_threshold = image_threshold
self.centers: Optional[List[List[np.ndarray]]] = None
self.centers: Optional[List[List[int]]] = None
self.fg_indices = fg_indices
self.bg_indices = bg_indices

def randomize(
self,
label: np.ndarray,
fg_indices: Optional[np.ndarray] = None,
bg_indices: Optional[np.ndarray] = None,
image: Optional[np.ndarray] = None,
label: NdarrayOrTensor,
fg_indices: Optional[NdarrayOrTensor] = None,
bg_indices: Optional[NdarrayOrTensor] = None,
image: Optional[NdarrayOrTensor] = None,
) -> None:
self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:])
fg_indices_: np.ndarray
bg_indices_: np.ndarray
if fg_indices is None or bg_indices is None:
if self.fg_indices is not None and self.bg_indices is not None:
fg_indices_ = self.fg_indices
bg_indices_ = self.bg_indices
else:
fg_indices_, bg_indices_ = map_binary_to_indices(label, image, self.image_threshold) # type: ignore
fg_indices_, bg_indices_ = map_binary_to_indices(label, image, self.image_threshold)
else:
fg_indices_ = fg_indices
bg_indices_ = bg_indices
self.centers = generate_pos_neg_label_crop_centers( # type: ignore
self.centers = generate_pos_neg_label_crop_centers(
self.spatial_size, self.num_samples, self.pos_ratio, label.shape[1:], fg_indices_, bg_indices_, self.R
)

def __call__(
self,
img: np.ndarray,
label: Optional[np.ndarray] = None,
image: Optional[np.ndarray] = None,
fg_indices: Optional[np.ndarray] = None,
bg_indices: Optional[np.ndarray] = None,
) -> List[np.ndarray]:
img: NdarrayOrTensor,
label: Optional[NdarrayOrTensor] = None,
image: Optional[NdarrayOrTensor] = None,
fg_indices: Optional[NdarrayOrTensor] = None,
bg_indices: Optional[NdarrayOrTensor] = None,
) -> List[NdarrayOrTensor]:
"""
Args:
img: input data to crop samples from based on the pos/neg ratio of `label` and `image`.
Expand All @@ -962,16 +962,12 @@ def __call__(
if image is None:
image = self.image

image, *_ = convert_data_type(image, np.ndarray) # type: ignore
label, *_ = convert_data_type(label, np.ndarray) # type: ignore

self.randomize(label, fg_indices, bg_indices, image)
results: List[np.ndarray] = []
results: List[NdarrayOrTensor] = []
if self.centers is not None:
for center in self.centers:
cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore
cropped: np.ndarray = cropper(img) # type: ignore
results.append(cropped)
cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size)
results.append(cropper(img))

return results

Expand Down Expand Up @@ -1035,16 +1031,18 @@ class RandCropByLabelClasses(Randomizable, Transform):

"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(
self,
spatial_size: Union[Sequence[int], int],
ratios: Optional[List[Union[float, int]]] = None,
label: Optional[np.ndarray] = None,
label: Optional[NdarrayOrTensor] = None,
num_classes: Optional[int] = None,
num_samples: int = 1,
image: Optional[np.ndarray] = None,
image: Optional[NdarrayOrTensor] = None,
image_threshold: float = 0.0,
indices: Optional[List[np.ndarray]] = None,
indices: Optional[List[NdarrayOrTensor]] = None,
) -> None:
self.spatial_size = ensure_tuple(spatial_size)
self.ratios = ratios
Expand All @@ -1053,35 +1051,35 @@ def __init__(
self.num_samples = num_samples
self.image = image
self.image_threshold = image_threshold
self.centers: Optional[List[List[np.ndarray]]] = None
self.centers: Optional[List[List[int]]] = None
self.indices = indices

def randomize(
self,
label: np.ndarray,
indices: Optional[List[np.ndarray]] = None,
image: Optional[np.ndarray] = None,
label: NdarrayOrTensor,
indices: Optional[List[NdarrayOrTensor]] = None,
image: Optional[NdarrayOrTensor] = None,
) -> None:
self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:])
indices_: Sequence[np.ndarray]
indices_: Sequence[NdarrayOrTensor]
if indices is None:
if self.indices is not None:
indices_ = self.indices
else:
indices_ = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) # type: ignore
indices_ = map_classes_to_indices(label, self.num_classes, image, self.image_threshold)
else:
indices_ = indices
self.centers = generate_label_classes_crop_centers( # type: ignore
self.centers = generate_label_classes_crop_centers(
self.spatial_size, self.num_samples, label.shape[1:], indices_, self.ratios, self.R
)

def __call__(
self,
img: np.ndarray,
label: Optional[np.ndarray] = None,
image: Optional[np.ndarray] = None,
indices: Optional[List[np.ndarray]] = None,
) -> List[np.ndarray]:
img: NdarrayOrTensor,
label: Optional[NdarrayOrTensor] = None,
image: Optional[NdarrayOrTensor] = None,
indices: Optional[List[NdarrayOrTensor]] = None,
) -> List[NdarrayOrTensor]:
"""
Args:
img: input data to crop samples from based on the ratios of every class, assumes `img` is a
Expand All @@ -1099,16 +1097,12 @@ def __call__(
if image is None:
image = self.image

image, *_ = convert_data_type(image, np.ndarray) # type: ignore
label, *_ = convert_data_type(label, np.ndarray) # type: ignore

self.randomize(label, indices, image)
results: List[np.ndarray] = []
results: List[NdarrayOrTensor] = []
if self.centers is not None:
for center in self.centers:
cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore
cropped: np.ndarray = cropper(img) # type: ignore
results.append(cropped)
cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size)
results.append(cropper(img))

return results

Expand Down
57 changes: 29 additions & 28 deletions monai/transforms/croppad/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
CenterSpatialCrop,
CropForeground,
DivisiblePad,
RandCropByLabelClasses,
RandCropByPosNegLabel,
ResizeWithPadOrCrop,
SpatialCrop,
SpatialPad,
Expand Down Expand Up @@ -1061,6 +1063,8 @@ class RandCropByPosNegLabeld(Randomizable, MapTransform, InvertibleTransform):

"""

backend = RandCropByPosNegLabel.backend

def __init__(
self,
keys: KeysCollection,
Expand Down Expand Up @@ -1094,28 +1098,26 @@ def __init__(
if len(self.keys) != len(self.meta_keys):
raise ValueError("meta_keys should have the same length as keys.")
self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys))
self.centers: Optional[List[List[np.ndarray]]] = None
self.centers: Optional[List[List[int]]] = None

def randomize(
self,
label: np.ndarray,
fg_indices: Optional[np.ndarray] = None,
bg_indices: Optional[np.ndarray] = None,
image: Optional[np.ndarray] = None,
label: NdarrayOrTensor,
fg_indices: Optional[NdarrayOrTensor] = None,
bg_indices: Optional[NdarrayOrTensor] = None,
image: Optional[NdarrayOrTensor] = None,
) -> None:
fg_indices_: np.ndarray
bg_indices_: np.ndarray
self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:])
if fg_indices is None or bg_indices is None:
fg_indices_, bg_indices_ = map_binary_to_indices(label, image, self.image_threshold) # type: ignore
fg_indices_, bg_indices_ = map_binary_to_indices(label, image, self.image_threshold)
else:
fg_indices_ = fg_indices
bg_indices_ = bg_indices
self.centers = generate_pos_neg_label_crop_centers( # type: ignore
self.centers = generate_pos_neg_label_crop_centers(
self.spatial_size, self.num_samples, self.pos_ratio, label.shape[1:], fg_indices_, bg_indices_, self.R
)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, np.ndarray]]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashable, NdarrayOrTensor]]:
d = dict(data)
label = d[self.label_key]
image = d[self.image_key] if self.image_key else None
Expand All @@ -1129,25 +1131,24 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n
raise ValueError("no available ROI centers to crop.")

# initialize returned list with shallow copy to preserve key ordering
results: List[Dict[Hashable, np.ndarray]] = [dict(d) for _ in range(self.num_samples)]
results: List[Dict[Hashable, NdarrayOrTensor]] = [dict(d) for _ in range(self.num_samples)]

for i, center in enumerate(self.centers):
# fill in the extra keys with unmodified data
for key in set(d.keys()).difference(set(self.keys)):
results[i][key] = deepcopy(d[key])
for key in self.key_iterator(d):
img = d[key]
cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore
cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size)
orig_size = img.shape[1:]
cropped: np.ndarray = cropper(img) # type: ignore
results[i][key] = cropped
results[i][key] = cropper(img)
self.push_transform(results[i], key, extra_info={"center": center}, orig_size=orig_size)
# add `patch_index` to the meta data
for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix):
meta_key = meta_key or f"{key}_{meta_key_postfix}"
if meta_key not in results[i]:
results[i][meta_key] = {} # type: ignore
results[i][meta_key][Key.PATCH_INDEX] = i
results[i][meta_key][Key.PATCH_INDEX] = i # type: ignore

return results

Expand Down Expand Up @@ -1250,6 +1251,8 @@ class RandCropByLabelClassesd(Randomizable, MapTransform, InvertibleTransform):

"""

backend = RandCropByLabelClasses.backend

def __init__(
self,
keys: KeysCollection,
Expand Down Expand Up @@ -1278,25 +1281,24 @@ def __init__(
if len(self.keys) != len(self.meta_keys):
raise ValueError("meta_keys should have the same length as keys.")
self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys))
self.centers: Optional[List[List[np.ndarray]]] = None
self.centers: Optional[List[List[int]]] = None

def randomize(
self,
label: np.ndarray,
indices: Optional[List[np.ndarray]] = None,
image: Optional[np.ndarray] = None,
label: NdarrayOrTensor,
indices: Optional[List[NdarrayOrTensor]] = None,
image: Optional[NdarrayOrTensor] = None,
) -> None:
self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:])
indices_: List[np.ndarray]
if indices is None:
indices_ = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) # type: ignore
indices_ = map_classes_to_indices(label, self.num_classes, image, self.image_threshold)
else:
indices_ = indices
self.centers = generate_label_classes_crop_centers( # type: ignore
self.centers = generate_label_classes_crop_centers(
self.spatial_size, self.num_samples, label.shape[1:], indices_, self.ratios, self.R
)

def __call__(self, data: Mapping[Hashable, Any]) -> List[Dict[Hashable, np.ndarray]]:
def __call__(self, data: Mapping[Hashable, Any]) -> List[Dict[Hashable, NdarrayOrTensor]]:
d = dict(data)
label = d[self.label_key]
image = d[self.image_key] if self.image_key else None
Expand All @@ -1309,25 +1311,24 @@ def __call__(self, data: Mapping[Hashable, Any]) -> List[Dict[Hashable, np.ndarr
raise ValueError("no available ROI centers to crop.")

# initialize returned list with shallow copy to preserve key ordering
results: List[Dict[Hashable, np.ndarray]] = [dict(d) for _ in range(self.num_samples)]
results: List[Dict[Hashable, NdarrayOrTensor]] = [dict(d) for _ in range(self.num_samples)]

for i, center in enumerate(self.centers):
# fill in the extra keys with unmodified data
for key in set(d.keys()).difference(set(self.keys)):
results[i][key] = deepcopy(d[key])
for key in self.key_iterator(d):
img = d[key]
cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore
cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size)
orig_size = img.shape[1:]
cropped: np.ndarray = cropper(img) # type: ignore
results[i][key] = cropped
results[i][key] = cropper(img)
self.push_transform(results[i], key, extra_info={"center": center}, orig_size=orig_size)
# add `patch_index` to the meta data
for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix):
meta_key = meta_key or f"{key}_{meta_key_postfix}"
if meta_key not in results[i]:
results[i][meta_key] = {} # type: ignore
results[i][meta_key][Key.PATCH_INDEX] = i
results[i][meta_key][Key.PATCH_INDEX] = i # type: ignore

return results

Expand Down
Loading