Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions monai/apps/deepgrow/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from monai.config import IndexSelection, KeysCollection
from monai.networks.layers import GaussianFilter
from monai.transforms import Resize, SpatialCrop
from monai.transforms.transform import MapTransform, RandomizableTransform, Transform
from monai.transforms.transform import MapTransform, Randomizable, Transform
from monai.transforms.utils import generate_spatial_bounding_box
from monai.utils import InterpolateMode, ensure_tuple_rep, min_version, optional_import

Expand Down Expand Up @@ -61,7 +61,7 @@ def __call__(self, data):
return d


class AddInitialSeedPointd(RandomizableTransform):
class AddInitialSeedPointd(Randomizable):
"""
Add random guidance as initial seed point for a given label.

Expand All @@ -86,7 +86,6 @@ def __init__(
sid: str = "sid",
connected_regions: int = 5,
):
super().__init__(prob=1.0, do_transform=True)
self.label = label
self.sids_key = sids
self.sid_key = sid
Expand Down Expand Up @@ -284,7 +283,7 @@ def __call__(self, data):
return d


class AddRandomGuidanced(RandomizableTransform):
class AddRandomGuidanced(Randomizable):
"""
Add random guidance based on discrepancies that were found between label and prediction.

Expand Down Expand Up @@ -320,7 +319,6 @@ def __init__(
probability: str = "probability",
batched: bool = True,
):
super().__init__(prob=1.0, do_transform=True)
self.guidance = guidance
self.discrepancy = discrepancy
self.probability = probability
Expand Down
13 changes: 6 additions & 7 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

from monai.data.utils import first, pickle_hashing
from monai.transforms import Compose, Randomizable, Transform, apply_transform
from monai.transforms.transform import RandomizableTransform
from monai.utils import MAX_SEED, get_seed, min_version, optional_import

if TYPE_CHECKING:
Expand Down Expand Up @@ -182,7 +181,7 @@ def _pre_transform(self, item_transformed):
raise ValueError("transform must be an instance of monai.transforms.Compose.")
for _transform in self.transform.transforms:
# execute all the deterministic transforms
if isinstance(_transform, RandomizableTransform) or not isinstance(_transform, Transform):
if isinstance(_transform, Randomizable) or not isinstance(_transform, Transform):
break
item_transformed = apply_transform(_transform, item_transformed)
return item_transformed
Expand All @@ -204,7 +203,7 @@ def _post_transform(self, item_transformed):
for _transform in self.transform.transforms:
if (
start_post_randomize_run
or isinstance(_transform, RandomizableTransform)
or isinstance(_transform, Randomizable)
or not isinstance(_transform, Transform)
):
start_post_randomize_run = True
Expand Down Expand Up @@ -547,7 +546,7 @@ def _load_cache_item(self, idx: int):
raise ValueError("transform must be an instance of monai.transforms.Compose.")
for _transform in self.transform.transforms:
# execute all the deterministic transforms
if isinstance(_transform, RandomizableTransform) or not isinstance(_transform, Transform):
if isinstance(_transform, Randomizable) or not isinstance(_transform, Transform):
break
item = apply_transform(_transform, item)
return item
Expand All @@ -564,7 +563,7 @@ def _transform(self, index: int):
if not isinstance(self.transform, Compose):
raise ValueError("transform must be an instance of monai.transforms.Compose.")
for _transform in self.transform.transforms:
if start_run or isinstance(_transform, RandomizableTransform) or not isinstance(_transform, Transform):
if start_run or isinstance(_transform, Randomizable) or not isinstance(_transform, Transform):
start_run = True
data = apply_transform(_transform, data)
return data
Expand Down Expand Up @@ -967,10 +966,10 @@ def __getitem__(self, index: int):
# set transforms of each zip component
for dataset in self.dataset.data:
transform = getattr(dataset, "transform", None)
if isinstance(transform, RandomizableTransform):
if isinstance(transform, Randomizable):
transform.set_random_state(seed=self._seed)
transform = getattr(self.dataset, "transform", None)
if isinstance(transform, RandomizableTransform):
if isinstance(transform, Randomizable):
transform.set_random_state(seed=self._seed)
return self.dataset[index]

Expand Down
5 changes: 2 additions & 3 deletions monai/data/image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from monai.config import DtypeLike
from monai.data.image_reader import ImageReader
from monai.transforms import LoadImage, Randomizable, apply_transform
from monai.transforms.transform import RandomizableTransform
from monai.utils import MAX_SEED, get_seed


Expand Down Expand Up @@ -107,14 +106,14 @@ def __getitem__(self, index: int):
label = self.labels[index]

if self.transform is not None:
if isinstance(self.transform, RandomizableTransform):
if isinstance(self.transform, Randomizable):
self.transform.set_random_state(seed=self._seed)
img = apply_transform(self.transform, img)

data = [img]

if self.seg_transform is not None:
if isinstance(self.seg_transform, RandomizableTransform):
if isinstance(self.seg_transform, Randomizable):
self.seg_transform.set_random_state(seed=self._seed)
seg = apply_transform(self.seg_transform, seg)

Expand Down
6 changes: 3 additions & 3 deletions monai/data/test_time_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from monai.data.utils import list_data_collate, pad_list_data_collate
from monai.transforms.compose import Compose
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.transform import RandomizableTransform
from monai.transforms.transform import Randomizable
from monai.transforms.utils import allow_missing_keys_mode
from monai.utils.enums import CommonKeys, InverseKeys

Expand All @@ -47,7 +47,7 @@ class TestTimeAugmentation:

Args:
transform: transform (or composed) to be applied to each realisation. At least one transform must be of type
`RandomizableTransform`. All random transforms must be of type `InvertibleTransform`.
`Randomizable`. All random transforms must be of type `InvertibleTransform`.
batch_size: number of realisations to infer at once.
num_workers: how many subprocesses to use for data.
inferrer_fn: function to use to perform inference.
Expand Down Expand Up @@ -96,7 +96,7 @@ def __init__(
def _check_transforms(self):
"""Should be at least 1 random transform, and all random transforms should be invertible."""
ts = [self.transform] if not isinstance(self.transform, Compose) else self.transform.transforms
randoms = np.array([isinstance(t, RandomizableTransform) for t in ts])
randoms = np.array([isinstance(t, Randomizable) for t in ts])
invertibles = np.array([isinstance(t, InvertibleTransform) for t in ts])
# check at least 1 random
if sum(randoms) == 0:
Expand Down
6 changes: 3 additions & 3 deletions monai/transforms/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
__all__ = ["Compose"]


class Compose(RandomizableTransform, InvertibleTransform):
class Compose(Randomizable, InvertibleTransform):
"""
``Compose`` provides the ability to chain a series of calls together in a
sequence. Each transform in the sequence must take a single argument and
Expand Down Expand Up @@ -102,14 +102,14 @@ def __init__(self, transforms: Optional[Union[Sequence[Callable], Callable]] = N
def set_random_state(self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None) -> "Compose":
super().set_random_state(seed=seed, state=state)
for _transform in self.transforms:
if not isinstance(_transform, RandomizableTransform):
if not isinstance(_transform, Randomizable):
continue
_transform.set_random_state(seed=self.R.randint(MAX_SEED, dtype="uint32"))
return self

def randomize(self, data: Optional[Any] = None) -> None:
for _transform in self.transforms:
if not isinstance(_transform, RandomizableTransform):
if not isinstance(_transform, Randomizable):
continue
try:
_transform.randomize(data)
Expand Down
10 changes: 5 additions & 5 deletions monai/transforms/croppad/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from monai.config import IndexSelection
from monai.data.utils import get_random_patch, get_valid_patch_size
from monai.transforms.transform import Randomizable, RandomizableTransform, Transform
from monai.transforms.transform import Randomizable, Transform
from monai.transforms.utils import (
generate_pos_neg_label_crop_centers,
generate_spatial_bounding_box,
Expand Down Expand Up @@ -279,7 +279,7 @@ def __call__(self, img: np.ndarray):
return cropper(img)


class RandSpatialCrop(RandomizableTransform):
class RandSpatialCrop(Randomizable):
"""
Crop image with random size or specific size ROI. It can crop at a random position as center
or at the image center. And allows to set the minimum size to limit the randomly generated ROI.
Expand Down Expand Up @@ -324,7 +324,7 @@ def __call__(self, img: np.ndarray):
return cropper(img)


class RandSpatialCropSamples(RandomizableTransform):
class RandSpatialCropSamples(Randomizable):
"""
Crop image with random size or specific size ROI to generate a list of N samples.
It can crop at a random position as center or at the image center. And allows to set
Expand Down Expand Up @@ -432,7 +432,7 @@ def __call__(self, img: np.ndarray):
return cropped


class RandWeightedCrop(RandomizableTransform):
class RandWeightedCrop(Randomizable):
"""
Samples a list of `num_samples` image patches according to the provided `weight_map`.

Expand Down Expand Up @@ -484,7 +484,7 @@ def __call__(self, img: np.ndarray, weight_map: Optional[np.ndarray] = None) ->
return results


class RandCropByPosNegLabel(RandomizableTransform):
class RandCropByPosNegLabel(Randomizable):
"""
Crop random fixed sized regions with the center being a foreground or background voxel
based on the Pos Neg Ratio.
Expand Down
14 changes: 5 additions & 9 deletions monai/transforms/croppad/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
SpatialPad,
)
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform
from monai.transforms.transform import MapTransform, Randomizable
from monai.transforms.utils import (
generate_pos_neg_label_crop_centers,
generate_spatial_bounding_box,
Expand Down Expand Up @@ -386,7 +386,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar
return d


class RandSpatialCropd(RandomizableTransform, MapTransform, InvertibleTransform):
class RandSpatialCropd(Randomizable, MapTransform, InvertibleTransform):
"""
Dictionary-based version :py:class:`monai.transforms.RandSpatialCrop`.
Crop image with random size or specific size ROI. It can crop at a random position as
Expand All @@ -413,7 +413,6 @@ def __init__(
random_size: bool = True,
allow_missing_keys: bool = False,
) -> None:
RandomizableTransform.__init__(self, prob=1.0, do_transform=True)
MapTransform.__init__(self, keys, allow_missing_keys)
self.roi_size = roi_size
self.random_center = random_center
Expand Down Expand Up @@ -477,7 +476,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar
return d


class RandSpatialCropSamplesd(RandomizableTransform, MapTransform):
class RandSpatialCropSamplesd(Randomizable, MapTransform):
"""
Dictionary-based version :py:class:`monai.transforms.RandSpatialCropSamples`.
Crop image with random size or specific size ROI to generate a list of N samples.
Expand Down Expand Up @@ -515,7 +514,6 @@ def __init__(
meta_key_postfix: str = "meta_dict",
allow_missing_keys: bool = False,
) -> None:
RandomizableTransform.__init__(self, prob=1.0, do_transform=True)
MapTransform.__init__(self, keys, allow_missing_keys)
if num_samples < 1:
raise ValueError(f"num_samples must be positive, got {num_samples}.")
Expand Down Expand Up @@ -626,7 +624,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar
return d


class RandWeightedCropd(RandomizableTransform, MapTransform):
class RandWeightedCropd(Randomizable, MapTransform):
"""
Samples a list of `num_samples` image patches according to the provided `weight_map`.

Expand Down Expand Up @@ -654,7 +652,6 @@ def __init__(
center_coord_key: Optional[str] = None,
allow_missing_keys: bool = False,
):
RandomizableTransform.__init__(self, prob=1.0, do_transform=True)
MapTransform.__init__(self, keys, allow_missing_keys)
self.spatial_size = ensure_tuple(spatial_size)
self.w_key = w_key
Expand Down Expand Up @@ -693,7 +690,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n
return results


class RandCropByPosNegLabeld(RandomizableTransform, MapTransform):
class RandCropByPosNegLabeld(Randomizable, MapTransform):
"""
Dictionary-based version :py:class:`monai.transforms.RandCropByPosNegLabel`.
Crop random fixed sized regions with the center being a foreground or background voxel
Expand Down Expand Up @@ -751,7 +748,6 @@ def __init__(
meta_key_postfix: str = "meta_dict",
allow_missing_keys: bool = False,
) -> None:
RandomizableTransform.__init__(self)
MapTransform.__init__(self, keys, allow_missing_keys)
self.label_key = label_key
self.spatial_size: Union[Tuple[int, ...], Sequence[int], int] = spatial_size
Expand Down
7 changes: 7 additions & 0 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def __init__(self, offsets: Union[Tuple[float, float], float], prob: float = 0.1
if len(offsets) != 2:
raise AssertionError("offsets should be a number or pair of numbers.")
self.offsets = (min(offsets), max(offsets))
self._offset = self.offsets[0]

def randomize(self, data: Optional[Any] = None) -> None:
self._offset = self.R.uniform(low=self.offsets[0], high=self.offsets[1])
Expand Down Expand Up @@ -217,6 +218,7 @@ def __init__(
if len(factors) != 2:
raise AssertionError("factors should be a number or pair of numbers.")
self.factors = (min(factors), max(factors))
self.factor = self.factors[0]
self.nonzero = nonzero
self.channel_wise = channel_wise
self.dtype = dtype
Expand Down Expand Up @@ -294,6 +296,7 @@ def __init__(self, factors: Union[Tuple[float, float], float], prob: float = 0.1
if len(factors) != 2:
raise AssertionError("factors should be a number or pair of numbers.")
self.factors = (min(factors), max(factors))
self.factor = self.factors[0]

def randomize(self, data: Optional[Any] = None) -> None:
self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1])
Expand Down Expand Up @@ -874,6 +877,10 @@ def __init__(
self.sigma_z = sigma_z
self.approx = approx

self.x = self.sigma_x[0]
self.y = self.sigma_y[0]
self.z = self.sigma_z[0]

def randomize(self, data: Optional[Any] = None) -> None:
super().randomize(None)
self.x = self.R.uniform(low=self.sigma_x[0], high=self.sigma_x[1])
Expand Down
9 changes: 6 additions & 3 deletions monai/transforms/intensity/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def __init__(
if len(offsets) != 2:
raise AssertionError("offsets should be a number or pair of numbers.")
self.offsets = (min(offsets), max(offsets))
self._offset = self.offsets[0]

def randomize(self, data: Optional[Any] = None) -> None:
self._offset = self.R.uniform(low=self.offsets[0], high=self.offsets[1])
Expand Down Expand Up @@ -293,6 +294,7 @@ def __init__(
if len(factors) != 2:
raise AssertionError("factors should be a number or pair of numbers.")
self.factors = (min(factors), max(factors))
self.factor = self.factors[0]
self.nonzero = nonzero
self.channel_wise = channel_wise
self.dtype = dtype
Expand Down Expand Up @@ -380,6 +382,7 @@ def __init__(
if len(factors) != 2:
raise AssertionError("factors should be a number or pair of numbers.")
self.factors = (min(factors), max(factors))
self.factor = self.factors[0]

def randomize(self, data: Optional[Any] = None) -> None:
self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1])
Expand Down Expand Up @@ -760,11 +763,11 @@ def __init__(
) -> None:
MapTransform.__init__(self, keys, allow_missing_keys)
RandomizableTransform.__init__(self, prob)
self.sigma_x = sigma_x
self.sigma_y = sigma_y
self.sigma_z = sigma_z
self.sigma_x, self.sigma_y, self.sigma_z = sigma_x, sigma_y, sigma_z
self.approx = approx

self.x, self.y, self.z = self.sigma_x[0], self.sigma_y[0], self.sigma_z[0]

def randomize(self, data: Optional[Any] = None) -> None:
super().randomize(None)
self.x = self.R.uniform(low=self.sigma_x[0], high=self.sigma_x[1])
Expand Down
Loading