diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 5b550f7885..a144c8c138 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -27,6 +27,11 @@ Generic Interfaces .. autoclass:: Randomizable :members: +`RandomizableTransform` +^^^^^^^^^^^^^^^^^^^^^^^ +.. autoclass:: RandomizableTransform + :members: + `Compose` ^^^^^^^^^ .. autoclass:: Compose diff --git a/monai/apps/deepgrow/transforms.py b/monai/apps/deepgrow/transforms.py index 80b0d1648d..3f4031fade 100644 --- a/monai/apps/deepgrow/transforms.py +++ b/monai/apps/deepgrow/transforms.py @@ -17,7 +17,7 @@ from monai.config import IndexSelection, KeysCollection from monai.networks.layers import GaussianFilter from monai.transforms import SpatialCrop -from monai.transforms.transform import MapTransform, Randomizable, Transform +from monai.transforms.transform import MapTransform, RandomizableTransform, Transform from monai.transforms.utils import generate_spatial_bounding_box from monai.utils import min_version, optional_import @@ -62,7 +62,7 @@ def __call__(self, data): return d -class AddInitialSeedPointd(Randomizable, Transform): +class AddInitialSeedPointd(RandomizableTransform): """ Add random guidance as initial seed point for a given label. @@ -279,7 +279,7 @@ def __call__(self, data): return d -class AddRandomGuidanced(Randomizable, Transform): +class AddRandomGuidanced(RandomizableTransform): """ Add random guidance based on discrepancies that were found between label and prediction. diff --git a/monai/data/dataset.py b/monai/data/dataset.py index db1ecc2b1f..bb5a98ba1e 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -26,6 +26,7 @@ from monai.data.utils import pickle_hashing from monai.transforms import Compose, Randomizable, Transform, apply_transform +from monai.transforms.transform import RandomizableTransform from monai.utils import MAX_SEED, get_seed, min_version, optional_import if TYPE_CHECKING: @@ -161,7 +162,7 @@ def _pre_transform(self, item_transformed): raise ValueError("transform must be an instance of monai.transforms.Compose.") for _transform in self.transform.transforms: # execute all the deterministic transforms - if isinstance(_transform, Randomizable) or not isinstance(_transform, Transform): + if isinstance(_transform, RandomizableTransform) or not isinstance(_transform, Transform): break item_transformed = apply_transform(_transform, item_transformed) return item_transformed @@ -183,7 +184,7 @@ def _post_transform(self, item_transformed): for _transform in self.transform.transforms: if ( start_post_randomize_run - or isinstance(_transform, Randomizable) + or isinstance(_transform, RandomizableTransform) or not isinstance(_transform, Transform) ): start_post_randomize_run = True @@ -522,7 +523,7 @@ def _load_cache_item(self, idx: int): raise ValueError("transform must be an instance of monai.transforms.Compose.") for _transform in self.transform.transforms: # execute all the deterministic transforms - if isinstance(_transform, Randomizable) or not isinstance(_transform, Transform): + if isinstance(_transform, RandomizableTransform) or not isinstance(_transform, Transform): break item = apply_transform(_transform, item) return item @@ -539,7 +540,7 @@ def __getitem__(self, index): if not isinstance(self.transform, Compose): raise ValueError("transform must be an instance of monai.transforms.Compose.") for _transform in self.transform.transforms: - if start_run or isinstance(_transform, Randomizable) or not isinstance(_transform, Transform): + if start_run or isinstance(_transform, RandomizableTransform) or not isinstance(_transform, Transform): start_run = True data = apply_transform(_transform, data) return data @@ -924,9 +925,9 @@ def __getitem__(self, index: int): # set transforms of each zip component for dataset in self.dataset.data: transform = getattr(dataset, "transform", None) - if isinstance(transform, Randomizable): + if isinstance(transform, RandomizableTransform): transform.set_random_state(seed=self._seed) transform = getattr(self.dataset, "transform", None) - if isinstance(transform, Randomizable): + if isinstance(transform, RandomizableTransform): transform.set_random_state(seed=self._seed) return self.dataset[index] diff --git a/monai/data/image_dataset.py b/monai/data/image_dataset.py index 1568e082ee..1074105508 100644 --- a/monai/data/image_dataset.py +++ b/monai/data/image_dataset.py @@ -17,6 +17,7 @@ from monai.config import DtypeLike from monai.data.image_reader import ImageReader from monai.transforms import LoadImage, Randomizable, apply_transform +from monai.transforms.transform import RandomizableTransform from monai.utils import MAX_SEED, get_seed @@ -106,14 +107,14 @@ def __getitem__(self, index: int): label = self.labels[index] if self.transform is not None: - if isinstance(self.transform, Randomizable): + if isinstance(self.transform, RandomizableTransform): self.transform.set_random_state(seed=self._seed) img = apply_transform(self.transform, img) data = [img] if self.seg_transform is not None: - if isinstance(self.seg_transform, Randomizable): + if isinstance(self.seg_transform, RandomizableTransform): self.seg_transform.set_random_state(seed=self._seed) seg = apply_transform(self.seg_transform, seg) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index f5c7c826e9..8b30d76bec 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -237,7 +237,7 @@ ZoomD, ZoomDict, ) -from .transform import MapTransform, Randomizable, Transform +from .transform import MapTransform, Randomizable, RandomizableTransform, Transform from .utility.array import ( AddChannel, AddExtremePointsChannel, diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 2d612ad2e3..a9f66b12a0 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -18,15 +18,14 @@ import numpy as np # For backwards compatiblity (so this still works: from monai.transforms.compose import MapTransform) -from monai.transforms.transform import MapTransform # noqa: F401 -from monai.transforms.transform import Randomizable, Transform +from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform, Transform # noqa: F401 from monai.transforms.utils import apply_transform from monai.utils import MAX_SEED, ensure_tuple, get_seed __all__ = ["Compose"] -class Compose(Randomizable, Transform): +class Compose(RandomizableTransform): """ ``Compose`` provides the ability to chain a series of calls together in a sequence. Each transform in the sequence must take a single argument and @@ -96,14 +95,14 @@ def __init__(self, transforms: Optional[Union[Sequence[Callable], Callable]] = N def set_random_state(self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None) -> "Compose": super().set_random_state(seed=seed, state=state) for _transform in self.transforms: - if not isinstance(_transform, Randomizable): + if not isinstance(_transform, RandomizableTransform): continue _transform.set_random_state(seed=self.R.randint(MAX_SEED, dtype="uint32")) return self def randomize(self, data: Optional[Any] = None) -> None: for _transform in self.transforms: - if not isinstance(_transform, Randomizable): + if not isinstance(_transform, RandomizableTransform): continue try: _transform.randomize(data) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index ef5e0019bd..a3d36ad903 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -20,7 +20,7 @@ from monai.config import IndexSelection from monai.data.utils import get_random_patch, get_valid_patch_size -from monai.transforms.transform import Randomizable, Transform +from monai.transforms.transform import Randomizable, RandomizableTransform, Transform from monai.transforms.utils import ( generate_pos_neg_label_crop_centers, generate_spatial_bounding_box, @@ -276,7 +276,7 @@ def __call__(self, img: np.ndarray): return cropper(img) -class RandSpatialCrop(Randomizable, Transform): +class RandSpatialCrop(RandomizableTransform): """ Crop image with random size or specific size ROI. It can crop at a random position as center or at the image center. And allows to set the minimum size to limit the randomly generated ROI. @@ -321,7 +321,7 @@ def __call__(self, img: np.ndarray): return cropper(img) -class RandSpatialCropSamples(Randomizable, Transform): +class RandSpatialCropSamples(RandomizableTransform): """ Crop image with random size or specific size ROI to generate a list of N samples. It can crop at a random position as center or at the image center. And allows to set @@ -429,7 +429,7 @@ def __call__(self, img: np.ndarray): return cropped -class RandWeightedCrop(Randomizable, Transform): +class RandWeightedCrop(RandomizableTransform): """ Samples a list of `num_samples` image patches according to the provided `weight_map`. @@ -481,7 +481,7 @@ def __call__(self, img: np.ndarray, weight_map: Optional[np.ndarray] = None) -> return results -class RandCropByPosNegLabel(Randomizable, Transform): +class RandCropByPosNegLabel(RandomizableTransform): """ Crop random fixed sized regions with the center being a foreground or background voxel based on the Pos Neg Ratio. diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 20ae6ac1ed..9739c6322f 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -30,7 +30,7 @@ SpatialCrop, SpatialPad, ) -from monai.transforms.transform import MapTransform, Randomizable +from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform from monai.transforms.utils import ( generate_pos_neg_label_crop_centers, generate_spatial_bounding_box, @@ -258,7 +258,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d -class RandSpatialCropd(Randomizable, MapTransform): +class RandSpatialCropd(RandomizableTransform, MapTransform): """ Dictionary-based version :py:class:`monai.transforms.RandSpatialCrop`. Crop image with random size or specific size ROI. It can crop at a random position as @@ -283,7 +283,8 @@ def __init__( random_center: bool = True, random_size: bool = True, ) -> None: - super().__init__(keys) + RandomizableTransform.__init__(self) + MapTransform.__init__(self, keys) self.roi_size = roi_size self.random_center = random_center self.random_size = random_size @@ -312,7 +313,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d -class RandSpatialCropSamplesd(Randomizable, MapTransform): +class RandSpatialCropSamplesd(RandomizableTransform, MapTransform): """ Dictionary-based version :py:class:`monai.transforms.RandSpatialCropSamples`. Crop image with random size or specific size ROI to generate a list of N samples. @@ -344,7 +345,8 @@ def __init__( random_center: bool = True, random_size: bool = True, ) -> None: - super().__init__(keys) + RandomizableTransform.__init__(self) + MapTransform.__init__(self, keys) if num_samples < 1: raise ValueError(f"num_samples must be positive, got {num_samples}.") self.num_samples = num_samples @@ -420,7 +422,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d -class RandWeightedCropd(Randomizable, MapTransform): +class RandWeightedCropd(RandomizableTransform, MapTransform): """ Samples a list of `num_samples` image patches according to the provided `weight_map`. @@ -446,7 +448,8 @@ def __init__( num_samples: int = 1, center_coord_key: Optional[str] = None, ): - super().__init__(keys) + RandomizableTransform.__init__(self) + MapTransform.__init__(self, keys) self.spatial_size = ensure_tuple(spatial_size) self.w_key = w_key self.num_samples = int(num_samples) @@ -484,7 +487,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n return results -class RandCropByPosNegLabeld(Randomizable, MapTransform): +class RandCropByPosNegLabeld(RandomizableTransform, MapTransform): """ Dictionary-based version :py:class:`monai.transforms.RandCropByPosNegLabel`. Crop random fixed sized regions with the center being a foreground or background voxel @@ -534,7 +537,8 @@ def __init__( fg_indices_key: Optional[str] = None, bg_indices_key: Optional[str] = None, ) -> None: - super().__init__(keys) + RandomizableTransform.__init__(self) + MapTransform.__init__(self, keys) self.label_key = label_key self.spatial_size: Union[Tuple[int, ...], Sequence[int], int] = spatial_size if pos < 0 or neg < 0: diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 40bef064eb..1bddc0137d 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -22,7 +22,7 @@ from monai.config import DtypeLike from monai.networks.layers import GaussianFilter, HilbertTransform, SavitzkyGolayFilter -from monai.transforms.transform import Randomizable, Transform +from monai.transforms.transform import RandomizableTransform, Transform from monai.transforms.utils import rescale_array from monai.utils import PT_BEFORE_1_7, InvalidPyTorchVersionError, dtype_torch_to_numpy, ensure_tuple_size @@ -49,7 +49,7 @@ ] -class RandGaussianNoise(Randomizable, Transform): +class RandGaussianNoise(RandomizableTransform): """ Add Gaussian noise to image. @@ -60,14 +60,13 @@ class RandGaussianNoise(Randomizable, Transform): """ def __init__(self, prob: float = 0.1, mean: Union[Sequence[float], float] = 0.0, std: float = 0.1) -> None: - self.prob = prob + RandomizableTransform.__init__(self, prob) self.mean = mean self.std = std - self._do_transform = False self._noise = None def randomize(self, im_shape: Sequence[int]) -> None: - self._do_transform = self.R.random() < self.prob + super().randomize(None) self._noise = self.R.normal(self.mean, self.R.uniform(0, self.std), size=im_shape) def __call__(self, img: Union[torch.Tensor, np.ndarray]) -> Union[torch.Tensor, np.ndarray]: @@ -101,7 +100,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: return np.asarray((img + self.offset), dtype=img.dtype) -class RandShiftIntensity(Randomizable, Transform): +class RandShiftIntensity(RandomizableTransform): """ Randomly shift intensity with randomly picked offset. """ @@ -113,6 +112,7 @@ def __init__(self, offsets: Union[Tuple[float, float], float], prob: float = 0.1 if single number, offset value is picked from (-offsets, offsets). prob: probability of shift. """ + RandomizableTransform.__init__(self, prob) if isinstance(offsets, (int, float)): self.offsets = (min(-offsets, offsets), max(-offsets, offsets)) else: @@ -120,12 +120,9 @@ def __init__(self, offsets: Union[Tuple[float, float], float], prob: float = 0.1 raise AssertionError("offsets should be a number or pair of numbers.") self.offsets = (min(offsets), max(offsets)) - self.prob = prob - self._do_transform = False - def randomize(self, data: Optional[Any] = None) -> None: self._offset = self.R.uniform(low=self.offsets[0], high=self.offsets[1]) - self._do_transform = self.R.random() < self.prob + super().randomize(None) def __call__(self, img: np.ndarray) -> np.ndarray: """ @@ -172,7 +169,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: raise ValueError("Incompatible values: minv=None or maxv=None and factor=None.") -class RandScaleIntensity(Randomizable, Transform): +class RandScaleIntensity(RandomizableTransform): """ Randomly scale the intensity of input image by ``v = v * (1 + factor)`` where the `factor` is randomly picked from (-factors[0], factors[0]). @@ -186,6 +183,7 @@ def __init__(self, factors: Union[Tuple[float, float], float], prob: float = 0.1 prob: probability of scale. """ + RandomizableTransform.__init__(self, prob) if isinstance(factors, (int, float)): self.factors = (min(-factors, factors), max(-factors, factors)) else: @@ -193,12 +191,9 @@ def __init__(self, factors: Union[Tuple[float, float], float], prob: float = 0.1 raise AssertionError("factors should be a number or pair of numbers.") self.factors = (min(factors), max(factors)) - self.prob = prob - self._do_transform = False - def randomize(self, data: Optional[Any] = None) -> None: self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1]) - self._do_transform = self.R.random() < self.prob + super().randomize(None) def __call__(self, img: np.ndarray) -> np.ndarray: """ @@ -243,7 +238,7 @@ def __init__( self.dtype = dtype def _normalize(self, img: np.ndarray, sub=None, div=None) -> np.ndarray: - slices = (img != 0) if self.nonzero else np.ones(img.shape, dtype=np.bool_) + slices = (img != 0) if self.nonzero else np.ones(img.shape, dtype=bool) if not np.any(slices): return img @@ -370,7 +365,7 @@ def __call__(self, img: np.ndarray): return np.power(((img - img_min) / float(img_range + epsilon)), self.gamma) * img_range + img_min -class RandAdjustContrast(Randomizable, Transform): +class RandAdjustContrast(RandomizableTransform): """ Randomly changes image intensity by gamma. Each pixel/voxel intensity is updated as:: @@ -383,7 +378,7 @@ class RandAdjustContrast(Randomizable, Transform): """ def __init__(self, prob: float = 0.1, gamma: Union[Sequence[float], float] = (0.5, 4.5)) -> None: - self.prob = prob + RandomizableTransform.__init__(self, prob) if isinstance(gamma, (int, float)): if gamma <= 0.5: @@ -396,11 +391,10 @@ def __init__(self, prob: float = 0.1, gamma: Union[Sequence[float], float] = (0. raise AssertionError("gamma should be a number or pair of numbers.") self.gamma = (min(gamma), max(gamma)) - self._do_transform = False self.gamma_value = None def randomize(self, data: Optional[Any] = None) -> None: - self._do_transform = self.R.random_sample() < self.prob + super().randomize(None) self.gamma_value = self.R.uniform(low=self.gamma[0], high=self.gamma[1]) def __call__(self, img: np.ndarray) -> np.ndarray: @@ -657,7 +651,7 @@ def __call__(self, img: np.ndarray): return gaussian_filter(input_data).squeeze(0).detach().numpy() -class RandGaussianSmooth(Randomizable, Transform): +class RandGaussianSmooth(RandomizableTransform): """ Apply Gaussian smooth to the input data based on randomly selected `sigma` parameters. @@ -679,15 +673,14 @@ def __init__( prob: float = 0.1, approx: str = "erf", ) -> None: + RandomizableTransform.__init__(self, prob) self.sigma_x = sigma_x self.sigma_y = sigma_y self.sigma_z = sigma_z - self.prob = prob self.approx = approx - self._do_transform = False def randomize(self, data: Optional[Any] = None) -> None: - self._do_transform = self.R.random_sample() < self.prob + super().randomize(None) self.x = self.R.uniform(low=self.sigma_x[0], high=self.sigma_x[1]) self.y = self.R.uniform(low=self.sigma_y[0], high=self.sigma_y[1]) self.z = self.R.uniform(low=self.sigma_z[0], high=self.sigma_z[1]) @@ -748,7 +741,7 @@ def __call__(self, img: np.ndarray): return (blurred_f + self.alpha * (blurred_f - filter_blurred_f)).squeeze(0).detach().numpy() -class RandGaussianSharpen(Randomizable, Transform): +class RandGaussianSharpen(RandomizableTransform): """ Sharpen images using the Gaussian Blur filter based on randomly selected `sigma1`, `sigma2` and `alpha`. The algorithm is :py:class:`monai.transforms.GaussianSharpen`. @@ -782,6 +775,7 @@ def __init__( approx: str = "erf", prob: float = 0.1, ) -> None: + RandomizableTransform.__init__(self, prob) self.sigma1_x = sigma1_x self.sigma1_y = sigma1_y self.sigma1_z = sigma1_z @@ -790,11 +784,9 @@ def __init__( self.sigma2_z = sigma2_z self.alpha = alpha self.approx = approx - self.prob = prob - self._do_transform = False def randomize(self, data: Optional[Any] = None) -> None: - self._do_transform = self.R.random_sample() < self.prob + super().randomize(None) self.x1 = self.R.uniform(low=self.sigma1_x[0], high=self.sigma1_x[1]) self.y1 = self.R.uniform(low=self.sigma1_y[0], high=self.sigma1_y[1]) self.z1 = self.R.uniform(low=self.sigma1_z[0], high=self.sigma1_z[1]) @@ -815,7 +807,7 @@ def __call__(self, img: np.ndarray): return GaussianSharpen(sigma1=sigma1, sigma2=sigma2, alpha=self.a, approx=self.approx)(img) -class RandHistogramShift(Randomizable, Transform): +class RandHistogramShift(RandomizableTransform): """ Apply random nonlinear transform to the image's intensity histogram. @@ -827,6 +819,7 @@ class RandHistogramShift(Randomizable, Transform): """ def __init__(self, num_control_points: Union[Tuple[int, int], int] = 10, prob: float = 0.1) -> None: + RandomizableTransform.__init__(self, prob) if isinstance(num_control_points, int): if num_control_points <= 2: @@ -838,11 +831,9 @@ def __init__(self, num_control_points: Union[Tuple[int, int], int] = 10, prob: f if min(num_control_points) <= 2: raise AssertionError("num_control_points should be greater than or equal to 3") self.num_control_points = (min(num_control_points), max(num_control_points)) - self.prob = prob - self._do_transform = False def randomize(self, data: Optional[Any] = None) -> None: - self._do_transform = self.R.random() < self.prob + super().randomize(None) num_control_point = self.R.randint(self.num_control_points[0], self.num_control_points[1] + 1) self.reference_control_points = np.linspace(0, 1, num_control_point) self.floating_control_points = np.copy(self.reference_control_points) diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 54a85a57b0..7d0d66d2ba 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -34,7 +34,7 @@ ShiftIntensity, ThresholdIntensity, ) -from monai.transforms.transform import MapTransform, Randomizable +from monai.transforms.transform import MapTransform, RandomizableTransform from monai.utils import dtype_torch_to_numpy, ensure_tuple_rep, ensure_tuple_size __all__ = [ @@ -92,7 +92,7 @@ ] -class RandGaussianNoised(Randomizable, MapTransform): +class RandGaussianNoised(RandomizableTransform, MapTransform): """ Dictionary-based version :py:class:`monai.transforms.RandGaussianNoise`. Add Gaussian noise to image. This transform assumes all the expected fields have same shape. @@ -108,15 +108,14 @@ class RandGaussianNoised(Randomizable, MapTransform): def __init__( self, keys: KeysCollection, prob: float = 0.1, mean: Union[Sequence[float], float] = 0.0, std: float = 0.1 ) -> None: - super().__init__(keys) - self.prob = prob + MapTransform.__init__(self, keys) + RandomizableTransform.__init__(self, prob) self.mean = ensure_tuple_rep(mean, len(self.keys)) self.std = std - self._do_transform = False self._noise: List[np.ndarray] = [] def randomize(self, im_shape: Sequence[int]) -> None: - self._do_transform = self.R.random() < self.prob + super().randomize(None) self._noise.clear() for m in self.mean: self._noise.append(self.R.normal(m, self.R.uniform(0, self.std), size=im_shape)) @@ -158,7 +157,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d -class RandShiftIntensityd(Randomizable, MapTransform): +class RandShiftIntensityd(RandomizableTransform, MapTransform): """ Dictionary-based version :py:class:`monai.transforms.RandShiftIntensity`. """ @@ -173,7 +172,8 @@ def __init__(self, keys: KeysCollection, offsets: Union[Tuple[float, float], flo prob: probability of rotating. (Default 0.1, with 10% probability it returns a rotated array.) """ - super().__init__(keys) + MapTransform.__init__(self, keys) + RandomizableTransform.__init__(self, prob) if isinstance(offsets, (int, float)): self.offsets = (min(-offsets, offsets), max(-offsets, offsets)) @@ -182,12 +182,9 @@ def __init__(self, keys: KeysCollection, offsets: Union[Tuple[float, float], flo raise AssertionError("offsets should be a number or pair of numbers.") self.offsets = (min(offsets), max(offsets)) - self.prob = prob - self._do_transform = False - def randomize(self, data: Optional[Any] = None) -> None: self._offset = self.R.uniform(low=self.offsets[0], high=self.offsets[1]) - self._do_transform = self.R.random() < self.prob + super().randomize(None) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) @@ -229,7 +226,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d -class RandScaleIntensityd(Randomizable, MapTransform): +class RandScaleIntensityd(RandomizableTransform, MapTransform): """ Dictionary-based version :py:class:`monai.transforms.RandScaleIntensity`. """ @@ -245,7 +242,8 @@ def __init__(self, keys: KeysCollection, factors: Union[Tuple[float, float], flo (Default 0.1, with 10% probability it returns a rotated array.) """ - super().__init__(keys) + MapTransform.__init__(self, keys) + RandomizableTransform.__init__(self, prob) if isinstance(factors, (int, float)): self.factors = (min(-factors, factors), max(-factors, factors)) @@ -254,12 +252,9 @@ def __init__(self, keys: KeysCollection, factors: Union[Tuple[float, float], flo raise AssertionError("factors should be a number or pair of numbers.") self.factors = (min(factors), max(factors)) - self.prob = prob - self._do_transform = False - def randomize(self, data: Optional[Any] = None) -> None: self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1]) - self._do_transform = self.R.random() < self.prob + super().randomize(None) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) @@ -382,7 +377,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d -class RandAdjustContrastd(Randomizable, MapTransform): +class RandAdjustContrastd(RandomizableTransform, MapTransform): """ Dictionary-based version :py:class:`monai.transforms.RandAdjustContrast`. Randomly changes image intensity by gamma. Each pixel/voxel intensity is updated as: @@ -400,8 +395,8 @@ class RandAdjustContrastd(Randomizable, MapTransform): def __init__( self, keys: KeysCollection, prob: float = 0.1, gamma: Union[Tuple[float, float], float] = (0.5, 4.5) ) -> None: - super().__init__(keys) - self.prob: float = prob + MapTransform.__init__(self, keys) + RandomizableTransform.__init__(self, prob) if isinstance(gamma, (int, float)): if gamma <= 0.5: @@ -414,11 +409,10 @@ def __init__( raise AssertionError("gamma should be a number or pair of numbers.") self.gamma = (min(gamma), max(gamma)) - self._do_transform = False self.gamma_value: Optional[float] = None def randomize(self, data: Optional[Any] = None) -> None: - self._do_transform = self.R.random_sample() < self.prob + super().randomize(None) self.gamma_value = self.R.uniform(low=self.gamma[0], high=self.gamma[1]) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: @@ -529,7 +523,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d -class RandGaussianSmoothd(Randomizable, MapTransform): +class RandGaussianSmoothd(RandomizableTransform, MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.GaussianSmooth`. @@ -554,16 +548,15 @@ def __init__( approx: str = "erf", prob: float = 0.1, ) -> None: - super().__init__(keys) + MapTransform.__init__(self, keys) + RandomizableTransform.__init__(self, prob) self.sigma_x = sigma_x self.sigma_y = sigma_y self.sigma_z = sigma_z self.approx = approx - self.prob = prob - self._do_transform = False def randomize(self, data: Optional[Any] = None) -> None: - self._do_transform = self.R.random_sample() < self.prob + super().randomize(None) self.x = self.R.uniform(low=self.sigma_x[0], high=self.sigma_x[1]) self.y = self.R.uniform(low=self.sigma_y[0], high=self.sigma_y[1]) self.z = self.R.uniform(low=self.sigma_z[0], high=self.sigma_z[1]) @@ -616,7 +609,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d -class RandGaussianSharpend(Randomizable, MapTransform): +class RandGaussianSharpend(RandomizableTransform, MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.GaussianSharpen`. @@ -652,7 +645,8 @@ def __init__( approx: str = "erf", prob: float = 0.1, ): - super().__init__(keys) + MapTransform.__init__(self, keys) + RandomizableTransform.__init__(self, prob) self.sigma1_x = sigma1_x self.sigma1_y = sigma1_y self.sigma1_z = sigma1_z @@ -661,11 +655,9 @@ def __init__( self.sigma2_z = sigma2_z self.alpha = alpha self.approx = approx - self.prob = prob - self._do_transform = False def randomize(self, data: Optional[Any] = None) -> None: - self._do_transform = self.R.random_sample() < self.prob + super().randomize(None) self.x1 = self.R.uniform(low=self.sigma1_x[0], high=self.sigma1_x[1]) self.y1 = self.R.uniform(low=self.sigma1_y[0], high=self.sigma1_y[1]) self.z1 = self.R.uniform(low=self.sigma1_z[0], high=self.sigma1_z[1]) @@ -689,7 +681,7 @@ def __call__(self, data): return d -class RandHistogramShiftd(Randomizable, MapTransform): +class RandHistogramShiftd(RandomizableTransform, MapTransform): """ Dictionary-based version :py:class:`monai.transforms.RandHistogramShift`. Apply random nonlinear transform the the image's intensity histogram. @@ -706,7 +698,8 @@ class RandHistogramShiftd(Randomizable, MapTransform): def __init__( self, keys: KeysCollection, num_control_points: Union[Tuple[int, int], int] = 10, prob: float = 0.1 ) -> None: - super().__init__(keys) + MapTransform.__init__(self, keys) + RandomizableTransform.__init__(self, prob) if isinstance(num_control_points, int): if num_control_points <= 2: raise AssertionError("num_control_points should be greater than or equal to 3") @@ -717,11 +710,9 @@ def __init__( if min(num_control_points) <= 2: raise AssertionError("num_control_points should be greater than or equal to 3") self.num_control_points = (min(num_control_points), max(num_control_points)) - self.prob = prob - self._do_transform = False def randomize(self, data: Optional[Any] = None) -> None: - self._do_transform = self.R.random() < self.prob + super().randomize(None) num_control_point = self.R.randint(self.num_control_points[0], self.num_control_points[1] + 1) self.reference_control_points = np.linspace(0, 1, num_control_point) self.floating_control_points = np.copy(self.reference_control_points) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index d6dbe56f01..3559d0eb3c 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -23,7 +23,7 @@ from monai.data.utils import compute_shape_offset, to_affine_nd, zoom_affine from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull from monai.transforms.croppad.array import CenterSpatialCrop -from monai.transforms.transform import Randomizable, Transform +from monai.transforms.transform import RandomizableTransform, Transform from monai.transforms.utils import ( create_control_grid, create_grid, @@ -604,7 +604,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: return result.astype(img.dtype) -class RandRotate90(Randomizable, Transform): +class RandRotate90(RandomizableTransform): """ With probability `prob`, input arrays are rotated by 90 degrees in the plane specified by `spatial_axes`. @@ -619,16 +619,15 @@ def __init__(self, prob: float = 0.1, max_k: int = 3, spatial_axes: Tuple[int, i spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes. Default: (0, 1), this is the first two axis in spatial dimensions. """ - self.prob = min(max(prob, 0.0), 1.0) + RandomizableTransform.__init__(self, min(max(prob, 0.0), 1.0)) self.max_k = max_k self.spatial_axes = spatial_axes - self._do_transform = False self._rand_k = 0 def randomize(self, data: Optional[Any] = None) -> None: self._rand_k = self.R.randint(self.max_k) + 1 - self._do_transform = self.R.random() < self.prob + super().randomize(None) def __call__(self, img: np.ndarray) -> np.ndarray: """ @@ -642,7 +641,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: return rotator(img) -class RandRotate(Randomizable, Transform): +class RandRotate(RandomizableTransform): """ Randomly rotate the input arrays. @@ -682,6 +681,7 @@ def __init__( align_corners: bool = False, dtype: DtypeLike = np.float64, ) -> None: + RandomizableTransform.__init__(self, prob) self.range_x = ensure_tuple(range_x) if len(self.range_x) == 1: self.range_x = tuple(sorted([-self.range_x[0], self.range_x[0]])) @@ -692,20 +692,18 @@ def __init__( if len(self.range_z) == 1: self.range_z = tuple(sorted([-self.range_z[0], self.range_z[0]])) - self.prob = prob self.keep_size = keep_size self.mode: GridSampleMode = GridSampleMode(mode) self.padding_mode: GridSamplePadMode = GridSamplePadMode(padding_mode) self.align_corners = align_corners self.dtype = dtype - self._do_transform = False self.x = 0.0 self.y = 0.0 self.z = 0.0 def randomize(self, data: Optional[Any] = None) -> None: - self._do_transform = self.R.random_sample() < self.prob + super().randomize(None) self.x = self.R.uniform(low=self.range_x[0], high=self.range_x[1]) self.y = self.R.uniform(low=self.range_y[0], high=self.range_y[1]) self.z = self.R.uniform(low=self.range_z[0], high=self.range_z[1]) @@ -747,7 +745,7 @@ def __call__( return rotator(img) -class RandFlip(Randomizable, Transform): +class RandFlip(RandomizableTransform): """ Randomly flips the image along axes. Preserves shape. See numpy.flip for additional details. @@ -759,25 +757,21 @@ class RandFlip(Randomizable, Transform): """ def __init__(self, prob: float = 0.1, spatial_axis: Optional[Union[Sequence[int], int]] = None) -> None: - self.prob = prob + RandomizableTransform.__init__(self, min(max(prob, 0.0), 1.0)) self.flipper = Flip(spatial_axis=spatial_axis) - self._do_transform = False - - def randomize(self, data: Optional[Any] = None) -> None: - self._do_transform = self.R.random_sample() < self.prob def __call__(self, img: np.ndarray) -> np.ndarray: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), """ - self.randomize() + self.randomize(None) if not self._do_transform: return img return self.flipper(img) -class RandZoom(Randomizable, Transform): +class RandZoom(RandomizableTransform): """ Randomly zooms input arrays with given probability within given zoom range. @@ -816,21 +810,20 @@ def __init__( align_corners: Optional[bool] = None, keep_size: bool = True, ) -> None: + RandomizableTransform.__init__(self, prob) self.min_zoom = ensure_tuple(min_zoom) self.max_zoom = ensure_tuple(max_zoom) if len(self.min_zoom) != len(self.max_zoom): raise AssertionError("min_zoom and max_zoom must have same length.") - self.prob = prob self.mode: InterpolateMode = InterpolateMode(mode) self.padding_mode: NumpyPadMode = NumpyPadMode(padding_mode) self.align_corners = align_corners self.keep_size = keep_size - self._do_transform = False self._zoom: Sequence[float] = [1.0] def randomize(self, data: Optional[Any] = None) -> None: - self._do_transform = self.R.random_sample() < self.prob + super().randomize(None) self._zoom = [self.R.uniform(l, h) for l, h in zip(self.min_zoom, self.max_zoom)] def __call__( @@ -961,7 +954,7 @@ def __call__( return np.asarray(grid.cpu().numpy()) -class RandAffineGrid(Randomizable, Transform): +class RandAffineGrid(RandomizableTransform): """ Generate randomised affine grid. """ @@ -1050,7 +1043,7 @@ def __call__( return affine_grid(spatial_size, grid) -class RandDeformGrid(Randomizable, Transform): +class RandDeformGrid(RandomizableTransform): """ Generate random deformation grid. """ @@ -1279,7 +1272,7 @@ def __call__( ) -class RandAffine(Randomizable, Transform): +class RandAffine(RandomizableTransform): """ Random affine transform. """ @@ -1331,6 +1324,7 @@ def __init__( - :py:class:`RandAffineGrid` for the random affine parameters configurations. - :py:class:`Affine` for the affine transformation parameters configurations. """ + RandomizableTransform.__init__(self, prob) self.rand_affine_grid = RandAffineGrid( rotate_range=rotate_range, @@ -1346,9 +1340,6 @@ def __init__( self.mode: GridSampleMode = GridSampleMode(mode) self.padding_mode: GridSamplePadMode = GridSamplePadMode(padding_mode) - self.do_transform = False - self.prob = prob - def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None ) -> "RandAffine": @@ -1357,7 +1348,7 @@ def set_random_state( return self def randomize(self, data: Optional[Any] = None) -> None: - self.do_transform = self.R.rand() < self.prob + super().randomize(None) self.rand_affine_grid.randomize() def __call__( @@ -1385,7 +1376,7 @@ def __call__( self.randomize() sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) - if self.do_transform: + if self._do_transform: grid = self.rand_affine_grid(spatial_size=sp_size) else: grid = create_grid(spatial_size=sp_size) @@ -1394,7 +1385,7 @@ def __call__( ) -class Rand2DElastic(Randomizable, Transform): +class Rand2DElastic(RandomizableTransform): """ Random elastic deformation and affine in 2D """ @@ -1451,6 +1442,7 @@ def __init__( - :py:class:`RandAffineGrid` for the random affine parameters configurations. - :py:class:`Affine` for the affine transformation parameters configurations. """ + RandomizableTransform.__init__(self, prob) self.deform_grid = RandDeformGrid( spacing=spacing, magnitude_range=magnitude_range, as_tensor_output=True, device=device ) @@ -1467,8 +1459,6 @@ def __init__( self.spatial_size = spatial_size self.mode: GridSampleMode = GridSampleMode(mode) self.padding_mode: GridSamplePadMode = GridSamplePadMode(padding_mode) - self.prob = prob - self.do_transform = False def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None @@ -1479,7 +1469,7 @@ def set_random_state( return self def randomize(self, spatial_size: Sequence[int]) -> None: - self.do_transform = self.R.rand() < self.prob + super().randomize(None) self.deform_grid.randomize(spatial_size) self.rand_affine_grid.randomize() @@ -1505,7 +1495,7 @@ def __call__( """ sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) self.randomize(spatial_size=sp_size) - if self.do_transform: + if self._do_transform: grid = self.deform_grid(spatial_size=sp_size) grid = self.rand_affine_grid(grid=grid) grid = torch.nn.functional.interpolate( # type: ignore @@ -1521,7 +1511,7 @@ def __call__( return self.resampler(img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode) -class Rand3DElastic(Randomizable, Transform): +class Rand3DElastic(RandomizableTransform): """ Random elastic deformation and affine in 3D """ @@ -1580,6 +1570,7 @@ def __init__( - :py:class:`RandAffineGrid` for the random affine parameters configurations. - :py:class:`Affine` for the affine transformation parameters configurations. """ + RandomizableTransform.__init__(self, prob) self.rand_affine_grid = RandAffineGrid(rotate_range, shear_range, translate_range, scale_range, True, device) self.resampler = Resample(as_tensor_output=as_tensor_output, device=device) @@ -1590,8 +1581,6 @@ def __init__( self.padding_mode: GridSamplePadMode = GridSamplePadMode(padding_mode) self.device = device - self.prob = prob - self.do_transform = False self.rand_offset = None self.magnitude = 1.0 self.sigma = 1.0 @@ -1604,8 +1593,8 @@ def set_random_state( return self def randomize(self, grid_size: Sequence[int]) -> None: - self.do_transform = self.R.rand() < self.prob - if self.do_transform: + super().randomize(None) + if self._do_transform: self.rand_offset = self.R.uniform(-1.0, 1.0, [3] + list(grid_size)).astype(np.float32) self.magnitude = self.R.uniform(self.magnitude_range[0], self.magnitude_range[1]) self.sigma = self.R.uniform(self.sigma_range[0], self.sigma_range[1]) @@ -1634,7 +1623,7 @@ def __call__( sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) self.randomize(grid_size=sp_size) grid = create_grid(spatial_size=sp_size) - if self.do_transform: + if self._do_transform: if self.rand_offset is None: raise AssertionError grid = torch.as_tensor(np.ascontiguousarray(grid), device=self.device) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 2c66cd5f50..6693d75bcd 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -35,7 +35,7 @@ Spacing, Zoom, ) -from monai.transforms.transform import MapTransform, Randomizable +from monai.transforms.transform import MapTransform, RandomizableTransform from monai.transforms.utils import create_grid from monai.utils import ( GridSampleMode, @@ -274,7 +274,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d -class RandRotate90d(Randomizable, MapTransform): +class RandRotate90d(RandomizableTransform, MapTransform): """ Dictionary-based version :py:class:`monai.transforms.RandRotate90`. With probability `prob`, input arrays are rotated by 90 degrees @@ -299,28 +299,26 @@ def __init__( spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes. Default: (0, 1), this is the first two axis in spatial dimensions. """ - super().__init__(keys) + MapTransform.__init__(self, keys) + RandomizableTransform.__init__(self, prob) - self.prob = min(max(prob, 0.0), 1.0) self.max_k = max_k self.spatial_axes = spatial_axes - self._do_transform = False self._rand_k = 0 def randomize(self, data: Optional[Any] = None) -> None: self._rand_k = self.R.randint(self.max_k) + 1 - self._do_transform = self.R.random() < self.prob + super().randomize(None) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Mapping[Hashable, np.ndarray]: self.randomize() - if not self._do_transform: - return data + d = dict(data) rotator = Rotate90(self._rand_k, self.spatial_axes) - d = dict(data) for key in self.keys: - d[key] = rotator(d[key]) + if self._do_transform: + d[key] = rotator(d[key]) return d @@ -364,7 +362,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d -class RandAffined(Randomizable, MapTransform): +class RandAffined(RandomizableTransform, MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.RandAffine`. """ @@ -420,9 +418,10 @@ def __init__( - :py:class:`monai.transforms.compose.MapTransform` - :py:class:`RandAffineGrid` for the random affine parameters configurations. """ - super().__init__(keys) + MapTransform.__init__(self, keys) + RandomizableTransform.__init__(self, prob) self.rand_affine = RandAffine( - prob=prob, + prob=1.0, # because probability handled in this class rotate_range=rotate_range, shear_range=shear_range, translate_range=translate_range, @@ -442,6 +441,7 @@ def set_random_state( return self def randomize(self, data: Optional[Any] = None) -> None: + super().randomize(None) self.rand_affine.randomize() def __call__( @@ -451,7 +451,7 @@ def __call__( self.randomize() sp_size = fall_back_tuple(self.rand_affine.spatial_size, data[self.keys[0]].shape[1:]) - if self.rand_affine.do_transform: + if self._do_transform: grid = self.rand_affine.rand_affine_grid(spatial_size=sp_size) else: grid = create_grid(spatial_size=sp_size) @@ -461,7 +461,7 @@ def __call__( return d -class Rand2DElasticd(Randomizable, MapTransform): +class Rand2DElasticd(RandomizableTransform, MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Rand2DElastic`. """ @@ -523,11 +523,12 @@ def __init__( - :py:class:`RandAffineGrid` for the random affine parameters configurations. - :py:class:`Affine` for the affine transformation parameters configurations. """ - super().__init__(keys) + MapTransform.__init__(self, keys) + RandomizableTransform.__init__(self, prob) self.rand_2d_elastic = Rand2DElastic( spacing=spacing, magnitude_range=magnitude_range, - prob=prob, + prob=1.0, # because probability controlled by this class rotate_range=rotate_range, shear_range=shear_range, translate_range=translate_range, @@ -547,6 +548,7 @@ def set_random_state( return self def randomize(self, spatial_size: Sequence[int]) -> None: + super().randomize(None) self.rand_2d_elastic.randomize(spatial_size) def __call__( @@ -557,7 +559,7 @@ def __call__( sp_size = fall_back_tuple(self.rand_2d_elastic.spatial_size, data[self.keys[0]].shape[1:]) self.randomize(spatial_size=sp_size) - if self.rand_2d_elastic.do_transform: + if self._do_transform: grid = self.rand_2d_elastic.deform_grid(spatial_size=sp_size) grid = self.rand_2d_elastic.rand_affine_grid(grid=grid) grid = torch.nn.functional.interpolate( # type: ignore @@ -578,7 +580,7 @@ def __call__( return d -class Rand3DElasticd(Randomizable, MapTransform): +class Rand3DElasticd(RandomizableTransform, MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Rand3DElastic`. """ @@ -641,11 +643,12 @@ def __init__( - :py:class:`RandAffineGrid` for the random affine parameters configurations. - :py:class:`Affine` for the affine transformation parameters configurations. """ - super().__init__(keys) + MapTransform.__init__(self, keys) + RandomizableTransform.__init__(self, prob) self.rand_3d_elastic = Rand3DElastic( sigma_range=sigma_range, magnitude_range=magnitude_range, - prob=prob, + prob=1.0, # because probability controlled by this class rotate_range=rotate_range, shear_range=shear_range, translate_range=translate_range, @@ -665,6 +668,7 @@ def set_random_state( return self def randomize(self, grid_size: Sequence[int]) -> None: + super().randomize(None) self.rand_3d_elastic.randomize(grid_size) def __call__( @@ -675,7 +679,7 @@ def __call__( self.randomize(grid_size=sp_size) grid = create_grid(spatial_size=sp_size) - if self.rand_3d_elastic.do_transform: + if self._do_transform: device = self.rand_3d_elastic.device grid = torch.tensor(grid).to(device) gaussian = GaussianFilter(spatial_dims=3, sigma=self.rand_3d_elastic.sigma, truncated=3.0).to(device) @@ -713,7 +717,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d -class RandFlipd(Randomizable, MapTransform): +class RandFlipd(RandomizableTransform, MapTransform): """ Dictionary-based version :py:class:`monai.transforms.RandFlip`. @@ -732,23 +736,18 @@ def __init__( prob: float = 0.1, spatial_axis: Optional[Union[Sequence[int], int]] = None, ) -> None: - super().__init__(keys) + MapTransform.__init__(self, keys) + RandomizableTransform.__init__(self, prob) self.spatial_axis = spatial_axis - self.prob = prob - self._do_transform = False self.flipper = Flip(spatial_axis=spatial_axis) - def randomize(self, data: Optional[Any] = None) -> None: - self._do_transform = self.R.random_sample() < self.prob - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: - self.randomize() + self.randomize(None) d = dict(data) - if not self._do_transform: - return d for key in self.keys: - d[key] = self.flipper(d[key]) + if self._do_transform: + d[key] = self.flipper(d[key]) return d @@ -810,7 +809,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d -class RandRotated(Randomizable, MapTransform): +class RandRotated(RandomizableTransform, MapTransform): """ Dictionary-based version :py:class:`monai.transforms.RandRotate` Randomly rotates the input arrays. @@ -857,7 +856,8 @@ def __init__( align_corners: Union[Sequence[bool], bool] = False, dtype: Union[Sequence[DtypeLike], DtypeLike] = np.float64, ) -> None: - super().__init__(keys) + MapTransform.__init__(self, keys) + RandomizableTransform.__init__(self, prob) self.range_x = ensure_tuple(range_x) if len(self.range_x) == 1: self.range_x = tuple(sorted([-self.range_x[0], self.range_x[0]])) @@ -868,20 +868,18 @@ def __init__( if len(self.range_z) == 1: self.range_z = tuple(sorted([-self.range_z[0], self.range_z[0]])) - self.prob = prob self.keep_size = keep_size self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.dtype = ensure_tuple_rep(dtype, len(self.keys)) - self._do_transform = False self.x = 0.0 self.y = 0.0 self.z = 0.0 def randomize(self, data: Optional[Any] = None) -> None: - self._do_transform = self.R.random_sample() < self.prob + super().randomize(None) self.x = self.R.uniform(low=self.range_x[0], high=self.range_x[1]) self.y = self.R.uniform(low=self.range_y[0], high=self.range_y[1]) self.z = self.R.uniform(low=self.range_z[0], high=self.range_z[1]) @@ -957,7 +955,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d -class RandZoomd(Randomizable, MapTransform): +class RandZoomd(RandomizableTransform, MapTransform): """ Dict-based version :py:class:`monai.transforms.RandZoom`. @@ -1000,23 +998,22 @@ def __init__( align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None, keep_size: bool = True, ) -> None: - super().__init__(keys) + MapTransform.__init__(self, keys) + RandomizableTransform.__init__(self, prob) self.min_zoom = ensure_tuple(min_zoom) self.max_zoom = ensure_tuple(max_zoom) if len(self.min_zoom) != len(self.max_zoom): raise AssertionError("min_zoom and max_zoom must have same length.") - self.prob = prob self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.keep_size = keep_size - self._do_transform = False self._zoom: Sequence[float] = [1.0] def randomize(self, data: Optional[Any] = None) -> None: - self._do_transform = self.R.random_sample() < self.prob + super().randomize(None) self._zoom = [self.R.uniform(l, h) for l, h in zip(self.min_zoom, self.max_zoom)] def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index e5841cbe97..9c9729d250 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -20,25 +20,13 @@ from monai.config import KeysCollection from monai.utils import MAX_SEED, ensure_tuple -__all__ = ["Randomizable", "Transform", "MapTransform"] +__all__ = ["Randomizable", "RandomizableTransform", "Transform", "MapTransform"] class Randomizable(ABC): """ An interface for handling random state locally, currently based on a class variable `R`, which is an instance of `np.random.RandomState`. - This is mainly for randomized data augmentation transforms. For example:: - - class RandShiftIntensity(Randomizable): - def randomize(): - self._offset = self.R.uniform(low=0, high=100) - def __call__(self, img): - self.randomize() - return img + self._offset - - transform = RandShiftIntensity() - transform.set_random_state(seed=0) - """ R: np.random.RandomState = np.random.RandomState() @@ -77,7 +65,6 @@ def set_random_state( self.R = np.random.RandomState() return self - @abstractmethod def randomize(self, data: Any) -> None: """ Within this method, :py:attr:`self.R` should be used, instead of `np.random`, to introduce random factors. @@ -89,7 +76,6 @@ def randomize(self, data: Any) -> None: Raises: NotImplementedError: When the subclass does not override this method. - """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") @@ -142,6 +128,40 @@ def __call__(self, data: Any): raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") +class RandomizableTransform(Randomizable, Transform): + """ + An interface for handling random state locally, currently based on a class variable `R`, + which is an instance of `np.random.RandomState`. + This is mainly for randomized data augmentation transforms. For example:: + + class RandShiftIntensity(RandomizableTransform): + def randomize(): + self._offset = self.R.uniform(low=0, high=100) + def __call__(self, img): + self.randomize() + return img + self._offset + + transform = RandShiftIntensity() + transform.set_random_state(seed=0) + + """ + + def __init__(self, prob=1.0, do_transform=False): + self._do_transform = do_transform + self.prob = min(max(prob, 0.0), 1.0) + + def randomize(self, data: Any) -> None: + """ + Within this method, :py:attr:`self.R` should be used, instead of `np.random`, to introduce random factors. + + all :py:attr:`self.R` calls happen here so that we have a better chance to + identify errors of sync the random state. + + This method can generate the random factors based on properties of the input data. + """ + self._do_transform = self.R.rand() < self.prob + + class MapTransform(Transform): """ A subclass of :py:class:`monai.transforms.Transform` with an assumption diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 0ee88e1a6c..24d2feb781 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -21,7 +21,7 @@ import torch from monai.config import DtypeLike, NdarrayTensor -from monai.transforms.transform import Randomizable, Transform +from monai.transforms.transform import RandomizableTransform, Transform from monai.transforms.utils import extreme_points_to_image, get_extreme_points, map_binary_to_indices from monai.utils import ensure_tuple, min_version, optional_import @@ -631,7 +631,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: return np.stack(result, axis=0).astype(np.float32) -class AddExtremePointsChannel(Transform, Randomizable): +class AddExtremePointsChannel(RandomizableTransform): """ Add extreme points of label to the image as a new channel. This transform generates extreme point from label and applies a gaussian filter. The pixel values in points image are rescaled diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index f9612c2408..e9d923d0fd 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -23,7 +23,7 @@ import torch from monai.config import DtypeLike, KeysCollection, NdarrayTensor -from monai.transforms.transform import MapTransform, Randomizable +from monai.transforms.transform import MapTransform, RandomizableTransform from monai.transforms.utility.array import ( AddChannel, AsChannelFirst, @@ -682,9 +682,9 @@ def __call__(self, data): return d -class RandLambdad(Lambdad, Randomizable): +class RandLambdad(Lambdad, RandomizableTransform): """ - Randomizable version :py:class:`monai.transforms.Lambdad`, the input `func` contains random logic. + RandomizableTransform version :py:class:`monai.transforms.Lambdad`, the input `func` contains random logic. It's a randomizable transform so `CacheDataset` will not execute it and cache the results. Args: @@ -800,7 +800,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d -class AddExtremePointsChanneld(Randomizable, MapTransform): +class AddExtremePointsChanneld(RandomizableTransform, MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.AddExtremePointsChannel`. @@ -828,7 +828,7 @@ def __init__( rescale_min: float = -1.0, rescale_max: float = 1.0, ): - super().__init__(keys) + MapTransform.__init__(self, keys) self.background = background self.pert = pert self.points: List[Tuple[int, ...]] = [] diff --git a/tests/test_compose.py b/tests/test_compose.py index 3a0a6ea5bb..bb8a5f08c5 100644 --- a/tests/test_compose.py +++ b/tests/test_compose.py @@ -13,11 +13,12 @@ import unittest from monai.data import DataLoader, Dataset -from monai.transforms import AddChannel, Compose, Randomizable +from monai.transforms import AddChannel, Compose +from monai.transforms.transform import RandomizableTransform from monai.utils import set_determinism -class _RandXform(Randomizable): +class _RandXform(RandomizableTransform): def randomize(self): self.val = self.R.random_sample() @@ -79,7 +80,7 @@ def c(d): # transform to handle dict data self.assertDictEqual(item, {"a": 2, "b": 1, "c": 2}) def test_random_compose(self): - class _Acc(Randomizable): + class _Acc(RandomizableTransform): self.rand = 0.0 def randomize(self, data=None): @@ -98,10 +99,13 @@ def __call__(self, data): self.assertAlmostEqual(c(1), 1.90734751) def test_randomize_warn(self): - class _RandomClass(Randomizable): + class _RandomClass(RandomizableTransform): def randomize(self, foo1, foo2): pass + def __call__(self, data): + pass + c = Compose([_RandomClass(), _RandomClass()]) with self.assertWarns(Warning): c.randomize() @@ -168,7 +172,7 @@ def test_flatten_and_len(self): self.assertEqual(len(t1), 8) def test_backwards_compatible_imports(self): - from monai.transforms.compose import MapTransform, Randomizable, Transform # noqa: F401 + from monai.transforms.compose import MapTransform, RandomizableTransform, Transform # noqa: F401 if __name__ == "__main__": diff --git a/tests/test_image_dataset.py b/tests/test_image_dataset.py index d79a7d884c..ec2cf77cd8 100644 --- a/tests/test_image_dataset.py +++ b/tests/test_image_dataset.py @@ -17,12 +17,12 @@ import numpy as np from monai.data import ImageDataset -from monai.transforms import Randomizable +from monai.transforms.transform import RandomizableTransform FILENAMES = ["test1.nii.gz", "test2.nii", "test3.nii.gz"] -class RandTest(Randomizable): +class RandTest(RandomizableTransform): """ randomisable transform for testing. """ diff --git a/tests/test_rand_lambdad.py b/tests/test_rand_lambdad.py index 359da8857a..2ddfeefae0 100644 --- a/tests/test_rand_lambdad.py +++ b/tests/test_rand_lambdad.py @@ -13,11 +13,11 @@ import numpy as np -from monai.transforms import Randomizable +from monai.transforms.transform import RandomizableTransform from monai.transforms.utility.dictionary import RandLambdad -class RandTest(Randomizable): +class RandTest(RandomizableTransform): """ randomisable transform for testing. """ diff --git a/tests/test_randomizable.py b/tests/test_randomizable.py index a7a30124df..9972bded0f 100644 --- a/tests/test_randomizable.py +++ b/tests/test_randomizable.py @@ -13,7 +13,7 @@ import numpy as np -from monai.transforms import Randomizable +from monai.transforms.transform import Randomizable class RandTest(Randomizable):