From 0ec917c718fd04e431b6c46aa554ef7b5ba1140e Mon Sep 17 00:00:00 2001 From: Keno <37253540+kbressem@users.noreply.github.com> Date: Wed, 8 Mar 2023 20:42:57 -0500 Subject: [PATCH 1/4] add `warn` flag to RandCropByLabelClasses Signed-off-by: kbressem --- monai/transforms/croppad/array.py | 5 ++++- monai/transforms/utils.py | 4 +++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index c96598c51f..2af6a5667f 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -1109,6 +1109,7 @@ class RandCropByLabelClasses(Randomizable, TraceableTransform, MultiSampleTrait) allow_smaller: if `False`, an exception will be raised if the image is smaller than the requested ROI in any dimension. If `True`, any smaller dimensions will remain unchanged. + warn: if `True` prints a warning if a class is not present in the label. """ @@ -1125,6 +1126,7 @@ def __init__( image_threshold: float = 0.0, indices: list[NdarrayOrTensor] | None = None, allow_smaller: bool = False, + warn: bool = True ) -> None: self.spatial_size = spatial_size self.ratios = ratios @@ -1136,6 +1138,7 @@ def __init__( self.centers: list[list[int]] | None = None self.indices = indices self.allow_smaller = allow_smaller + self.warn = warn def randomize( self, label: torch.Tensor, indices: list[NdarrayOrTensor] | None = None, image: torch.Tensor | None = None @@ -1149,7 +1152,7 @@ def randomize( else: indices_ = indices self.centers = generate_label_classes_crop_centers( - self.spatial_size, self.num_samples, label.shape[1:], indices_, self.ratios, self.R, self.allow_smaller + self.spatial_size, self.num_samples, label.shape[1:], indices_, self.ratios, self.R, self.allow_smaller, self.warn ) def __call__( diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index d3c8eb606f..9a852da298 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -536,6 +536,7 @@ def generate_label_classes_crop_centers( ratios: list[float | int] | None = None, rand_state: np.random.RandomState | None = None, allow_smaller: bool = False, + warn: bool = True ) -> list[list[int]]: """ Generate valid sample locations based on the specified ratios of label classes. @@ -552,6 +553,7 @@ def generate_label_classes_crop_centers( allow_smaller: if `False`, an exception will be raised if the image is smaller than the requested ROI in any dimension. If `True`, any smaller dimensions will be set to match the cropped size (i.e., no cropping in that dimension). + warn: if `True` prints a warning if a class is not present in the label. """ if rand_state is None: @@ -568,7 +570,7 @@ def generate_label_classes_crop_centers( raise ValueError(f"ratios should not contain negative number, got {ratios_}.") for i, array in enumerate(indices): - if len(array) == 0: + if len(array) == 0 and warn: warnings.warn(f"no available indices of class {i} to crop, set the crop ratio of this class to zero.") ratios_[i] = 0 From 31dd63e8af74e52aff75196853bb6323422a666d Mon Sep 17 00:00:00 2001 From: kbressem Date: Wed, 8 Mar 2023 21:43:46 -0500 Subject: [PATCH 2/4] fix codestyle Signed-off-by: kbressem --- monai/transforms/croppad/array.py | 11 +++++++++-- monai/transforms/utils.py | 2 +- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index af1380e690..b96afd2be3 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -1164,7 +1164,7 @@ def __init__( image_threshold: float = 0.0, indices: list[NdarrayOrTensor] | None = None, allow_smaller: bool = False, - warn: bool = True + warn: bool = True, ) -> None: self.spatial_size = spatial_size self.ratios = ratios @@ -1201,7 +1201,14 @@ def randomize( if _shape is None: raise ValueError("label or image must be provided to infer the output spatial shape.") self.centers = generate_label_classes_crop_centers( - self.spatial_size, self.num_samples, _shape, indices_, self.ratios, self.R, self.allow_smaller, self.warn + self.spatial_size, + self.num_samples, + label.shape[1:], + indices_, + self.ratios, + self.R, + self.allow_smaller, + self.warn, ) @LazyTransform.lazy_evaluation.setter # type: ignore diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index a083b06274..ea3577159d 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -535,7 +535,7 @@ def generate_label_classes_crop_centers( ratios: list[float | int] | None = None, rand_state: np.random.RandomState | None = None, allow_smaller: bool = False, - warn: bool = True + warn: bool = True, ) -> list[list[int]]: """ Generate valid sample locations based on the specified ratios of label classes. From d76035b6356bc3c215fa85a3b095cf88a55a09ad Mon Sep 17 00:00:00 2001 From: kbressem Date: Wed, 8 Mar 2023 22:02:33 -0500 Subject: [PATCH 3/4] put back _shape arguemnt to generate_label_classes_crop_centers Signed-off-by: kbressem --- monai/transforms/croppad/array.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index b96afd2be3..318110848b 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -1201,14 +1201,7 @@ def randomize( if _shape is None: raise ValueError("label or image must be provided to infer the output spatial shape.") self.centers = generate_label_classes_crop_centers( - self.spatial_size, - self.num_samples, - label.shape[1:], - indices_, - self.ratios, - self.R, - self.allow_smaller, - self.warn, + self.spatial_size, self.num_samples, _shape, indices_, self.ratios, self.R, self.allow_smaller, self.warn ) @LazyTransform.lazy_evaluation.setter # type: ignore From d92cd3ad17979ac66383c21ba46ed5b2e62619d8 Mon Sep 17 00:00:00 2001 From: kbressem Date: Thu, 9 Mar 2023 08:00:35 -0500 Subject: [PATCH 4/4] update dict transform Signed-off-by: kbressem --- monai/transforms/croppad/dictionary.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 624e95e9b3..1aa710d018 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -961,6 +961,8 @@ class RandCropByLabelClassesd(Randomizable, MapTransform, LazyTransform, MultiSa the requested ROI in any dimension. If `True`, any smaller dimensions will remain unchanged. allow_missing_keys: don't raise exception if key is missing. + warn: if `True` prints a warning if a class is not present in the label. + """ @@ -979,6 +981,7 @@ def __init__( indices_key: str | None = None, allow_smaller: bool = False, allow_missing_keys: bool = False, + warn: bool = True, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) self.label_key = label_key @@ -991,6 +994,7 @@ def __init__( num_samples=num_samples, image_threshold=image_threshold, allow_smaller=allow_smaller, + warn=warn, ) def set_random_state(