Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,4 +524,15 @@
weighted_patch_samples,
zero_margins,
)
from .utils_pytorch_numpy_unification import clip, in1d, moveaxis, percentile, where
from .utils_pytorch_numpy_unification import (
any_np_pt,
clip,
floor_divide,
in1d,
moveaxis,
nonzero,
percentile,
ravel,
unravel_index,
where,
)
12 changes: 7 additions & 5 deletions monai/transforms/croppad/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,16 +910,18 @@ def randomize(
image: Optional[np.ndarray] = None,
) -> None:
self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:])
fg_indices_: np.ndarray
bg_indices_: np.ndarray
if fg_indices is None or bg_indices is None:
if self.fg_indices is not None and self.bg_indices is not None:
fg_indices_ = self.fg_indices
bg_indices_ = self.bg_indices
else:
fg_indices_, bg_indices_ = map_binary_to_indices(label, image, self.image_threshold)
fg_indices_, bg_indices_ = map_binary_to_indices(label, image, self.image_threshold) # type: ignore
else:
fg_indices_ = fg_indices
bg_indices_ = bg_indices
self.centers = generate_pos_neg_label_crop_centers(
self.centers = generate_pos_neg_label_crop_centers( # type: ignore
self.spatial_size, self.num_samples, self.pos_ratio, label.shape[1:], fg_indices_, bg_indices_, self.R
)

Expand Down Expand Up @@ -1052,15 +1054,15 @@ def randomize(
image: Optional[np.ndarray] = None,
) -> None:
self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:])
indices_: List[np.ndarray]
indices_: Sequence[np.ndarray]
if indices is None:
if self.indices is not None:
indices_ = self.indices
else:
indices_ = map_classes_to_indices(label, self.num_classes, image, self.image_threshold)
indices_ = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) # type: ignore
else:
indices_ = indices
self.centers = generate_label_classes_crop_centers(
self.centers = generate_label_classes_crop_centers( # type: ignore
self.spatial_size, self.num_samples, label.shape[1:], indices_, self.ratios, self.R
)

Expand Down
10 changes: 6 additions & 4 deletions monai/transforms/croppad/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,13 +1100,15 @@ def randomize(
bg_indices: Optional[np.ndarray] = None,
image: Optional[np.ndarray] = None,
) -> None:
fg_indices_: np.ndarray
bg_indices_: np.ndarray
self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:])
if fg_indices is None or bg_indices is None:
fg_indices_, bg_indices_ = map_binary_to_indices(label, image, self.image_threshold)
fg_indices_, bg_indices_ = map_binary_to_indices(label, image, self.image_threshold) # type: ignore
else:
fg_indices_ = fg_indices
bg_indices_ = bg_indices
self.centers = generate_pos_neg_label_crop_centers(
self.centers = generate_pos_neg_label_crop_centers( # type: ignore
self.spatial_size, self.num_samples, self.pos_ratio, label.shape[1:], fg_indices_, bg_indices_, self.R
)

Expand Down Expand Up @@ -1283,10 +1285,10 @@ def randomize(
self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:])
indices_: List[np.ndarray]
if indices is None:
indices_ = map_classes_to_indices(label, self.num_classes, image, self.image_threshold)
indices_ = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) # type: ignore
else:
indices_ = indices
self.centers = generate_label_classes_crop_centers(
self.centers = generate_label_classes_crop_centers( # type: ignore
self.spatial_size, self.num_samples, label.shape[1:], indices_, self.ratios, self.R
)

Expand Down
7 changes: 5 additions & 2 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,12 +810,14 @@ def __call__(
output_shape: expected shape of output indices. if None, use `self.output_shape` instead.

"""
fg_indices: np.ndarray
bg_indices: np.ndarray
label, *_ = convert_data_type(label, np.ndarray) # type: ignore
if image is not None:
image, *_ = convert_data_type(image, np.ndarray) # type: ignore
if output_shape is None:
output_shape = self.output_shape
fg_indices, bg_indices = map_binary_to_indices(label, image, self.image_threshold)
fg_indices, bg_indices = map_binary_to_indices(label, image, self.image_threshold) # type: ignore
if output_shape is not None:
fg_indices = np.stack([np.unravel_index(i, output_shape) for i in fg_indices])
bg_indices = np.stack([np.unravel_index(i, output_shape) for i in bg_indices])
Expand Down Expand Up @@ -868,7 +870,8 @@ def __call__(

if output_shape is None:
output_shape = self.output_shape
indices = map_classes_to_indices(label, self.num_classes, image, self.image_threshold)
indices: List[np.ndarray]
indices = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) # type: ignore
if output_shape is not None:
indices = [np.stack([np.unravel_index(i, output_shape) for i in array]) for array in indices]

Expand Down
74 changes: 41 additions & 33 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from monai.networks.layers import GaussianFilter
from monai.transforms.compose import Compose, OneOf
from monai.transforms.transform import MapTransform, Transform
from monai.transforms.utils_pytorch_numpy_unification import any_np_pt, nonzero, ravel, unravel_index
from monai.utils import (
GridSampleMode,
InterpolateMode,
Expand Down Expand Up @@ -261,10 +262,10 @@ def resize_center(img: np.ndarray, *resize_dims: Optional[int], fill_value: floa


def map_binary_to_indices(
label: np.ndarray,
image: Optional[np.ndarray] = None,
label: NdarrayOrTensor,
image: Optional[NdarrayOrTensor] = None,
image_threshold: float = 0.0,
) -> Tuple[np.ndarray, np.ndarray]:
) -> Tuple[NdarrayOrTensor, NdarrayOrTensor]:
"""
Compute the foreground and background of input label data, return the indices after fattening.
For example:
Expand All @@ -277,28 +278,31 @@ def map_binary_to_indices(
to define background. so the output items will not map to all the voxels in the label.
image_threshold: if enabled `image`, use ``image > image_threshold`` to
determine the valid image content area and select background only in this area.

"""

# Prepare fg/bg indices
if label.shape[0] > 1:
label = label[1:] # for One-Hot format data, remove the background channel
label_flat = np.any(label, axis=0).ravel() # in case label has multiple dimensions
fg_indices = np.nonzero(label_flat)[0]
label_flat = ravel(any_np_pt(label, 0)) # in case label has multiple dimensions
fg_indices = nonzero(label_flat)
if image is not None:
img_flat = np.any(image > image_threshold, axis=0).ravel()
bg_indices = np.nonzero(np.logical_and(img_flat, ~label_flat))[0]
img_flat = ravel(any_np_pt(image > image_threshold, 0))
img_flat, *_ = convert_data_type(
img_flat, type(label), device=label.device if isinstance(label, torch.Tensor) else None
)
bg_indices = nonzero(img_flat & ~label_flat)
else:
bg_indices = np.nonzero(~label_flat)[0]
bg_indices = nonzero(~label_flat)

return fg_indices, bg_indices


def map_classes_to_indices(
label: np.ndarray,
label: NdarrayOrTensor,
num_classes: Optional[int] = None,
image: Optional[np.ndarray] = None,
image: Optional[NdarrayOrTensor] = None,
image_threshold: float = 0.0,
) -> List[np.ndarray]:
) -> List[NdarrayOrTensor]:
"""
Filter out indices of every class of the input label data, return the indices after fattening.
It can handle both One-Hot format label and Argmax format label, must provide `num_classes` for
Expand All @@ -318,11 +322,11 @@ def map_classes_to_indices(
determine the valid image content area and select class indices only in this area.

"""
img_flat: Optional[np.ndarray] = None
img_flat: Optional[NdarrayOrTensor] = None
if image is not None:
img_flat = np.any(image > image_threshold, axis=0).ravel()
img_flat = ravel((image > image_threshold).any(0))

indices: List[np.ndarray] = []
indices: List[NdarrayOrTensor] = []
# assuming the first dimension is channel
channels = len(label)

Expand All @@ -333,9 +337,9 @@ def map_classes_to_indices(
num_classes_ = num_classes

for c in range(num_classes_):
label_flat = np.any(label[c : c + 1] if channels > 1 else label == c, axis=0).ravel()
label_flat = np.logical_and(img_flat, label_flat) if img_flat is not None else label_flat
indices.append(np.nonzero(label_flat)[0])
label_flat = ravel(any_np_pt(label[c : c + 1] if channels > 1 else label == c, 0))
label_flat = img_flat & label_flat if img_flat is not None else label_flat
indices.append(nonzero(label_flat))

return indices

Expand Down Expand Up @@ -385,8 +389,10 @@ def weighted_patch_samples(


def correct_crop_centers(
centers: List[np.ndarray], spatial_size: Union[Sequence[int], int], label_spatial_shape: Sequence[int]
) -> List[np.ndarray]:
centers: List[Union[int, torch.Tensor]],
spatial_size: Union[Sequence[int], int],
label_spatial_shape: Sequence[int],
) -> List[int]:
"""
Utility to correct the crop center if the crop size is bigger than the image size.

Expand Down Expand Up @@ -419,18 +425,20 @@ def correct_crop_centers(
center_i = valid_end[i] - 1
centers[i] = center_i

return centers
corrected_centers: List[int] = [c.item() if isinstance(c, torch.Tensor) else c for c in centers] # type: ignore

return corrected_centers


def generate_pos_neg_label_crop_centers(
spatial_size: Union[Sequence[int], int],
num_samples: int,
pos_ratio: float,
label_spatial_shape: Sequence[int],
fg_indices: np.ndarray,
bg_indices: np.ndarray,
fg_indices: NdarrayOrTensor,
bg_indices: NdarrayOrTensor,
rand_state: Optional[np.random.RandomState] = None,
) -> List[List[np.ndarray]]:
) -> List[List[int]]:
"""
Generate valid sample locations based on the label with option for specifying foreground ratio
Valid: samples sitting entirely within image, expected input shape: [C, H, W, D] or [C, H, W]
Expand All @@ -453,11 +461,12 @@ def generate_pos_neg_label_crop_centers(
rand_state = np.random.random.__self__ # type: ignore

centers = []
fg_indices, bg_indices = np.asarray(fg_indices), np.asarray(bg_indices)
if fg_indices.size == 0 and bg_indices.size == 0:
fg_indices = np.asarray(fg_indices) if isinstance(fg_indices, Sequence) else fg_indices
bg_indices = np.asarray(bg_indices) if isinstance(bg_indices, Sequence) else bg_indices
if len(fg_indices) == 0 and len(bg_indices) == 0:
raise ValueError("No sampling location available.")

if fg_indices.size == 0 or bg_indices.size == 0:
if len(fg_indices) == 0 or len(bg_indices) == 0:
warnings.warn(
f"N foreground {len(fg_indices)}, N background {len(bg_indices)},"
"unable to generate class balanced samples."
Expand All @@ -467,7 +476,8 @@ def generate_pos_neg_label_crop_centers(
for _ in range(num_samples):
indices_to_use = fg_indices if rand_state.rand() < pos_ratio else bg_indices
random_int = rand_state.randint(len(indices_to_use))
center = np.unravel_index(indices_to_use[random_int], label_spatial_shape)
idx = indices_to_use[random_int]
center = unravel_index(idx, label_spatial_shape)
# shift center to range of valid centers
center_ori = list(center)
centers.append(correct_crop_centers(center_ori, spatial_size, label_spatial_shape))
Expand All @@ -479,10 +489,10 @@ def generate_label_classes_crop_centers(
spatial_size: Union[Sequence[int], int],
num_samples: int,
label_spatial_shape: Sequence[int],
indices: List[np.ndarray],
indices: Sequence[NdarrayOrTensor],
ratios: Optional[List[Union[float, int]]] = None,
rand_state: Optional[np.random.RandomState] = None,
) -> List[List[np.ndarray]]:
) -> List[List[int]]:
"""
Generate valid sample locations based on the specified ratios of label classes.
Valid: samples sitting entirely within image, expected input shape: [C, H, W, D] or [C, H, W]
Expand All @@ -508,8 +518,6 @@ def generate_label_classes_crop_centers(
if any(i < 0 for i in ratios_):
raise ValueError("ratios should not contain negative number.")

# ensure indices are numpy array
indices = [np.asarray(i) for i in indices]
for i, array in enumerate(indices):
if len(array) == 0:
warnings.warn(f"no available indices of class {i} to crop, set the crop ratio of this class to zero.")
Expand All @@ -521,7 +529,7 @@ def generate_label_classes_crop_centers(
# randomly select the indices of a class based on the ratios
indices_to_use = indices[i]
random_int = rand_state.randint(len(indices_to_use))
center = np.unravel_index(indices_to_use[random_int], label_spatial_shape)
center = unravel_index(indices_to_use[random_int], label_spatial_shape)
# shift center to range of valid centers
center_ori = list(center)
centers.append(correct_crop_centers(center_ori, spatial_size, label_spatial_shape))
Expand Down
Loading