From 291987a9e1f340ac6ef3584a0e0c329821afb575 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 5 Apr 2021 22:32:31 +0100 Subject: [PATCH] remove unused RandomizableTransform Signed-off-by: Wenqi Li --- monai/apps/deepgrow/transforms.py | 8 +++----- monai/data/dataset.py | 13 ++++++------- monai/data/image_dataset.py | 5 ++--- monai/data/test_time_augmentation.py | 6 +++--- monai/transforms/compose.py | 6 +++--- monai/transforms/croppad/array.py | 10 +++++----- monai/transforms/croppad/dictionary.py | 14 +++++--------- monai/transforms/intensity/array.py | 7 +++++++ monai/transforms/intensity/dictionary.py | 9 ++++++--- monai/transforms/spatial/array.py | 8 ++++---- monai/transforms/transform.py | 16 +++++++++++++--- monai/transforms/utility/array.py | 4 ++-- monai/transforms/utility/dictionary.py | 8 ++++---- tests/test_compose.py | 8 ++++---- tests/test_rand_lambdad.py | 4 ++-- 15 files changed, 69 insertions(+), 57 deletions(-) diff --git a/monai/apps/deepgrow/transforms.py b/monai/apps/deepgrow/transforms.py index c58d4c1123..3d8f08bc01 100644 --- a/monai/apps/deepgrow/transforms.py +++ b/monai/apps/deepgrow/transforms.py @@ -16,7 +16,7 @@ from monai.config import IndexSelection, KeysCollection from monai.networks.layers import GaussianFilter from monai.transforms import Resize, SpatialCrop -from monai.transforms.transform import MapTransform, RandomizableTransform, Transform +from monai.transforms.transform import MapTransform, Randomizable, Transform from monai.transforms.utils import generate_spatial_bounding_box from monai.utils import InterpolateMode, ensure_tuple_rep, min_version, optional_import @@ -61,7 +61,7 @@ def __call__(self, data): return d -class AddInitialSeedPointd(RandomizableTransform): +class AddInitialSeedPointd(Randomizable): """ Add random guidance as initial seed point for a given label. @@ -86,7 +86,6 @@ def __init__( sid: str = "sid", connected_regions: int = 5, ): - super().__init__(prob=1.0, do_transform=True) self.label = label self.sids_key = sids self.sid_key = sid @@ -284,7 +283,7 @@ def __call__(self, data): return d -class AddRandomGuidanced(RandomizableTransform): +class AddRandomGuidanced(Randomizable): """ Add random guidance based on discrepancies that were found between label and prediction. @@ -320,7 +319,6 @@ def __init__( probability: str = "probability", batched: bool = True, ): - super().__init__(prob=1.0, do_transform=True) self.guidance = guidance self.discrepancy = discrepancy self.probability = probability diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 12403bbff1..a09050e5bc 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -29,7 +29,6 @@ from monai.data.utils import first, pickle_hashing from monai.transforms import Compose, Randomizable, Transform, apply_transform -from monai.transforms.transform import RandomizableTransform from monai.utils import MAX_SEED, get_seed, min_version, optional_import if TYPE_CHECKING: @@ -182,7 +181,7 @@ def _pre_transform(self, item_transformed): raise ValueError("transform must be an instance of monai.transforms.Compose.") for _transform in self.transform.transforms: # execute all the deterministic transforms - if isinstance(_transform, RandomizableTransform) or not isinstance(_transform, Transform): + if isinstance(_transform, Randomizable) or not isinstance(_transform, Transform): break item_transformed = apply_transform(_transform, item_transformed) return item_transformed @@ -204,7 +203,7 @@ def _post_transform(self, item_transformed): for _transform in self.transform.transforms: if ( start_post_randomize_run - or isinstance(_transform, RandomizableTransform) + or isinstance(_transform, Randomizable) or not isinstance(_transform, Transform) ): start_post_randomize_run = True @@ -547,7 +546,7 @@ def _load_cache_item(self, idx: int): raise ValueError("transform must be an instance of monai.transforms.Compose.") for _transform in self.transform.transforms: # execute all the deterministic transforms - if isinstance(_transform, RandomizableTransform) or not isinstance(_transform, Transform): + if isinstance(_transform, Randomizable) or not isinstance(_transform, Transform): break item = apply_transform(_transform, item) return item @@ -564,7 +563,7 @@ def _transform(self, index: int): if not isinstance(self.transform, Compose): raise ValueError("transform must be an instance of monai.transforms.Compose.") for _transform in self.transform.transforms: - if start_run or isinstance(_transform, RandomizableTransform) or not isinstance(_transform, Transform): + if start_run or isinstance(_transform, Randomizable) or not isinstance(_transform, Transform): start_run = True data = apply_transform(_transform, data) return data @@ -967,10 +966,10 @@ def __getitem__(self, index: int): # set transforms of each zip component for dataset in self.dataset.data: transform = getattr(dataset, "transform", None) - if isinstance(transform, RandomizableTransform): + if isinstance(transform, Randomizable): transform.set_random_state(seed=self._seed) transform = getattr(self.dataset, "transform", None) - if isinstance(transform, RandomizableTransform): + if isinstance(transform, Randomizable): transform.set_random_state(seed=self._seed) return self.dataset[index] diff --git a/monai/data/image_dataset.py b/monai/data/image_dataset.py index 1074105508..1568e082ee 100644 --- a/monai/data/image_dataset.py +++ b/monai/data/image_dataset.py @@ -17,7 +17,6 @@ from monai.config import DtypeLike from monai.data.image_reader import ImageReader from monai.transforms import LoadImage, Randomizable, apply_transform -from monai.transforms.transform import RandomizableTransform from monai.utils import MAX_SEED, get_seed @@ -107,14 +106,14 @@ def __getitem__(self, index: int): label = self.labels[index] if self.transform is not None: - if isinstance(self.transform, RandomizableTransform): + if isinstance(self.transform, Randomizable): self.transform.set_random_state(seed=self._seed) img = apply_transform(self.transform, img) data = [img] if self.seg_transform is not None: - if isinstance(self.seg_transform, RandomizableTransform): + if isinstance(self.seg_transform, Randomizable): self.seg_transform.set_random_state(seed=self._seed) seg = apply_transform(self.seg_transform, seg) diff --git a/monai/data/test_time_augmentation.py b/monai/data/test_time_augmentation.py index 51b95adc58..06e1f63da5 100644 --- a/monai/data/test_time_augmentation.py +++ b/monai/data/test_time_augmentation.py @@ -20,7 +20,7 @@ from monai.data.utils import list_data_collate, pad_list_data_collate from monai.transforms.compose import Compose from monai.transforms.inverse import InvertibleTransform -from monai.transforms.transform import RandomizableTransform +from monai.transforms.transform import Randomizable from monai.transforms.utils import allow_missing_keys_mode from monai.utils.enums import CommonKeys, InverseKeys @@ -47,7 +47,7 @@ class TestTimeAugmentation: Args: transform: transform (or composed) to be applied to each realisation. At least one transform must be of type - `RandomizableTransform`. All random transforms must be of type `InvertibleTransform`. + `Randomizable`. All random transforms must be of type `InvertibleTransform`. batch_size: number of realisations to infer at once. num_workers: how many subprocesses to use for data. inferrer_fn: function to use to perform inference. @@ -96,7 +96,7 @@ def __init__( def _check_transforms(self): """Should be at least 1 random transform, and all random transforms should be invertible.""" ts = [self.transform] if not isinstance(self.transform, Compose) else self.transform.transforms - randoms = np.array([isinstance(t, RandomizableTransform) for t in ts]) + randoms = np.array([isinstance(t, Randomizable) for t in ts]) invertibles = np.array([isinstance(t, InvertibleTransform) for t in ts]) # check at least 1 random if sum(randoms) == 0: diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index dd40663e2a..ce965b8b18 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -32,7 +32,7 @@ __all__ = ["Compose"] -class Compose(RandomizableTransform, InvertibleTransform): +class Compose(Randomizable, InvertibleTransform): """ ``Compose`` provides the ability to chain a series of calls together in a sequence. Each transform in the sequence must take a single argument and @@ -102,14 +102,14 @@ def __init__(self, transforms: Optional[Union[Sequence[Callable], Callable]] = N def set_random_state(self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None) -> "Compose": super().set_random_state(seed=seed, state=state) for _transform in self.transforms: - if not isinstance(_transform, RandomizableTransform): + if not isinstance(_transform, Randomizable): continue _transform.set_random_state(seed=self.R.randint(MAX_SEED, dtype="uint32")) return self def randomize(self, data: Optional[Any] = None) -> None: for _transform in self.transforms: - if not isinstance(_transform, RandomizableTransform): + if not isinstance(_transform, Randomizable): continue try: _transform.randomize(data) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 159fa1a5f4..c8f7136334 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -20,7 +20,7 @@ from monai.config import IndexSelection from monai.data.utils import get_random_patch, get_valid_patch_size -from monai.transforms.transform import Randomizable, RandomizableTransform, Transform +from monai.transforms.transform import Randomizable, Transform from monai.transforms.utils import ( generate_pos_neg_label_crop_centers, generate_spatial_bounding_box, @@ -279,7 +279,7 @@ def __call__(self, img: np.ndarray): return cropper(img) -class RandSpatialCrop(RandomizableTransform): +class RandSpatialCrop(Randomizable): """ Crop image with random size or specific size ROI. It can crop at a random position as center or at the image center. And allows to set the minimum size to limit the randomly generated ROI. @@ -324,7 +324,7 @@ def __call__(self, img: np.ndarray): return cropper(img) -class RandSpatialCropSamples(RandomizableTransform): +class RandSpatialCropSamples(Randomizable): """ Crop image with random size or specific size ROI to generate a list of N samples. It can crop at a random position as center or at the image center. And allows to set @@ -432,7 +432,7 @@ def __call__(self, img: np.ndarray): return cropped -class RandWeightedCrop(RandomizableTransform): +class RandWeightedCrop(Randomizable): """ Samples a list of `num_samples` image patches according to the provided `weight_map`. @@ -484,7 +484,7 @@ def __call__(self, img: np.ndarray, weight_map: Optional[np.ndarray] = None) -> return results -class RandCropByPosNegLabel(RandomizableTransform): +class RandCropByPosNegLabel(Randomizable): """ Crop random fixed sized regions with the center being a foreground or background voxel based on the Pos Neg Ratio. diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 1d4fcfdb1f..c8d5ceea40 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -34,7 +34,7 @@ SpatialPad, ) from monai.transforms.inverse import InvertibleTransform -from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform +from monai.transforms.transform import MapTransform, Randomizable from monai.transforms.utils import ( generate_pos_neg_label_crop_centers, generate_spatial_bounding_box, @@ -386,7 +386,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar return d -class RandSpatialCropd(RandomizableTransform, MapTransform, InvertibleTransform): +class RandSpatialCropd(Randomizable, MapTransform, InvertibleTransform): """ Dictionary-based version :py:class:`monai.transforms.RandSpatialCrop`. Crop image with random size or specific size ROI. It can crop at a random position as @@ -413,7 +413,6 @@ def __init__( random_size: bool = True, allow_missing_keys: bool = False, ) -> None: - RandomizableTransform.__init__(self, prob=1.0, do_transform=True) MapTransform.__init__(self, keys, allow_missing_keys) self.roi_size = roi_size self.random_center = random_center @@ -477,7 +476,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar return d -class RandSpatialCropSamplesd(RandomizableTransform, MapTransform): +class RandSpatialCropSamplesd(Randomizable, MapTransform): """ Dictionary-based version :py:class:`monai.transforms.RandSpatialCropSamples`. Crop image with random size or specific size ROI to generate a list of N samples. @@ -515,7 +514,6 @@ def __init__( meta_key_postfix: str = "meta_dict", allow_missing_keys: bool = False, ) -> None: - RandomizableTransform.__init__(self, prob=1.0, do_transform=True) MapTransform.__init__(self, keys, allow_missing_keys) if num_samples < 1: raise ValueError(f"num_samples must be positive, got {num_samples}.") @@ -626,7 +624,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar return d -class RandWeightedCropd(RandomizableTransform, MapTransform): +class RandWeightedCropd(Randomizable, MapTransform): """ Samples a list of `num_samples` image patches according to the provided `weight_map`. @@ -654,7 +652,6 @@ def __init__( center_coord_key: Optional[str] = None, allow_missing_keys: bool = False, ): - RandomizableTransform.__init__(self, prob=1.0, do_transform=True) MapTransform.__init__(self, keys, allow_missing_keys) self.spatial_size = ensure_tuple(spatial_size) self.w_key = w_key @@ -693,7 +690,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n return results -class RandCropByPosNegLabeld(RandomizableTransform, MapTransform): +class RandCropByPosNegLabeld(Randomizable, MapTransform): """ Dictionary-based version :py:class:`monai.transforms.RandCropByPosNegLabel`. Crop random fixed sized regions with the center being a foreground or background voxel @@ -751,7 +748,6 @@ def __init__( meta_key_postfix: str = "meta_dict", allow_missing_keys: bool = False, ) -> None: - RandomizableTransform.__init__(self) MapTransform.__init__(self, keys, allow_missing_keys) self.label_key = label_key self.spatial_size: Union[Tuple[int, ...], Sequence[int], int] = spatial_size diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index f89e381daa..62350d4ab0 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -122,6 +122,7 @@ def __init__(self, offsets: Union[Tuple[float, float], float], prob: float = 0.1 if len(offsets) != 2: raise AssertionError("offsets should be a number or pair of numbers.") self.offsets = (min(offsets), max(offsets)) + self._offset = self.offsets[0] def randomize(self, data: Optional[Any] = None) -> None: self._offset = self.R.uniform(low=self.offsets[0], high=self.offsets[1]) @@ -217,6 +218,7 @@ def __init__( if len(factors) != 2: raise AssertionError("factors should be a number or pair of numbers.") self.factors = (min(factors), max(factors)) + self.factor = self.factors[0] self.nonzero = nonzero self.channel_wise = channel_wise self.dtype = dtype @@ -294,6 +296,7 @@ def __init__(self, factors: Union[Tuple[float, float], float], prob: float = 0.1 if len(factors) != 2: raise AssertionError("factors should be a number or pair of numbers.") self.factors = (min(factors), max(factors)) + self.factor = self.factors[0] def randomize(self, data: Optional[Any] = None) -> None: self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1]) @@ -874,6 +877,10 @@ def __init__( self.sigma_z = sigma_z self.approx = approx + self.x = self.sigma_x[0] + self.y = self.sigma_y[0] + self.z = self.sigma_z[0] + def randomize(self, data: Optional[Any] = None) -> None: super().randomize(None) self.x = self.R.uniform(low=self.sigma_x[0], high=self.sigma_x[1]) diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 517c34cbf2..a35e5c8ea6 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -206,6 +206,7 @@ def __init__( if len(offsets) != 2: raise AssertionError("offsets should be a number or pair of numbers.") self.offsets = (min(offsets), max(offsets)) + self._offset = self.offsets[0] def randomize(self, data: Optional[Any] = None) -> None: self._offset = self.R.uniform(low=self.offsets[0], high=self.offsets[1]) @@ -293,6 +294,7 @@ def __init__( if len(factors) != 2: raise AssertionError("factors should be a number or pair of numbers.") self.factors = (min(factors), max(factors)) + self.factor = self.factors[0] self.nonzero = nonzero self.channel_wise = channel_wise self.dtype = dtype @@ -380,6 +382,7 @@ def __init__( if len(factors) != 2: raise AssertionError("factors should be a number or pair of numbers.") self.factors = (min(factors), max(factors)) + self.factor = self.factors[0] def randomize(self, data: Optional[Any] = None) -> None: self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1]) @@ -760,11 +763,11 @@ def __init__( ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) - self.sigma_x = sigma_x - self.sigma_y = sigma_y - self.sigma_z = sigma_z + self.sigma_x, self.sigma_y, self.sigma_z = sigma_x, sigma_y, sigma_z self.approx = approx + self.x, self.y, self.z = self.sigma_x[0], self.sigma_y[0], self.sigma_z[0] + def randomize(self, data: Optional[Any] = None) -> None: super().randomize(None) self.x = self.R.uniform(low=self.sigma_x[0], high=self.sigma_x[1]) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 1c096ba743..a3eb055f7e 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -23,7 +23,7 @@ from monai.data.utils import compute_shape_offset, to_affine_nd, zoom_affine from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull from monai.transforms.croppad.array import CenterSpatialCrop -from monai.transforms.transform import RandomizableTransform, Transform +from monai.transforms.transform import Randomizable, RandomizableTransform, Transform from monai.transforms.utils import ( create_control_grid, create_grid, @@ -790,7 +790,7 @@ class RandAxisFlip(RandomizableTransform): """ def __init__(self, prob: float = 0.1) -> None: - RandomizableTransform.__init__(self, min(max(prob, 0.0), 1.0)) + RandomizableTransform.__init__(self, prob) self._axis: Optional[int] = None def randomize(self, data: np.ndarray) -> None: @@ -1004,7 +1004,7 @@ def __call__( return grid if self.as_tensor_output else np.asarray(grid.cpu().numpy()), affine -class RandAffineGrid(RandomizableTransform): +class RandAffineGrid(Randomizable): """ Generate randomised affine grid. """ @@ -1101,7 +1101,7 @@ def get_transformation_matrix(self) -> Optional[Union[np.ndarray, torch.Tensor]] return self.affine -class RandDeformGrid(RandomizableTransform): +class RandDeformGrid(Randomizable): """ Generate random deformation grid. """ diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 6a22db1076..ff5f021739 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -180,17 +180,27 @@ class RandomizableTransform(Randomizable, Transform): """ An interface for handling random state locally, currently based on a class variable `R`, which is an instance of `np.random.RandomState`. - This is mainly for randomized data augmentation transforms. For example:: + This class introduces a randomized flag `_do_transform`, is mainly for randomized data augmentation transforms. + For example: - class RandShiftIntensity(RandomizableTransform): - def randomize(): + .. code-block:: python + + from monai.transforms import RandomizableTransform + + class RandShiftIntensity100(RandomizableTransform): + def randomize(self): + super().randomize(None) self._offset = self.R.uniform(low=0, high=100) + def __call__(self, img): self.randomize() + if not self._do_transform: + return img return img + self._offset transform = RandShiftIntensity() transform.set_random_state(seed=0) + print(transform(10)) """ diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 8e0dabafb2..6903b2628d 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -22,7 +22,7 @@ import torch from monai.config import DtypeLike, NdarrayTensor -from monai.transforms.transform import RandomizableTransform, Transform +from monai.transforms.transform import Randomizable, Transform from monai.transforms.utils import extreme_points_to_image, get_extreme_points, map_binary_to_indices from monai.utils import ensure_tuple, min_version, optional_import @@ -667,7 +667,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: return np.stack(result, axis=0) -class AddExtremePointsChannel(RandomizableTransform): +class AddExtremePointsChannel(Randomizable): """ Add extreme points of label to the image as a new channel. This transform generates extreme point from label and applies a gaussian filter. The pixel values in points image are rescaled diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index c437cd055b..9464faa503 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -23,7 +23,7 @@ import torch from monai.config import DtypeLike, KeysCollection, NdarrayTensor -from monai.transforms.transform import MapTransform, RandomizableTransform +from monai.transforms.transform import MapTransform, Randomizable from monai.transforms.utility.array import ( AddChannel, AsChannelFirst, @@ -731,9 +731,9 @@ def __call__(self, data): return d -class RandLambdad(Lambdad, RandomizableTransform): +class RandLambdad(Lambdad, Randomizable): """ - RandomizableTransform version :py:class:`monai.transforms.Lambdad`, the input `func` contains random logic. + Randomizable version :py:class:`monai.transforms.Lambdad`, the input `func` contains random logic. It's a randomizable transform so `CacheDataset` will not execute it and cache the results. Args: @@ -853,7 +853,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d -class AddExtremePointsChanneld(RandomizableTransform, MapTransform): +class AddExtremePointsChanneld(Randomizable, MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.AddExtremePointsChannel`. diff --git a/tests/test_compose.py b/tests/test_compose.py index bb8a5f08c5..97b044af8f 100644 --- a/tests/test_compose.py +++ b/tests/test_compose.py @@ -14,11 +14,11 @@ from monai.data import DataLoader, Dataset from monai.transforms import AddChannel, Compose -from monai.transforms.transform import RandomizableTransform +from monai.transforms.transform import Randomizable from monai.utils import set_determinism -class _RandXform(RandomizableTransform): +class _RandXform(Randomizable): def randomize(self): self.val = self.R.random_sample() @@ -80,7 +80,7 @@ def c(d): # transform to handle dict data self.assertDictEqual(item, {"a": 2, "b": 1, "c": 2}) def test_random_compose(self): - class _Acc(RandomizableTransform): + class _Acc(Randomizable): self.rand = 0.0 def randomize(self, data=None): @@ -99,7 +99,7 @@ def __call__(self, data): self.assertAlmostEqual(c(1), 1.90734751) def test_randomize_warn(self): - class _RandomClass(RandomizableTransform): + class _RandomClass(Randomizable): def randomize(self, foo1, foo2): pass diff --git a/tests/test_rand_lambdad.py b/tests/test_rand_lambdad.py index 2ddfeefae0..a450b67413 100644 --- a/tests/test_rand_lambdad.py +++ b/tests/test_rand_lambdad.py @@ -13,11 +13,11 @@ import numpy as np -from monai.transforms.transform import RandomizableTransform +from monai.transforms.transform import Randomizable from monai.transforms.utility.dictionary import RandLambdad -class RandTest(RandomizableTransform): +class RandTest(Randomizable): """ randomisable transform for testing. """