diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 0e6529703f..e1c915cc93 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -884,17 +884,19 @@ class RandCropByPosNegLabel(Randomizable, Transform): """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__( self, spatial_size: Union[Sequence[int], int], - label: Optional[np.ndarray] = None, + label: Optional[NdarrayOrTensor] = None, pos: float = 1.0, neg: float = 1.0, num_samples: int = 1, - image: Optional[np.ndarray] = None, + image: Optional[NdarrayOrTensor] = None, image_threshold: float = 0.0, - fg_indices: Optional[np.ndarray] = None, - bg_indices: Optional[np.ndarray] = None, + fg_indices: Optional[NdarrayOrTensor] = None, + bg_indices: Optional[NdarrayOrTensor] = None, ) -> None: self.spatial_size = ensure_tuple(spatial_size) self.label = label @@ -906,41 +908,39 @@ def __init__( self.num_samples = num_samples self.image = image self.image_threshold = image_threshold - self.centers: Optional[List[List[np.ndarray]]] = None + self.centers: Optional[List[List[int]]] = None self.fg_indices = fg_indices self.bg_indices = bg_indices def randomize( self, - label: np.ndarray, - fg_indices: Optional[np.ndarray] = None, - bg_indices: Optional[np.ndarray] = None, - image: Optional[np.ndarray] = None, + label: NdarrayOrTensor, + fg_indices: Optional[NdarrayOrTensor] = None, + bg_indices: Optional[NdarrayOrTensor] = None, + image: Optional[NdarrayOrTensor] = None, ) -> None: self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:]) - fg_indices_: np.ndarray - bg_indices_: np.ndarray if fg_indices is None or bg_indices is None: if self.fg_indices is not None and self.bg_indices is not None: fg_indices_ = self.fg_indices bg_indices_ = self.bg_indices else: - fg_indices_, bg_indices_ = map_binary_to_indices(label, image, self.image_threshold) # type: ignore + fg_indices_, bg_indices_ = map_binary_to_indices(label, image, self.image_threshold) else: fg_indices_ = fg_indices bg_indices_ = bg_indices - self.centers = generate_pos_neg_label_crop_centers( # type: ignore + self.centers = generate_pos_neg_label_crop_centers( self.spatial_size, self.num_samples, self.pos_ratio, label.shape[1:], fg_indices_, bg_indices_, self.R ) def __call__( self, - img: np.ndarray, - label: Optional[np.ndarray] = None, - image: Optional[np.ndarray] = None, - fg_indices: Optional[np.ndarray] = None, - bg_indices: Optional[np.ndarray] = None, - ) -> List[np.ndarray]: + img: NdarrayOrTensor, + label: Optional[NdarrayOrTensor] = None, + image: Optional[NdarrayOrTensor] = None, + fg_indices: Optional[NdarrayOrTensor] = None, + bg_indices: Optional[NdarrayOrTensor] = None, + ) -> List[NdarrayOrTensor]: """ Args: img: input data to crop samples from based on the pos/neg ratio of `label` and `image`. @@ -962,16 +962,12 @@ def __call__( if image is None: image = self.image - image, *_ = convert_data_type(image, np.ndarray) # type: ignore - label, *_ = convert_data_type(label, np.ndarray) # type: ignore - self.randomize(label, fg_indices, bg_indices, image) - results: List[np.ndarray] = [] + results: List[NdarrayOrTensor] = [] if self.centers is not None: for center in self.centers: - cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore - cropped: np.ndarray = cropper(img) # type: ignore - results.append(cropped) + cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) + results.append(cropper(img)) return results @@ -1035,16 +1031,18 @@ class RandCropByLabelClasses(Randomizable, Transform): """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__( self, spatial_size: Union[Sequence[int], int], ratios: Optional[List[Union[float, int]]] = None, - label: Optional[np.ndarray] = None, + label: Optional[NdarrayOrTensor] = None, num_classes: Optional[int] = None, num_samples: int = 1, - image: Optional[np.ndarray] = None, + image: Optional[NdarrayOrTensor] = None, image_threshold: float = 0.0, - indices: Optional[List[np.ndarray]] = None, + indices: Optional[List[NdarrayOrTensor]] = None, ) -> None: self.spatial_size = ensure_tuple(spatial_size) self.ratios = ratios @@ -1053,35 +1051,35 @@ def __init__( self.num_samples = num_samples self.image = image self.image_threshold = image_threshold - self.centers: Optional[List[List[np.ndarray]]] = None + self.centers: Optional[List[List[int]]] = None self.indices = indices def randomize( self, - label: np.ndarray, - indices: Optional[List[np.ndarray]] = None, - image: Optional[np.ndarray] = None, + label: NdarrayOrTensor, + indices: Optional[List[NdarrayOrTensor]] = None, + image: Optional[NdarrayOrTensor] = None, ) -> None: self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:]) - indices_: Sequence[np.ndarray] + indices_: Sequence[NdarrayOrTensor] if indices is None: if self.indices is not None: indices_ = self.indices else: - indices_ = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) # type: ignore + indices_ = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) else: indices_ = indices - self.centers = generate_label_classes_crop_centers( # type: ignore + self.centers = generate_label_classes_crop_centers( self.spatial_size, self.num_samples, label.shape[1:], indices_, self.ratios, self.R ) def __call__( self, - img: np.ndarray, - label: Optional[np.ndarray] = None, - image: Optional[np.ndarray] = None, - indices: Optional[List[np.ndarray]] = None, - ) -> List[np.ndarray]: + img: NdarrayOrTensor, + label: Optional[NdarrayOrTensor] = None, + image: Optional[NdarrayOrTensor] = None, + indices: Optional[List[NdarrayOrTensor]] = None, + ) -> List[NdarrayOrTensor]: """ Args: img: input data to crop samples from based on the ratios of every class, assumes `img` is a @@ -1099,16 +1097,12 @@ def __call__( if image is None: image = self.image - image, *_ = convert_data_type(image, np.ndarray) # type: ignore - label, *_ = convert_data_type(label, np.ndarray) # type: ignore - self.randomize(label, indices, image) - results: List[np.ndarray] = [] + results: List[NdarrayOrTensor] = [] if self.centers is not None: for center in self.centers: - cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore - cropped: np.ndarray = cropper(img) # type: ignore - results.append(cropped) + cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) + results.append(cropper(img)) return results diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 2d50ba0b34..488b832450 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -33,6 +33,8 @@ CenterSpatialCrop, CropForeground, DivisiblePad, + RandCropByLabelClasses, + RandCropByPosNegLabel, ResizeWithPadOrCrop, SpatialCrop, SpatialPad, @@ -1061,6 +1063,8 @@ class RandCropByPosNegLabeld(Randomizable, MapTransform, InvertibleTransform): """ + backend = RandCropByPosNegLabel.backend + def __init__( self, keys: KeysCollection, @@ -1094,28 +1098,26 @@ def __init__( if len(self.keys) != len(self.meta_keys): raise ValueError("meta_keys should have the same length as keys.") self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) - self.centers: Optional[List[List[np.ndarray]]] = None + self.centers: Optional[List[List[int]]] = None def randomize( self, - label: np.ndarray, - fg_indices: Optional[np.ndarray] = None, - bg_indices: Optional[np.ndarray] = None, - image: Optional[np.ndarray] = None, + label: NdarrayOrTensor, + fg_indices: Optional[NdarrayOrTensor] = None, + bg_indices: Optional[NdarrayOrTensor] = None, + image: Optional[NdarrayOrTensor] = None, ) -> None: - fg_indices_: np.ndarray - bg_indices_: np.ndarray self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:]) if fg_indices is None or bg_indices is None: - fg_indices_, bg_indices_ = map_binary_to_indices(label, image, self.image_threshold) # type: ignore + fg_indices_, bg_indices_ = map_binary_to_indices(label, image, self.image_threshold) else: fg_indices_ = fg_indices bg_indices_ = bg_indices - self.centers = generate_pos_neg_label_crop_centers( # type: ignore + self.centers = generate_pos_neg_label_crop_centers( self.spatial_size, self.num_samples, self.pos_ratio, label.shape[1:], fg_indices_, bg_indices_, self.R ) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, np.ndarray]]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashable, NdarrayOrTensor]]: d = dict(data) label = d[self.label_key] image = d[self.image_key] if self.image_key else None @@ -1129,7 +1131,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n raise ValueError("no available ROI centers to crop.") # initialize returned list with shallow copy to preserve key ordering - results: List[Dict[Hashable, np.ndarray]] = [dict(d) for _ in range(self.num_samples)] + results: List[Dict[Hashable, NdarrayOrTensor]] = [dict(d) for _ in range(self.num_samples)] for i, center in enumerate(self.centers): # fill in the extra keys with unmodified data @@ -1137,17 +1139,16 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n results[i][key] = deepcopy(d[key]) for key in self.key_iterator(d): img = d[key] - cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore + cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) orig_size = img.shape[1:] - cropped: np.ndarray = cropper(img) # type: ignore - results[i][key] = cropped + results[i][key] = cropper(img) self.push_transform(results[i], key, extra_info={"center": center}, orig_size=orig_size) # add `patch_index` to the meta data for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix): meta_key = meta_key or f"{key}_{meta_key_postfix}" if meta_key not in results[i]: results[i][meta_key] = {} # type: ignore - results[i][meta_key][Key.PATCH_INDEX] = i + results[i][meta_key][Key.PATCH_INDEX] = i # type: ignore return results @@ -1250,6 +1251,8 @@ class RandCropByLabelClassesd(Randomizable, MapTransform, InvertibleTransform): """ + backend = RandCropByLabelClasses.backend + def __init__( self, keys: KeysCollection, @@ -1278,25 +1281,24 @@ def __init__( if len(self.keys) != len(self.meta_keys): raise ValueError("meta_keys should have the same length as keys.") self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) - self.centers: Optional[List[List[np.ndarray]]] = None + self.centers: Optional[List[List[int]]] = None def randomize( self, - label: np.ndarray, - indices: Optional[List[np.ndarray]] = None, - image: Optional[np.ndarray] = None, + label: NdarrayOrTensor, + indices: Optional[List[NdarrayOrTensor]] = None, + image: Optional[NdarrayOrTensor] = None, ) -> None: self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:]) - indices_: List[np.ndarray] if indices is None: - indices_ = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) # type: ignore + indices_ = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) else: indices_ = indices - self.centers = generate_label_classes_crop_centers( # type: ignore + self.centers = generate_label_classes_crop_centers( self.spatial_size, self.num_samples, label.shape[1:], indices_, self.ratios, self.R ) - def __call__(self, data: Mapping[Hashable, Any]) -> List[Dict[Hashable, np.ndarray]]: + def __call__(self, data: Mapping[Hashable, Any]) -> List[Dict[Hashable, NdarrayOrTensor]]: d = dict(data) label = d[self.label_key] image = d[self.image_key] if self.image_key else None @@ -1309,7 +1311,7 @@ def __call__(self, data: Mapping[Hashable, Any]) -> List[Dict[Hashable, np.ndarr raise ValueError("no available ROI centers to crop.") # initialize returned list with shallow copy to preserve key ordering - results: List[Dict[Hashable, np.ndarray]] = [dict(d) for _ in range(self.num_samples)] + results: List[Dict[Hashable, NdarrayOrTensor]] = [dict(d) for _ in range(self.num_samples)] for i, center in enumerate(self.centers): # fill in the extra keys with unmodified data @@ -1317,17 +1319,16 @@ def __call__(self, data: Mapping[Hashable, Any]) -> List[Dict[Hashable, np.ndarr results[i][key] = deepcopy(d[key]) for key in self.key_iterator(d): img = d[key] - cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore + cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) orig_size = img.shape[1:] - cropped: np.ndarray = cropper(img) # type: ignore - results[i][key] = cropped + results[i][key] = cropper(img) self.push_transform(results[i], key, extra_info={"center": center}, orig_size=orig_size) # add `patch_index` to the meta data for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix): meta_key = meta_key or f"{key}_{meta_key_postfix}" if meta_key not in results[i]: results[i][meta_key] = {} # type: ignore - results[i][meta_key][Key.PATCH_INDEX] = i + results[i][meta_key][Key.PATCH_INDEX] = i # type: ignore return results diff --git a/tests/test_rand_crop_by_label_classes.py b/tests/test_rand_crop_by_label_classes.py index b21f971042..d562a44a6d 100644 --- a/tests/test_rand_crop_by_label_classes.py +++ b/tests/test_rand_crop_by_label_classes.py @@ -15,68 +15,77 @@ from parameterized import parameterized from monai.transforms import ClassesToIndices, RandCropByLabelClasses +from tests.utils import TEST_NDARRAYS -TEST_CASE_0 = [ +TESTS_INDICES, TESTS_SHAPE = [], [] +for p in TEST_NDARRAYS: # One-Hot label - { - "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "num_classes": None, - "spatial_size": [2, 2, -1], - "ratios": [1, 1, 1], - "num_samples": 2, - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image_threshold": 0, - }, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - list, - (3, 2, 2, 3), -] + TESTS_INDICES.append( + [ + { + "label": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "num_classes": None, + "spatial_size": [2, 2, -1], + "ratios": [1, 1, 1], + "num_samples": 2, + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "image_threshold": 0, + }, + {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, + list, + (3, 2, 2, 3), + ] + ) -TEST_CASE_1 = [ - # Argmax label - { - "label": np.random.randint(0, 2, size=[1, 3, 3, 3]), - "num_classes": 2, - "spatial_size": [2, 2, 2], - "ratios": [1, 1], - "num_samples": 2, - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image_threshold": 0, - }, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - list, - (3, 2, 2, 2), -] + TESTS_INDICES.append( + [ + # Argmax label + { + "label": p(np.random.randint(0, 2, size=[1, 3, 3, 3])), + "num_classes": 2, + "spatial_size": [2, 2, 2], + "ratios": [1, 1], + "num_samples": 2, + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "image_threshold": 0, + }, + {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, + list, + (3, 2, 2, 2), + ] + ) -TEST_CASE_2 = [ - # provide label at runtime - { - "label": None, - "num_classes": 2, - "spatial_size": [2, 2, 2], - "ratios": [1, 1], - "num_samples": 2, - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image_threshold": 0, - }, - { - "img": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "label": np.random.randint(0, 2, size=[1, 3, 3, 3]), - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - }, - list, - (3, 2, 2, 2), -] + TESTS_SHAPE.append( + [ + # provide label at runtime + { + "label": None, + "num_classes": 2, + "spatial_size": [2, 2, 2], + "ratios": [1, 1], + "num_samples": 2, + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "image_threshold": 0, + }, + { + "img": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "label": p(np.random.randint(0, 2, size=[1, 3, 3, 3])), + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + }, + list, + (3, 2, 2, 2), + ] + ) class TestRandCropByLabelClasses(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand(TESTS_INDICES + TESTS_SHAPE) def test_type_shape(self, input_param, input_data, expected_type, expected_shape): result = RandCropByLabelClasses(**input_param)(**input_data) self.assertIsInstance(result, expected_type) self.assertTupleEqual(result[0].shape, expected_shape) - @parameterized.expand([TEST_CASE_0, TEST_CASE_1]) + @parameterized.expand(TESTS_INDICES) def test_indices(self, input_param, input_data, expected_type, expected_shape): input_param["indices"] = ClassesToIndices(num_classes=input_param["num_classes"])(input_param["label"]) result = RandCropByLabelClasses(**input_param)(**input_data) diff --git a/tests/test_rand_crop_by_label_classesd.py b/tests/test_rand_crop_by_label_classesd.py index 829096953b..27fe3425dd 100644 --- a/tests/test_rand_crop_by_label_classesd.py +++ b/tests/test_rand_crop_by_label_classesd.py @@ -15,52 +15,59 @@ from parameterized import parameterized from monai.transforms import ClassesToIndicesd, RandCropByLabelClassesd +from tests.utils import TEST_NDARRAYS -TEST_CASE_0 = [ - # One-Hot label - { - "keys": "img", - "label_key": "label", - "num_classes": None, - "spatial_size": [2, 2, -1], - "ratios": [1, 1, 1], - "num_samples": 2, - "image_key": "image", - "image_threshold": 0, - }, - { - "img": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), - }, - list, - (3, 2, 2, 3), -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + # One-Hot label + { + "keys": "img", + "label_key": "label", + "num_classes": None, + "spatial_size": [2, 2, -1], + "ratios": [1, 1, 1], + "num_samples": 2, + "image_key": "image", + "image_threshold": 0, + }, + { + "img": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "label": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + }, + list, + (3, 2, 2, 3), + ] + ) -TEST_CASE_1 = [ - # Argmax label - { - "keys": "img", - "label_key": "label", - "num_classes": 2, - "spatial_size": [2, 2, 2], - "ratios": [1, 1], - "num_samples": 2, - "image_key": "image", - "image_threshold": 0, - }, - { - "img": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "label": np.random.randint(0, 2, size=[1, 3, 3, 3]), - }, - list, - (3, 2, 2, 2), -] + TESTS.append( + [ + # Argmax label + { + "keys": "img", + "label_key": "label", + "num_classes": 2, + "spatial_size": [2, 2, 2], + "ratios": [1, 1], + "num_samples": 2, + "image_key": "image", + "image_threshold": 0, + }, + { + "img": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "label": p(np.random.randint(0, 2, size=[1, 3, 3, 3])), + }, + list, + (3, 2, 2, 2), + ] + ) class TestRandCropByLabelClassesd(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1]) + @parameterized.expand(TESTS) def test_type_shape(self, input_param, input_data, expected_type, expected_shape): result = RandCropByLabelClassesd(**input_param)(input_data) self.assertIsInstance(result, expected_type) diff --git a/tests/test_rand_crop_by_pos_neg_label.py b/tests/test_rand_crop_by_pos_neg_label.py index e0f669ab3f..a81976dea1 100644 --- a/tests/test_rand_crop_by_pos_neg_label.py +++ b/tests/test_rand_crop_by_pos_neg_label.py @@ -10,68 +10,93 @@ # limitations under the License. import unittest +from copy import deepcopy import numpy as np from parameterized import parameterized from monai.transforms import RandCropByPosNegLabel +from tests.utils import TEST_NDARRAYS -TEST_CASE_0 = [ - { - "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "spatial_size": [2, 2, -1], - "pos": 1, - "neg": 1, - "num_samples": 2, - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image_threshold": 0, - }, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - list, - (3, 2, 2, 3), -] +TESTS = [] +TESTS.append( + [ + { + "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "spatial_size": [2, 2, -1], + "pos": 1, + "neg": 1, + "num_samples": 2, + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "image_threshold": 0, + }, + {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, + (3, 2, 2, 3), + ] +) +TESTS.append( + [ + { + "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "spatial_size": [2, 2, 2], + "pos": 1, + "neg": 1, + "num_samples": 2, + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "image_threshold": 0, + }, + {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, + (3, 2, 2, 2), + ] +) +TESTS.append( + [ + { + "label": None, + "spatial_size": [2, 2, 2], + "pos": 1, + "neg": 1, + "num_samples": 2, + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "image_threshold": 0, + }, + { + "img": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + }, + (3, 2, 2, 2), + ] +) -TEST_CASE_1 = [ - { - "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "spatial_size": [2, 2, 2], - "pos": 1, - "neg": 1, - "num_samples": 2, - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image_threshold": 0, - }, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - list, - (3, 2, 2, 2), -] -TEST_CASE_2 = [ - { - "label": None, - "spatial_size": [2, 2, 2], - "pos": 1, - "neg": 1, - "num_samples": 2, - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image_threshold": 0, - }, - { - "img": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - }, - list, - (3, 2, 2, 2), -] +class TestRandCropByPosNegLabel(unittest.TestCase): + @staticmethod + def convert_data_type(im_type, d, keys=("img", "image", "label")): + out = deepcopy(d) + for k, v in out.items(): + if k in keys and isinstance(v, np.ndarray): + out[k] = im_type(v) + return out + @parameterized.expand(TESTS) + def test_type_shape(self, input_param, input_data, expected_shape): + results = [] + for p in TEST_NDARRAYS: + input_param_mod = self.convert_data_type(p, input_param) + input_data_mod = self.convert_data_type(p, input_data) + cropper = RandCropByPosNegLabel(**input_param_mod) + cropper.set_random_state(0) + result = cropper(**input_data_mod) -class TestRandCropByPosNegLabel(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) - def test_type_shape(self, input_param, input_data, expected_type, expected_shape): - result = RandCropByPosNegLabel(**input_param)(**input_data) - self.assertIsInstance(result, expected_type) - self.assertTupleEqual(result[0].shape, expected_shape) + self.assertIsInstance(result, list) + self.assertTupleEqual(result[0].shape, expected_shape) + + # check for same results across numpy, torch.Tensor and torch.cuda.Tensor + result = np.asarray([i if isinstance(i, np.ndarray) else i.cpu().numpy() for i in result]) + results.append(np.asarray(result)) + if len(results) > 1: + np.testing.assert_allclose(results[0], results[-1]) if __name__ == "__main__": diff --git a/tests/test_rand_crop_by_pos_neg_labeld.py b/tests/test_rand_crop_by_pos_neg_labeld.py index 17a3e117bb..6d2f39cc54 100644 --- a/tests/test_rand_crop_by_pos_neg_labeld.py +++ b/tests/test_rand_crop_by_pos_neg_labeld.py @@ -10,90 +10,101 @@ # limitations under the License. import unittest +from copy import deepcopy import numpy as np from parameterized import parameterized from monai.transforms import RandCropByPosNegLabeld +from tests.utils import TEST_NDARRAYS -TEST_CASE_0 = [ - { - "keys": ["image", "extra", "label"], - "label_key": "label", - "spatial_size": [-1, 2, 2], - "pos": 1, - "neg": 1, - "num_samples": 2, - "image_key": None, - "image_threshold": 0, - }, - { - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "extra": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image_meta_dict": {"affine": np.eye(3), "shape": "CHWD"}, - }, - list, - (3, 3, 2, 2), +TESTS = [ + [ + { + "keys": ["image", "extra", "label"], + "label_key": "label", + "spatial_size": [-1, 2, 2], + "pos": 1, + "neg": 1, + "num_samples": 2, + "image_key": None, + "image_threshold": 0, + }, + { + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "extra": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "image_meta_dict": {"affine": np.eye(3), "shape": "CHWD"}, + }, + (3, 3, 2, 2), + ], + [ + { + "keys": ["image", "extra", "label"], + "label_key": "label", + "spatial_size": [2, 2, 2], + "pos": 1, + "neg": 1, + "num_samples": 2, + "image_key": None, + "image_threshold": 0, + }, + { + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "extra": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "label_meta_dict": {"affine": np.eye(3), "shape": "CHWD"}, + }, + (3, 2, 2, 2), + ], + [ + { + "keys": ["image", "extra", "label"], + "label_key": "label", + "spatial_size": [2, 2, 2], + "pos": 1, + "neg": 1, + "num_samples": 2, + "image_key": None, + "image_threshold": 0, + }, + { + "image": np.zeros([3, 3, 3, 3]) - 1, + "extra": np.zeros([3, 3, 3, 3]), + "label": np.ones([3, 3, 3, 3]), + "extra_meta_dict": {"affine": np.eye(3), "shape": "CHWD"}, + }, + (3, 2, 2, 2), + ], ] -TEST_CASE_1 = [ - { - "keys": ["image", "extra", "label"], - "label_key": "label", - "spatial_size": [2, 2, 2], - "pos": 1, - "neg": 1, - "num_samples": 2, - "image_key": None, - "image_threshold": 0, - }, - { - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "extra": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "label_meta_dict": {"affine": np.eye(3), "shape": "CHWD"}, - }, - list, - (3, 2, 2, 2), -] -TEST_CASE_2 = [ - { - "keys": ["image", "extra", "label"], - "label_key": "label", - "spatial_size": [2, 2, 2], - "pos": 1, - "neg": 1, - "num_samples": 2, - "image_key": None, - "image_threshold": 0, - }, - { - "image": np.zeros([3, 3, 3, 3]) - 1, - "extra": np.zeros([3, 3, 3, 3]), - "label": np.ones([3, 3, 3, 3]), - "extra_meta_dict": {"affine": np.eye(3), "shape": "CHWD"}, - }, - list, - (3, 2, 2, 2), -] +class TestRandCropByPosNegLabeld(unittest.TestCase): + @staticmethod + def convert_data_type(im_type, d, keys=("img", "image", "label")): + out = deepcopy(d) + for k, v in out.items(): + if k in keys and isinstance(v, np.ndarray): + out[k] = im_type(v) + return out + @parameterized.expand(TESTS) + def test_type_shape(self, input_param, input_data, expected_shape): + for p in TEST_NDARRAYS: + input_param_mod = self.convert_data_type(p, input_param) + input_data_mod = self.convert_data_type(p, input_data) + cropper = RandCropByPosNegLabeld(**input_param_mod) + cropper.set_random_state(0) + result = cropper(input_data_mod) -class TestRandCropByPosNegLabeld(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) - def test_type_shape(self, input_param, input_data, expected_type, expected_shape): - result = RandCropByPosNegLabeld(**input_param)(input_data) - self.assertIsInstance(result, expected_type) - self.assertTupleEqual(result[0]["image"].shape, expected_shape) - self.assertTupleEqual(result[0]["extra"].shape, expected_shape) - self.assertTupleEqual(result[0]["label"].shape, expected_shape) - _len = len(tuple(input_data.keys())) - self.assertTupleEqual(tuple(result[0].keys())[:_len], tuple(input_data.keys())) - for i, item in enumerate(result): - self.assertEqual(item["image_meta_dict"]["patch_index"], i) - self.assertEqual(item["label_meta_dict"]["patch_index"], i) - self.assertEqual(item["extra_meta_dict"]["patch_index"], i) + self.assertIsInstance(result, list) + + _len = len(tuple(input_data.keys())) + self.assertTupleEqual(tuple(result[0].keys())[:_len], tuple(input_data.keys())) + for k in ("image", "extra", "label"): + self.assertTupleEqual(result[0][k].shape, expected_shape) + for i, item in enumerate(result): + self.assertEqual(item[k + "_meta_dict"]["patch_index"], i) if __name__ == "__main__":