Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
beab17d
backends -> backend
rijobro Aug 25, 2021
c3192ba
code format
rijobro Aug 25, 2021
f790ad7
code format2
rijobro Aug 25, 2021
1d28704
AddChannel, AsChannelFirst, AsChannelLast, EnsureChannelFirst, Identi…
rijobro Aug 25, 2021
fe0c787
Merge remote-tracking branch 'MONAI/dev' into utility_transforms
rijobro Aug 25, 2021
be7eac2
moveaxis backwards compatible
rijobro Aug 25, 2021
0d27527
code format
rijobro Aug 25, 2021
69ff653
Merge branch 'dev' into utility_transforms
wyli Aug 25, 2021
95622bc
Merge branch 'dev' into utility_transforms
wyli Aug 26, 2021
c9dcd8d
EnsureType, RemoveRepeatedChannel, SplitChannel, ToCupy, ToNumpy, ToP…
rijobro Aug 26, 2021
754a684
trigger ci/cd
rijobro Aug 26, 2021
e38e7a3
permute requires positive indices
rijobro Aug 26, 2021
0a1b4a5
Merge branch 'utility_transforms' into utility_transforms2
rijobro Aug 26, 2021
24c136d
Merge branch 'dev' into utility_transforms
wyli Aug 26, 2021
2971cdb
Merge branch 'dev' into utility_transforms
rijobro Aug 27, 2021
af2d2ec
correct permute
rijobro Aug 27, 2021
b1e476d
correct permute
rijobro Aug 27, 2021
dce485e
Merge branch 'utility_transforms' into utility_transforms2
rijobro Aug 27, 2021
baecdf8
Merge branch 'dev' into utility_transforms
rijobro Aug 27, 2021
4c05342
Merge remote-tracking branch 'rijobro/utility_transforms' into utilit…
rijobro Aug 27, 2021
3a9e170
has_pil
rijobro Aug 27, 2021
6363721
Merge branch 'dev' into utility_transforms2
rijobro Aug 27, 2021
e17e778
Merge remote-tracking branch 'MONAI/dev' into utility_transforms2
rijobro Aug 27, 2021
6694800
fixes flake8
wyli Aug 29, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 32 additions & 35 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,7 @@
map_classes_to_indices,
)
from monai.transforms.utils_pytorch_numpy_unification import moveaxis
from monai.utils import (
convert_to_numpy,
convert_to_tensor,
ensure_tuple,
issequenceiterable,
look_up_option,
min_version,
optional_import,
)
from monai.utils import convert_to_numpy, convert_to_tensor, ensure_tuple, look_up_option, min_version, optional_import
from monai.utils.enums import TransformBackends
from monai.utils.type_conversion import convert_data_type

Expand Down Expand Up @@ -255,20 +247,22 @@ class RemoveRepeatedChannel(Transform):
repeats: the number of repetitions to be deleted for each element.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, repeats: int) -> None:
if repeats <= 0:
raise AssertionError("repeats count must be greater than 0.")

self.repeats = repeats

def __call__(self, img: np.ndarray) -> np.ndarray:
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Apply the transform to `img`, assuming `img` is a "channel-first" array.
"""
if np.shape(img)[0] < 2:
if img.shape[0] < 2:
raise AssertionError("Image must have more than one channel")

return np.array(img[:: self.repeats, :])
return img[:: self.repeats, :]


class SplitChannel(Transform):
Expand All @@ -281,10 +275,12 @@ class SplitChannel(Transform):

"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, channel_dim: int = 0) -> None:
self.channel_dim = channel_dim

def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> List[Union[np.ndarray, torch.Tensor]]:
def __call__(self, img: NdarrayOrTensor) -> List[NdarrayOrTensor]:
n_classes = img.shape[self.channel_dim]
if n_classes <= 1:
raise RuntimeError("input image does not contain multiple channels.")
Expand Down Expand Up @@ -335,18 +331,13 @@ class ToTensor(Transform):
Converts the input image to a tensor without applying any other transformations.
"""

def __call__(self, img) -> torch.Tensor:
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __call__(self, img: NdarrayOrTensor) -> torch.Tensor:
"""
Apply the transform to `img` and make it contiguous.
"""
if isinstance(img, torch.Tensor):
return img.contiguous()
if issequenceiterable(img):
# numpy array with 0 dims is also sequence iterable
if not (isinstance(img, np.ndarray) and img.ndim == 0):
# `ascontiguousarray` will add 1 dim if img has no dim, so we only apply on data with dims
img = np.ascontiguousarray(img)
return torch.as_tensor(img)
return convert_to_tensor(img, wrap_sequence=True) # type: ignore


class EnsureType(Transform):
Expand All @@ -361,14 +352,16 @@ class EnsureType(Transform):

"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, data_type: str = "tensor") -> None:
data_type = data_type.lower()
if data_type not in ("tensor", "numpy"):
raise ValueError("`data type` must be 'tensor' or 'numpy'.")

self.data_type = data_type

def __call__(self, data):
def __call__(self, data: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Args:
data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc.
Expand All @@ -377,46 +370,46 @@ def __call__(self, data):
if applicable.

"""
return convert_to_tensor(data) if self.data_type == "tensor" else convert_to_numpy(data)
return convert_to_tensor(data) if self.data_type == "tensor" else convert_to_numpy(data) # type: ignore


class ToNumpy(Transform):
"""
Converts the input data to numpy array, can support list or tuple of numbers and PyTorch Tensor.
"""

def __call__(self, img) -> np.ndarray:
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __call__(self, img: NdarrayOrTensor) -> np.ndarray:
"""
Apply the transform to `img` and make it contiguous.
"""
if isinstance(img, torch.Tensor):
img = img.detach().cpu().numpy()
elif has_cp and isinstance(img, cp_ndarray):
img = cp.asnumpy(img)

array: np.ndarray = np.asarray(img)
return np.ascontiguousarray(array) if array.ndim > 0 else array
return convert_to_numpy(img) # type: ignore


class ToCupy(Transform):
"""
Converts the input data to CuPy array, can support list or tuple of numbers, NumPy and PyTorch Tensor.
"""

def __call__(self, img):
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Apply the transform to `img` and make it contiguous.
"""
if isinstance(img, torch.Tensor):
img = img.detach().cpu().numpy()
return cp.ascontiguousarray(cp.asarray(img))
return cp.ascontiguousarray(cp.asarray(img)) # type: ignore


class ToPIL(Transform):
"""
Converts the input image (in the form of NumPy array or PyTorch Tensor) to PIL image
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __call__(self, img):
"""
Apply the transform to `img`.
Expand All @@ -433,13 +426,17 @@ class Transpose(Transform):
Transposes the input image based on the given `indices` dimension ordering.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, indices: Optional[Sequence[int]]) -> None:
self.indices = None if indices is None else tuple(indices)

def __call__(self, img: np.ndarray) -> np.ndarray:
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Apply the transform to `img`.
"""
if isinstance(img, torch.Tensor):
return img.permute(self.indices or tuple(range(img.ndim)[::-1]))
return img.transpose(self.indices) # type: ignore


Expand Down
33 changes: 23 additions & 10 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,8 @@ class RemoveRepeatedChanneld(MapTransform):
Dictionary-based wrapper of :py:class:`monai.transforms.RemoveRepeatedChannel`.
"""

backend = RemoveRepeatedChannel.backend

def __init__(self, keys: KeysCollection, repeats: int, allow_missing_keys: bool = False) -> None:
"""
Args:
Expand All @@ -345,7 +347,7 @@ def __init__(self, keys: KeysCollection, repeats: int, allow_missing_keys: bool
super().__init__(keys, allow_missing_keys)
self.repeater = RemoveRepeatedChannel(repeats)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.repeater(d[key])
Expand All @@ -356,9 +358,10 @@ class SplitChanneld(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.SplitChannel`.
All the input specified by `keys` should be split into same count of data.

"""

backend = SplitChannel.backend

def __init__(
self,
keys: KeysCollection,
Expand All @@ -382,9 +385,7 @@ def __init__(
self.output_postfixes = output_postfixes
self.splitter = SplitChannel(channel_dim=channel_dim)

def __call__(
self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]]
) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
rets = self.splitter(d[key])
Expand Down Expand Up @@ -439,6 +440,8 @@ class ToTensord(MapTransform, InvertibleTransform):
Dictionary-based wrapper of :py:class:`monai.transforms.ToTensor`.
"""

backend = ToTensor.backend

def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None:
"""
Args:
Expand All @@ -449,14 +452,14 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No
super().__init__(keys, allow_missing_keys)
self.converter = ToTensor()

def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
self.push_transform(d, key)
d[key] = self.converter(d[key])
return d

def inverse(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]:
def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = deepcopy(dict(data))
for key in self.key_iterator(d):
# Create inverse transform
Expand All @@ -481,6 +484,8 @@ class EnsureTyped(MapTransform, InvertibleTransform):

"""

backend = EnsureType.backend

def __init__(self, keys: KeysCollection, data_type: str = "tensor", allow_missing_keys: bool = False) -> None:
"""
Args:
Expand All @@ -492,7 +497,7 @@ def __init__(self, keys: KeysCollection, data_type: str = "tensor", allow_missin
super().__init__(keys, allow_missing_keys)
self.converter = EnsureType(data_type=data_type)

def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
self.push_transform(d, key)
Expand All @@ -515,6 +520,8 @@ class ToNumpyd(MapTransform):
Dictionary-based wrapper of :py:class:`monai.transforms.ToNumpy`.
"""

backend = ToNumpy.backend

def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None:
"""
Args:
Expand All @@ -537,6 +544,8 @@ class ToCupyd(MapTransform):
Dictionary-based wrapper of :py:class:`monai.transforms.ToCupy`.
"""

backend = ToCupy.backend

def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None:
"""
Args:
Expand All @@ -547,7 +556,7 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No
super().__init__(keys, allow_missing_keys)
self.converter = ToCupy()

def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.converter(d[key])
Expand All @@ -559,6 +568,8 @@ class ToPILd(MapTransform):
Dictionary-based wrapper of :py:class:`monai.transforms.ToNumpy`.
"""

backend = ToPIL.backend

def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None:
"""
Args:
Expand All @@ -581,13 +592,15 @@ class Transposed(MapTransform, InvertibleTransform):
Dictionary-based wrapper of :py:class:`monai.transforms.Transpose`.
"""

backend = Transpose.backend

def __init__(
self, keys: KeysCollection, indices: Optional[Sequence[int]], allow_missing_keys: bool = False
) -> None:
super().__init__(keys, allow_missing_keys)
self.transform = Transpose(indices)

def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.transform(d[key])
Expand Down
21 changes: 14 additions & 7 deletions monai/utils/type_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def get_dtype(data: Any):
return type(data)


def convert_to_tensor(data):
def convert_to_tensor(data, wrap_sequence: bool = False):
"""
Utility to convert the input data to a PyTorch Tensor. If passing a dictionary, list or tuple,
recursively check every item and convert it to PyTorch Tensor.
Expand All @@ -92,6 +92,8 @@ def convert_to_tensor(data):
data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc.
will convert Tensor, Numpy array, float, int, bool to Tensors, strings and objects keep the original.
for dictionary, list or tuple, convert every item to a Tensor if applicable.
wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[tensor(1), tensor(2)]`.
If `True`, then `[1, 2]` -> `tensor([1, 2])`.

"""
if isinstance(data, torch.Tensor):
Expand All @@ -105,17 +107,19 @@ def convert_to_tensor(data):
return torch.as_tensor(data if data.ndim == 0 else np.ascontiguousarray(data))
elif isinstance(data, (float, int, bool)):
return torch.as_tensor(data)
elif isinstance(data, dict):
return {k: convert_to_tensor(v) for k, v in data.items()}
elif isinstance(data, Sequence) and wrap_sequence:
return torch.as_tensor(data)
elif isinstance(data, list):
return [convert_to_tensor(i) for i in data]
elif isinstance(data, tuple):
return tuple(convert_to_tensor(i) for i in data)
elif isinstance(data, dict):
return {k: convert_to_tensor(v) for k, v in data.items()}

return data


def convert_to_numpy(data):
def convert_to_numpy(data, wrap_sequence: bool = False):
"""
Utility to convert the input data to a numpy array. If passing a dictionary, list or tuple,
recursively check every item and convert it to numpy array.
Expand All @@ -124,20 +128,23 @@ def convert_to_numpy(data):
data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc.
will convert Tensor, Numpy array, float, int, bool to numpy arrays, strings and objects keep the original.
for dictionary, list or tuple, convert every item to a numpy array if applicable.

wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[array(1), array(2)]`.
If `True`, then `[1, 2]` -> `array([1, 2])`.
"""
if isinstance(data, torch.Tensor):
data = data.detach().cpu().numpy()
elif has_cp and isinstance(data, cp_ndarray):
data = cp.asnumpy(data)
elif isinstance(data, (float, int, bool)):
data = np.asarray(data)
elif isinstance(data, dict):
return {k: convert_to_numpy(v) for k, v in data.items()}
elif isinstance(data, Sequence) and wrap_sequence:
return np.asarray(data)
elif isinstance(data, list):
return [convert_to_numpy(i) for i in data]
elif isinstance(data, tuple):
return tuple(convert_to_numpy(i) for i in data)
elif isinstance(data, dict):
return {k: convert_to_numpy(v) for k, v in data.items()}

if isinstance(data, np.ndarray) and data.ndim > 0:
data = np.ascontiguousarray(data)
Expand Down
Loading