Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ Generic Interfaces
.. autoclass:: Randomizable
:members:

`RandomizableTransform`
^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: RandomizableTransform
:members:

`Compose`
^^^^^^^^^
.. autoclass:: Compose
Expand Down
6 changes: 3 additions & 3 deletions monai/apps/deepgrow/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from monai.config import IndexSelection, KeysCollection
from monai.networks.layers import GaussianFilter
from monai.transforms import SpatialCrop
from monai.transforms.transform import MapTransform, Randomizable, Transform
from monai.transforms.transform import MapTransform, RandomizableTransform, Transform
from monai.transforms.utils import generate_spatial_bounding_box
from monai.utils import min_version, optional_import

Expand Down Expand Up @@ -62,7 +62,7 @@ def __call__(self, data):
return d


class AddInitialSeedPointd(Randomizable, Transform):
class AddInitialSeedPointd(RandomizableTransform):
"""
Add random guidance as initial seed point for a given label.

Expand Down Expand Up @@ -279,7 +279,7 @@ def __call__(self, data):
return d


class AddRandomGuidanced(Randomizable, Transform):
class AddRandomGuidanced(RandomizableTransform):
"""
Add random guidance based on discrepancies that were found between label and prediction.

Expand Down
13 changes: 7 additions & 6 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from monai.data.utils import pickle_hashing
from monai.transforms import Compose, Randomizable, Transform, apply_transform
from monai.transforms.transform import RandomizableTransform
from monai.utils import MAX_SEED, get_seed, min_version, optional_import

if TYPE_CHECKING:
Expand Down Expand Up @@ -161,7 +162,7 @@ def _pre_transform(self, item_transformed):
raise ValueError("transform must be an instance of monai.transforms.Compose.")
for _transform in self.transform.transforms:
# execute all the deterministic transforms
if isinstance(_transform, Randomizable) or not isinstance(_transform, Transform):
if isinstance(_transform, RandomizableTransform) or not isinstance(_transform, Transform):
break
item_transformed = apply_transform(_transform, item_transformed)
return item_transformed
Expand All @@ -183,7 +184,7 @@ def _post_transform(self, item_transformed):
for _transform in self.transform.transforms:
if (
start_post_randomize_run
or isinstance(_transform, Randomizable)
or isinstance(_transform, RandomizableTransform)
or not isinstance(_transform, Transform)
):
start_post_randomize_run = True
Expand Down Expand Up @@ -522,7 +523,7 @@ def _load_cache_item(self, idx: int):
raise ValueError("transform must be an instance of monai.transforms.Compose.")
for _transform in self.transform.transforms:
# execute all the deterministic transforms
if isinstance(_transform, Randomizable) or not isinstance(_transform, Transform):
if isinstance(_transform, RandomizableTransform) or not isinstance(_transform, Transform):
break
item = apply_transform(_transform, item)
return item
Expand All @@ -539,7 +540,7 @@ def __getitem__(self, index):
if not isinstance(self.transform, Compose):
raise ValueError("transform must be an instance of monai.transforms.Compose.")
for _transform in self.transform.transforms:
if start_run or isinstance(_transform, Randomizable) or not isinstance(_transform, Transform):
if start_run or isinstance(_transform, RandomizableTransform) or not isinstance(_transform, Transform):
start_run = True
data = apply_transform(_transform, data)
return data
Expand Down Expand Up @@ -924,9 +925,9 @@ def __getitem__(self, index: int):
# set transforms of each zip component
for dataset in self.dataset.data:
transform = getattr(dataset, "transform", None)
if isinstance(transform, Randomizable):
if isinstance(transform, RandomizableTransform):
transform.set_random_state(seed=self._seed)
transform = getattr(self.dataset, "transform", None)
if isinstance(transform, Randomizable):
if isinstance(transform, RandomizableTransform):
transform.set_random_state(seed=self._seed)
return self.dataset[index]
5 changes: 3 additions & 2 deletions monai/data/image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from monai.config import DtypeLike
from monai.data.image_reader import ImageReader
from monai.transforms import LoadImage, Randomizable, apply_transform
from monai.transforms.transform import RandomizableTransform
from monai.utils import MAX_SEED, get_seed


Expand Down Expand Up @@ -106,14 +107,14 @@ def __getitem__(self, index: int):
label = self.labels[index]

if self.transform is not None:
if isinstance(self.transform, Randomizable):
if isinstance(self.transform, RandomizableTransform):
self.transform.set_random_state(seed=self._seed)
img = apply_transform(self.transform, img)

data = [img]

if self.seg_transform is not None:
if isinstance(self.seg_transform, Randomizable):
if isinstance(self.seg_transform, RandomizableTransform):
self.seg_transform.set_random_state(seed=self._seed)
seg = apply_transform(self.seg_transform, seg)

Expand Down
2 changes: 1 addition & 1 deletion monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@
ZoomD,
ZoomDict,
)
from .transform import MapTransform, Randomizable, Transform
from .transform import MapTransform, Randomizable, RandomizableTransform, Transform
from .utility.array import (
AddChannel,
AddExtremePointsChannel,
Expand Down
9 changes: 4 additions & 5 deletions monai/transforms/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,14 @@
import numpy as np

# For backwards compatiblity (so this still works: from monai.transforms.compose import MapTransform)
from monai.transforms.transform import MapTransform # noqa: F401
from monai.transforms.transform import Randomizable, Transform
from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform, Transform # noqa: F401
from monai.transforms.utils import apply_transform
from monai.utils import MAX_SEED, ensure_tuple, get_seed

__all__ = ["Compose"]


class Compose(Randomizable, Transform):
class Compose(RandomizableTransform):
"""
``Compose`` provides the ability to chain a series of calls together in a
sequence. Each transform in the sequence must take a single argument and
Expand Down Expand Up @@ -96,14 +95,14 @@ def __init__(self, transforms: Optional[Union[Sequence[Callable], Callable]] = N
def set_random_state(self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None) -> "Compose":
super().set_random_state(seed=seed, state=state)
for _transform in self.transforms:
if not isinstance(_transform, Randomizable):
if not isinstance(_transform, RandomizableTransform):
continue
_transform.set_random_state(seed=self.R.randint(MAX_SEED, dtype="uint32"))
return self

def randomize(self, data: Optional[Any] = None) -> None:
for _transform in self.transforms:
if not isinstance(_transform, Randomizable):
if not isinstance(_transform, RandomizableTransform):
continue
try:
_transform.randomize(data)
Expand Down
10 changes: 5 additions & 5 deletions monai/transforms/croppad/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from monai.config import IndexSelection
from monai.data.utils import get_random_patch, get_valid_patch_size
from monai.transforms.transform import Randomizable, Transform
from monai.transforms.transform import Randomizable, RandomizableTransform, Transform
from monai.transforms.utils import (
generate_pos_neg_label_crop_centers,
generate_spatial_bounding_box,
Expand Down Expand Up @@ -276,7 +276,7 @@ def __call__(self, img: np.ndarray):
return cropper(img)


class RandSpatialCrop(Randomizable, Transform):
class RandSpatialCrop(RandomizableTransform):
"""
Crop image with random size or specific size ROI. It can crop at a random position as center
or at the image center. And allows to set the minimum size to limit the randomly generated ROI.
Expand Down Expand Up @@ -321,7 +321,7 @@ def __call__(self, img: np.ndarray):
return cropper(img)


class RandSpatialCropSamples(Randomizable, Transform):
class RandSpatialCropSamples(RandomizableTransform):
"""
Crop image with random size or specific size ROI to generate a list of N samples.
It can crop at a random position as center or at the image center. And allows to set
Expand Down Expand Up @@ -429,7 +429,7 @@ def __call__(self, img: np.ndarray):
return cropped


class RandWeightedCrop(Randomizable, Transform):
class RandWeightedCrop(RandomizableTransform):
"""
Samples a list of `num_samples` image patches according to the provided `weight_map`.

Expand Down Expand Up @@ -481,7 +481,7 @@ def __call__(self, img: np.ndarray, weight_map: Optional[np.ndarray] = None) ->
return results


class RandCropByPosNegLabel(Randomizable, Transform):
class RandCropByPosNegLabel(RandomizableTransform):
"""
Crop random fixed sized regions with the center being a foreground or background voxel
based on the Pos Neg Ratio.
Expand Down
22 changes: 13 additions & 9 deletions monai/transforms/croppad/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
SpatialCrop,
SpatialPad,
)
from monai.transforms.transform import MapTransform, Randomizable
from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform
from monai.transforms.utils import (
generate_pos_neg_label_crop_centers,
generate_spatial_bounding_box,
Expand Down Expand Up @@ -258,7 +258,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda
return d


class RandSpatialCropd(Randomizable, MapTransform):
class RandSpatialCropd(RandomizableTransform, MapTransform):
"""
Dictionary-based version :py:class:`monai.transforms.RandSpatialCrop`.
Crop image with random size or specific size ROI. It can crop at a random position as
Expand All @@ -283,7 +283,8 @@ def __init__(
random_center: bool = True,
random_size: bool = True,
) -> None:
super().__init__(keys)
RandomizableTransform.__init__(self)
MapTransform.__init__(self, keys)
self.roi_size = roi_size
self.random_center = random_center
self.random_size = random_size
Expand Down Expand Up @@ -312,7 +313,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda
return d


class RandSpatialCropSamplesd(Randomizable, MapTransform):
class RandSpatialCropSamplesd(RandomizableTransform, MapTransform):
"""
Dictionary-based version :py:class:`monai.transforms.RandSpatialCropSamples`.
Crop image with random size or specific size ROI to generate a list of N samples.
Expand Down Expand Up @@ -344,7 +345,8 @@ def __init__(
random_center: bool = True,
random_size: bool = True,
) -> None:
super().__init__(keys)
RandomizableTransform.__init__(self)
MapTransform.__init__(self, keys)
if num_samples < 1:
raise ValueError(f"num_samples must be positive, got {num_samples}.")
self.num_samples = num_samples
Expand Down Expand Up @@ -420,7 +422,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda
return d


class RandWeightedCropd(Randomizable, MapTransform):
class RandWeightedCropd(RandomizableTransform, MapTransform):
"""
Samples a list of `num_samples` image patches according to the provided `weight_map`.

Expand All @@ -446,7 +448,8 @@ def __init__(
num_samples: int = 1,
center_coord_key: Optional[str] = None,
):
super().__init__(keys)
RandomizableTransform.__init__(self)
MapTransform.__init__(self, keys)
self.spatial_size = ensure_tuple(spatial_size)
self.w_key = w_key
self.num_samples = int(num_samples)
Expand Down Expand Up @@ -484,7 +487,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n
return results


class RandCropByPosNegLabeld(Randomizable, MapTransform):
class RandCropByPosNegLabeld(RandomizableTransform, MapTransform):
"""
Dictionary-based version :py:class:`monai.transforms.RandCropByPosNegLabel`.
Crop random fixed sized regions with the center being a foreground or background voxel
Expand Down Expand Up @@ -534,7 +537,8 @@ def __init__(
fg_indices_key: Optional[str] = None,
bg_indices_key: Optional[str] = None,
) -> None:
super().__init__(keys)
RandomizableTransform.__init__(self)
MapTransform.__init__(self, keys)
self.label_key = label_key
self.spatial_size: Union[Tuple[int, ...], Sequence[int], int] = spatial_size
if pos < 0 or neg < 0:
Expand Down
Loading