Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion monai/apps/deepgrow/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from monai.config import IndexSelection, KeysCollection
from monai.networks.layers import GaussianFilter
from monai.transforms import SpatialCrop
from monai.transforms.compose import MapTransform, Randomizable, Transform
from monai.transforms.transform import MapTransform, Randomizable, Transform
from monai.transforms.utils import generate_spatial_bounding_box
from monai.utils import min_version, optional_import

Expand Down
3 changes: 2 additions & 1 deletion monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.

from .adaptors import FunctionSignature, adaptor, apply_alias, to_kwargs
from .compose import Compose, MapTransform, Randomizable, Transform
from .compose import Compose
from .croppad.array import (
BorderPad,
BoundingRect,
Expand Down Expand Up @@ -234,6 +234,7 @@
ZoomD,
ZoomDict,
)
from .transform import MapTransform, Randomizable, Transform
from .utility.array import (
AddChannel,
AddExtremePointsChannel,
Expand Down
192 changes: 5 additions & 187 deletions monai/transforms/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,135 +13,17 @@
"""

import warnings
from abc import ABC, abstractmethod
from typing import Any, Callable, Hashable, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Optional, Sequence, Union

import numpy as np

from monai.config import KeysCollection
# For backwards compatiblity (so this still works: from monai.transforms.compose import MapTransform)
from monai.transforms.transform import MapTransform # noqa: F401
from monai.transforms.transform import Randomizable, Transform
from monai.transforms.utils import apply_transform
from monai.utils import MAX_SEED, ensure_tuple, get_seed

__all__ = ["Transform", "Randomizable", "Compose", "MapTransform"]


class Transform(ABC):
"""
An abstract class of a ``Transform``.
A transform is callable that processes ``data``.

It could be stateful and may modify ``data`` in place,
the implementation should be aware of:

#. thread safety when mutating its own states.
When used from a multi-process context, transform's instance variables are read-only.
#. ``data`` content unused by this transform may still be used in the
subsequent transforms in a composed transform.
#. storing too much information in ``data`` may not scale.

See Also

:py:class:`monai.transforms.Compose`
"""

@abstractmethod
def __call__(self, data: Any):
"""
``data`` is an element which often comes from an iteration over an
iterable, such as :py:class:`torch.utils.data.Dataset`. This method should
return an updated version of ``data``.
To simplify the input validations, most of the transforms assume that

- ``data`` is a Numpy ndarray, PyTorch Tensor or string
- the data shape can be:

#. string data without shape, `LoadImage` transform expects file paths
#. most of the pre-processing transforms expect: ``(num_channels, spatial_dim_1[, spatial_dim_2, ...])``,
except that `AddChannel` expects (spatial_dim_1[, spatial_dim_2, ...]) and
`AsChannelFirst` expects (spatial_dim_1[, spatial_dim_2, ...], num_channels)
#. most of the post-processing transforms expect
``(batch_size, num_channels, spatial_dim_1[, spatial_dim_2, ...])``

- the channel dimension is not omitted even if number of channels is one

This method can optionally take additional arguments to help execute transformation operation.

Raises:
NotImplementedError: When the subclass does not override this method.

"""
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")


class Randomizable(ABC):
"""
An interface for handling random state locally, currently based on a class variable `R`,
which is an instance of `np.random.RandomState`.
This is mainly for randomized data augmentation transforms. For example::

class RandShiftIntensity(Randomizable):
def randomize():
self._offset = self.R.uniform(low=0, high=100)
def __call__(self, img):
self.randomize()
return img + self._offset

transform = RandShiftIntensity()
transform.set_random_state(seed=0)

"""

R: np.random.RandomState = np.random.RandomState()

def set_random_state(
self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None
) -> "Randomizable":
"""
Set the random state locally, to control the randomness, the derived
classes should use :py:attr:`self.R` instead of `np.random` to introduce random
factors.

Args:
seed: set the random state with an integer seed.
state: set the random state with a `np.random.RandomState` object.

Raises:
TypeError: When ``state`` is not an ``Optional[np.random.RandomState]``.

Returns:
a Randomizable instance.

"""
if seed is not None:
_seed = id(seed) if not isinstance(seed, (int, np.integer)) else seed
_seed = _seed % MAX_SEED
self.R = np.random.RandomState(_seed)
return self

if state is not None:
if not isinstance(state, np.random.RandomState):
raise TypeError(f"state must be None or a np.random.RandomState but is {type(state).__name__}.")
self.R = state
return self

self.R = np.random.RandomState()
return self

@abstractmethod
def randomize(self, data: Any) -> None:
"""
Within this method, :py:attr:`self.R` should be used, instead of `np.random`, to introduce random factors.

all :py:attr:`self.R` calls happen here so that we have a better chance to
identify errors of sync the random state.

This method can generate the random factors based on properties of the input data.

Raises:
NotImplementedError: When the subclass does not override this method.

"""
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
__all__ = ["Compose"]


class Compose(Randomizable, Transform):
Expand Down Expand Up @@ -255,67 +137,3 @@ def __call__(self, input_):
for _transform in self.transforms:
input_ = apply_transform(_transform, input_)
return input_


class MapTransform(Transform):
"""
A subclass of :py:class:`monai.transforms.Transform` with an assumption
that the ``data`` input of ``self.__call__`` is a MutableMapping such as ``dict``.

The ``keys`` parameter will be used to get and set the actual data
item to transform. That is, the callable of this transform should
follow the pattern:

.. code-block:: python

def __call__(self, data):
for key in self.keys:
if key in data:
# update output data with some_transform_function(data[key]).
else:
# do nothing or some exceptions handling.
return data

Raises:
ValueError: When ``keys`` is an empty iterable.
TypeError: When ``keys`` type is not in ``Union[Hashable, Iterable[Hashable]]``.

"""

def __init__(self, keys: KeysCollection) -> None:
self.keys: Tuple[Hashable, ...] = ensure_tuple(keys)
if not self.keys:
raise ValueError("keys must be non empty.")
for key in self.keys:
if not isinstance(key, Hashable):
raise TypeError(f"keys must be one of (Hashable, Iterable[Hashable]) but is {type(keys).__name__}.")

@abstractmethod
def __call__(self, data):
"""
``data`` often comes from an iteration over an iterable,
such as :py:class:`torch.utils.data.Dataset`.

To simplify the input validations, this method assumes:

- ``data`` is a Python dictionary
- ``data[key]`` is a Numpy ndarray, PyTorch Tensor or string, where ``key`` is an element
of ``self.keys``, the data shape can be:

#. string data without shape, `LoadImaged` transform expects file paths
#. most of the pre-processing transforms expect: ``(num_channels, spatial_dim_1[, spatial_dim_2, ...])``,
except that `AddChanneld` expects (spatial_dim_1[, spatial_dim_2, ...]) and
`AsChannelFirstd` expects (spatial_dim_1[, spatial_dim_2, ...], num_channels)
#. most of the post-processing transforms expect
``(batch_size, num_channels, spatial_dim_1[, spatial_dim_2, ...])``

- the channel dimension is not omitted even if number of channels is one

Raises:
NotImplementedError: When the subclass does not override this method.

returns:
An updated dictionary version of ``data`` by applying the transform.

"""
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
2 changes: 1 addition & 1 deletion monai/transforms/croppad/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from monai.config import IndexSelection
from monai.data.utils import get_random_patch, get_valid_patch_size
from monai.transforms.compose import Randomizable, Transform
from monai.transforms.transform import Randomizable, Transform
from monai.transforms.utils import (
generate_pos_neg_label_crop_centers,
generate_spatial_bounding_box,
Expand Down
2 changes: 1 addition & 1 deletion monai/transforms/croppad/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

from monai.config import IndexSelection, KeysCollection
from monai.data.utils import get_random_patch, get_valid_patch_size
from monai.transforms.compose import MapTransform, Randomizable
from monai.transforms.croppad.array import (
BorderPad,
BoundingRect,
Expand All @@ -31,6 +30,7 @@
SpatialCrop,
SpatialPad,
)
from monai.transforms.transform import MapTransform, Randomizable
from monai.transforms.utils import (
generate_pos_neg_label_crop_centers,
generate_spatial_bounding_box,
Expand Down
2 changes: 1 addition & 1 deletion monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from monai.config import DtypeLike
from monai.networks.layers import GaussianFilter, HilbertTransform, SavitzkyGolayFilter
from monai.transforms.compose import Randomizable, Transform
from monai.transforms.transform import Randomizable, Transform
from monai.transforms.utils import rescale_array
from monai.utils import PT_BEFORE_1_7, InvalidPyTorchVersionError, dtype_torch_to_numpy, ensure_tuple_size

Expand Down
2 changes: 1 addition & 1 deletion monai/transforms/intensity/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import torch

from monai.config import DtypeLike, KeysCollection
from monai.transforms.compose import MapTransform, Randomizable
from monai.transforms.intensity.array import (
AdjustContrast,
GaussianSharpen,
Expand All @@ -35,6 +34,7 @@
ShiftIntensity,
ThresholdIntensity,
)
from monai.transforms.transform import MapTransform, Randomizable
from monai.utils import dtype_torch_to_numpy, ensure_tuple_rep, ensure_tuple_size

__all__ = [
Expand Down
2 changes: 1 addition & 1 deletion monai/transforms/io/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from monai.data.image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader
from monai.data.nifti_saver import NiftiSaver
from monai.data.png_saver import PNGSaver
from monai.transforms.compose import Transform
from monai.transforms.transform import Transform
from monai.utils import GridSampleMode, GridSamplePadMode
from monai.utils import ImageMetaKey as Key
from monai.utils import InterpolateMode, ensure_tuple, optional_import
Expand Down
2 changes: 1 addition & 1 deletion monai/transforms/io/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

from monai.config import DtypeLike, KeysCollection
from monai.data.image_reader import ImageReader
from monai.transforms.compose import MapTransform
from monai.transforms.io.array import LoadImage, SaveImage
from monai.transforms.transform import MapTransform
from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode

__all__ = [
Expand Down
2 changes: 1 addition & 1 deletion monai/transforms/post/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import torch.nn.functional as F

from monai.networks import one_hot
from monai.transforms.compose import Transform
from monai.transforms.transform import Transform
from monai.transforms.utils import get_largest_connected_component_mask
from monai.utils import ensure_tuple

Expand Down
2 changes: 1 addition & 1 deletion monai/transforms/post/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import torch

from monai.config import KeysCollection
from monai.transforms.compose import MapTransform
from monai.transforms.post.array import (
Activations,
AsDiscrete,
Expand All @@ -30,6 +29,7 @@
MeanEnsemble,
VoteEnsemble,
)
from monai.transforms.transform import MapTransform
from monai.utils import ensure_tuple_rep

__all__ = [
Expand Down
2 changes: 1 addition & 1 deletion monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from monai.config import USE_COMPILED, DtypeLike
from monai.data.utils import compute_shape_offset, to_affine_nd, zoom_affine
from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull
from monai.transforms.compose import Randomizable, Transform
from monai.transforms.croppad.array import CenterSpatialCrop
from monai.transforms.transform import Randomizable, Transform
from monai.transforms.utils import (
create_control_grid,
create_grid,
Expand Down
2 changes: 1 addition & 1 deletion monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

from monai.config import DtypeLike, KeysCollection
from monai.networks.layers.simplelayers import GaussianFilter
from monai.transforms.compose import MapTransform, Randomizable
from monai.transforms.croppad.array import CenterSpatialCrop
from monai.transforms.spatial.array import (
Flip,
Expand All @@ -36,6 +35,7 @@
Spacing,
Zoom,
)
from monai.transforms.transform import MapTransform, Randomizable
from monai.transforms.utils import create_grid
from monai.utils import (
GridSampleMode,
Expand Down
Loading