Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,7 @@
clip,
floor_divide,
in1d,
maximum,
moveaxis,
nonzero,
percentile,
Expand Down
78 changes: 44 additions & 34 deletions monai/transforms/croppad/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torch.nn.functional import pad as pad_pt

from monai.config import IndexSelection
from monai.config.type_definitions import NdarrayTensor
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.utils import get_random_patch, get_valid_patch_size
from monai.transforms.transform import Randomizable, Transform
from monai.transforms.utils import (
Expand All @@ -36,6 +36,7 @@
map_classes_to_indices,
weighted_patch_samples,
)
from monai.transforms.utils_pytorch_numpy_unification import floor_divide, maximum
from monai.utils import (
Method,
NumpyPadMode,
Expand Down Expand Up @@ -98,8 +99,7 @@ def __init__(

@staticmethod
def _np_pad(img: np.ndarray, all_pad_width, mode, **kwargs) -> np.ndarray:
img_np, *_ = convert_data_type(img, np.ndarray)
return np.pad(img_np, all_pad_width, mode=mode, **kwargs) # type: ignore
return np.pad(img, all_pad_width, mode=mode, **kwargs) # type: ignore

@staticmethod
def _pt_pad(img: torch.Tensor, all_pad_width, mode, **kwargs) -> torch.Tensor:
Expand All @@ -109,9 +109,9 @@ def _pt_pad(img: torch.Tensor, all_pad_width, mode, **kwargs) -> torch.Tensor:

def __call__(
self,
img: NdarrayTensor,
img: NdarrayOrTensor,
mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None,
) -> NdarrayTensor:
) -> NdarrayOrTensor:
"""
Args:
img: data to be transformed, assuming `img` is channel-first and
Expand All @@ -132,7 +132,7 @@ def __call__(
pad = self._pt_pad
else:
pad = self._np_pad # type: ignore
return pad(img, self.to_pad, mode, **self.kwargs)
return pad(img, self.to_pad, mode, **self.kwargs) # type: ignore


class SpatialPad(Transform):
Expand Down Expand Up @@ -190,9 +190,9 @@ def _determine_data_pad_width(self, data_shape: Sequence[int]) -> List[Tuple[int

def __call__(
self,
img: NdarrayTensor,
img: NdarrayOrTensor,
mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None,
) -> NdarrayTensor:
) -> NdarrayOrTensor:
"""
Args:
img: data to be transformed, assuming `img` is channel-first and
Expand Down Expand Up @@ -255,9 +255,9 @@ def __init__(

def __call__(
self,
img: NdarrayTensor,
img: NdarrayOrTensor,
mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None,
) -> NdarrayTensor:
) -> NdarrayOrTensor:
"""
Args:
img: data to be transformed, assuming `img` is channel-first and
Expand Down Expand Up @@ -337,9 +337,9 @@ def __init__(

def __call__(
self,
img: NdarrayTensor,
img: NdarrayOrTensor,
mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None,
) -> NdarrayTensor:
) -> NdarrayOrTensor:
"""
Args:
img: data to be transformed, assuming `img` is channel-first
Expand Down Expand Up @@ -377,12 +377,14 @@ class SpatialCrop(Transform):
- the start and end coordinates of the ROI
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(
self,
roi_center: Union[Sequence[int], np.ndarray, None] = None,
roi_size: Union[Sequence[int], np.ndarray, None] = None,
roi_start: Union[Sequence[int], np.ndarray, None] = None,
roi_end: Union[Sequence[int], np.ndarray, None] = None,
roi_center: Union[Sequence[int], NdarrayOrTensor, None] = None,
roi_size: Union[Sequence[int], NdarrayOrTensor, None] = None,
roi_start: Union[Sequence[int], NdarrayOrTensor, None] = None,
roi_end: Union[Sequence[int], NdarrayOrTensor, None] = None,
roi_slices: Optional[Sequence[slice]] = None,
) -> None:
"""
Expand All @@ -395,33 +397,38 @@ def __init__(
use the end coordinate of image.
roi_slices: list of slices for each of the spatial dimensions.
"""
roi_start_torch: torch.Tensor

if roi_slices:
if not all(s.step is None or s.step == 1 for s in roi_slices):
raise ValueError("Only slice steps of 1/None are currently supported")
self.slices = list(roi_slices)
else:
if roi_center is not None and roi_size is not None:
roi_center = np.asarray(roi_center, dtype=np.int16)
roi_size = np.asarray(roi_size, dtype=np.int16)
roi_start_np = np.maximum(roi_center - np.floor_divide(roi_size, 2), 0)
roi_end_np = np.maximum(roi_start_np + roi_size, roi_start_np)
roi_center = torch.as_tensor(roi_center, dtype=torch.int16)
roi_size = torch.as_tensor(roi_size, dtype=torch.int16, device=roi_center.device)
roi_start_torch = maximum( # type: ignore
roi_center - floor_divide(roi_size, 2),
torch.zeros_like(roi_center),
)
roi_end_torch = maximum(roi_start_torch + roi_size, roi_start_torch)
else:
if roi_start is None or roi_end is None:
raise ValueError("Please specify either roi_center, roi_size or roi_start, roi_end.")
roi_start_np = np.maximum(np.asarray(roi_start, dtype=np.int16), 0)
roi_end_np = np.maximum(np.asarray(roi_end, dtype=np.int16), roi_start_np)
# Allow for 1D by converting back to np.array (since np.maximum will convert to int)
roi_start_np = roi_start_np if isinstance(roi_start_np, np.ndarray) else np.array([roi_start_np])
roi_end_np = roi_end_np if isinstance(roi_end_np, np.ndarray) else np.array([roi_end_np])
# convert to slices
self.slices = [slice(s, e) for s, e in zip(roi_start_np, roi_end_np)]

def __call__(self, img: Union[np.ndarray, torch.Tensor]):
roi_start_torch = torch.as_tensor(roi_start, dtype=torch.int16)
roi_start_torch = maximum(roi_start_torch, torch.zeros_like(roi_start_torch)) # type: ignore
roi_end_torch = maximum(torch.as_tensor(roi_end, dtype=torch.int16), roi_start_torch)
# convert to slices (accounting for 1d)
if roi_start_torch.numel() == 1:
self.slices = [slice(int(roi_start_torch.item()), int(roi_end_torch.item()))]
else:
self.slices = [slice(int(s.item()), int(e.item())) for s, e in zip(roi_start_torch, roi_end_torch)]

def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Apply the transform to `img`, assuming `img` is channel-first and
slicing doesn't apply to the channel dim.
"""
img, *_ = convert_data_type(img, np.ndarray)
sd = min(len(self.slices), len(img.shape[1:])) # spatial dims
slices = [slice(None)] + self.slices[:sd]
return img[tuple(slices)]
Expand Down Expand Up @@ -822,7 +829,8 @@ def __call__(self, img: np.ndarray, weight_map: Optional[np.ndarray] = None) ->
results = []
for center in self.centers:
cropper = SpatialCrop(roi_center=center, roi_size=_spatial_size)
results.append(cropper(img))
cropped: np.ndarray = cropper(img) # type: ignore
results.append(cropped)
return results


Expand Down Expand Up @@ -962,7 +970,8 @@ def __call__(
if self.centers is not None:
for center in self.centers:
cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore
results.append(cropper(img))
cropped: np.ndarray = cropper(img) # type: ignore
results.append(cropped)

return results

Expand Down Expand Up @@ -1098,7 +1107,8 @@ def __call__(
if self.centers is not None:
for center in self.centers:
cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore
results.append(cropper(img))
cropped: np.ndarray = cropper(img) # type: ignore
results.append(cropped)

return results

Expand Down Expand Up @@ -1146,7 +1156,7 @@ def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = N
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
"""
img, *_ = convert_data_type(img, np.ndarray) # type: ignore
return self.padder(self.cropper(img), mode=mode)
return self.padder(self.cropper(img), mode=mode) # type: ignore


class BoundingRect(Transform):
Expand Down
47 changes: 26 additions & 21 deletions monai/transforms/croppad/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import numpy as np

from monai.config import IndexSelection, KeysCollection
from monai.config.type_definitions import NdarrayTensor
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.utils import get_random_patch, get_valid_patch_size
from monai.transforms.croppad.array import (
BorderPad,
Expand Down Expand Up @@ -147,14 +147,14 @@ def __init__(
self.mode = ensure_tuple_rep(mode, len(self.keys))
self.padder = SpatialPad(spatial_size, method, **kwargs)

def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key, m in self.key_iterator(d, self.mode):
self.push_transform(d, key, extra_info={"mode": m.value if isinstance(m, Enum) else m})
d[key] = self.padder(d[key], mode=m)
return d

def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = deepcopy(dict(data))
for key in self.key_iterator(d):
transform = self.get_most_recent_transform(d, key)
Expand Down Expand Up @@ -222,14 +222,14 @@ def __init__(
self.mode = ensure_tuple_rep(mode, len(self.keys))
self.padder = BorderPad(spatial_border=spatial_border, **kwargs)

def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key, m in self.key_iterator(d, self.mode):
self.push_transform(d, key, extra_info={"mode": m.value if isinstance(m, Enum) else m})
d[key] = self.padder(d[key], mode=m)
return d

def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = deepcopy(dict(data))

for key in self.key_iterator(d):
Expand Down Expand Up @@ -298,14 +298,14 @@ def __init__(
self.mode = ensure_tuple_rep(mode, len(self.keys))
self.padder = DivisiblePad(k=k, method=method, **kwargs)

def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key, m in self.key_iterator(d, self.mode):
self.push_transform(d, key, extra_info={"mode": m.value if isinstance(m, Enum) else m})
d[key] = self.padder(d[key], mode=m)
return d

def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = deepcopy(dict(data))

for key in self.key_iterator(d):
Expand Down Expand Up @@ -339,6 +339,8 @@ class SpatialCropd(MapTransform, InvertibleTransform):
- the start and end coordinates of the ROI
"""

backend = SpatialCrop.backend

def __init__(
self,
keys: KeysCollection,
Expand All @@ -365,14 +367,14 @@ def __init__(
super().__init__(keys, allow_missing_keys)
self.cropper = SpatialCrop(roi_center, roi_size, roi_start, roi_end, roi_slices)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
self.push_transform(d, key)
d[key] = self.cropper(d[key])
return d

def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = deepcopy(dict(data))

for key in self.key_iterator(d):
Expand Down Expand Up @@ -426,7 +428,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda
self.push_transform(d, key, orig_size=orig_size)
return d

def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = deepcopy(dict(data))

for key in self.key_iterator(d):
Expand Down Expand Up @@ -481,7 +483,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda

return d

def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = deepcopy(dict(data))

for key in self.key_iterator(d):
Expand Down Expand Up @@ -576,7 +578,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda
d[key] = cropper(d[key])
return d

def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = deepcopy(dict(data))

for key in self.key_iterator(d):
Expand Down Expand Up @@ -772,7 +774,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n
ret.append(cropped)
return ret

def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def inverse(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]:
d = deepcopy(dict(data))
# We changed the transform name from RandSpatialCropd to RandSpatialCropSamplesd
# Need to revert that since we're calling RandSpatialCropd's inverse
Expand Down Expand Up @@ -859,7 +861,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda
d[key] = self.cropper.crop_pad(img=d[key], box_start=box_start, box_end=box_end, mode=m)
return d

def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = deepcopy(dict(data))
for key in self.key_iterator(d):
transform = self.get_most_recent_transform(d, key)
Expand Down Expand Up @@ -964,7 +966,8 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n
for i, center in enumerate(self.centers):
cropper = SpatialCrop(roi_center=center, roi_size=_spatial_size)
orig_size = img.shape[1:]
results[i][key] = cropper(img)
cropped: np.ndarray = cropper(img) # type: ignore
results[i][key] = cropped
self.push_transform(results[i], key, extra_info={"center": center}, orig_size=orig_size)
if self.center_coord_key:
results[i][self.center_coord_key] = center
Expand All @@ -979,7 +982,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n

return results

def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = deepcopy(dict(data))
for key in self.key_iterator(d):
transform = self.get_most_recent_transform(d, key)
Expand Down Expand Up @@ -1136,7 +1139,8 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n
img = d[key]
cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore
orig_size = img.shape[1:]
results[i][key] = cropper(img)
cropped: np.ndarray = cropper(img) # type: ignore
results[i][key] = cropped
self.push_transform(results[i], key, extra_info={"center": center}, orig_size=orig_size)
# add `patch_index` to the meta data
for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix):
Expand All @@ -1147,7 +1151,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n

return results

def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = deepcopy(dict(data))
for key in self.key_iterator(d):
transform = self.get_most_recent_transform(d, key)
Expand Down Expand Up @@ -1315,7 +1319,8 @@ def __call__(self, data: Mapping[Hashable, Any]) -> List[Dict[Hashable, np.ndarr
img = d[key]
cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore
orig_size = img.shape[1:]
results[i][key] = cropper(img)
cropped: np.ndarray = cropper(img) # type: ignore
results[i][key] = cropped
self.push_transform(results[i], key, extra_info={"center": center}, orig_size=orig_size)
# add `patch_index` to the meta data
for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix):
Expand All @@ -1326,7 +1331,7 @@ def __call__(self, data: Mapping[Hashable, Any]) -> List[Dict[Hashable, np.ndarr

return results

def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = deepcopy(dict(data))
for key in self.key_iterator(d):
transform = self.get_most_recent_transform(d, key)
Expand Down Expand Up @@ -1399,7 +1404,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda
)
return d

def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = deepcopy(dict(data))
for key in self.key_iterator(d):
transform = self.get_most_recent_transform(d, key)
Expand Down
Loading