Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion monai/apps/detection/transforms/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,7 +1071,7 @@ def __init__(
if len(self.image_keys) != len(self.meta_keys):
raise ValueError("meta_keys should have the same length as keys.")
self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.image_keys))
self.centers: list[list[int]] | None = None
self.centers: tuple[tuple] | None = None
self.allow_smaller = allow_smaller

def generate_fg_center_boxes_np(self, boxes: NdarrayOrTensor, image_size: Sequence[int]) -> np.ndarray:
Expand Down
2 changes: 1 addition & 1 deletion monai/apps/detection/utils/detector_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def pad_images(
max_spatial_size = compute_divisible_spatial_size(spatial_shape=list(max_spatial_size_t), k=size_divisible)

# allocate memory for the padded images
images = torch.zeros([len(image_sizes), in_channels] + max_spatial_size, dtype=dtype, device=device)
images = torch.zeros([len(image_sizes), in_channels] + list(max_spatial_size), dtype=dtype, device=device)

# Use `SpatialPad` to match sizes, padding in the end will not affect boxes
padder = SpatialPad(spatial_size=max_spatial_size, method="end", mode=mode, **kwargs)
Expand Down
75 changes: 40 additions & 35 deletions monai/transforms/croppad/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,13 @@ class Pad(InvertibleTransform, LazyTransform):
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(
self, to_pad: list[tuple[int, int]] | None = None, mode: str = PytorchPadMode.CONSTANT, **kwargs
self, to_pad: tuple[tuple[int, int]] | None = None, mode: str = PytorchPadMode.CONSTANT, **kwargs
) -> None:
self.to_pad = to_pad
self.mode = mode
self.kwargs = kwargs

def compute_pad_width(self, spatial_shape: Sequence[int]) -> list[tuple[int, int]]:
def compute_pad_width(self, spatial_shape: Sequence[int]) -> tuple[tuple[int, int]]:
"""
dynamically compute the pad width according to the spatial shape.
the output is the amount of padding for all dimensions including the channel.
Expand All @@ -123,8 +123,8 @@ def compute_pad_width(self, spatial_shape: Sequence[int]) -> list[tuple[int, int
"""
raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.")

def __call__( # type: ignore
self, img: torch.Tensor, to_pad: list[tuple[int, int]] | None = None, mode: str | None = None, **kwargs
def __call__( # type: ignore[override]
self, img: torch.Tensor, to_pad: tuple[tuple[int, int]] | None = None, mode: str | None = None, **kwargs
) -> torch.Tensor:
"""
Args:
Expand All @@ -150,7 +150,7 @@ def __call__( # type: ignore
kwargs_.update(kwargs)

img_t = convert_to_tensor(data=img, track_meta=get_track_meta())
return pad_func(img_t, to_pad_, mode_, self.get_transform_info(), kwargs_) # type: ignore
return pad_func(img_t, to_pad_, mode_, self.get_transform_info(), kwargs_)

def inverse(self, data: MetaTensor) -> MetaTensor:
transform = self.pop_transform(data)
Expand Down Expand Up @@ -200,7 +200,7 @@ def __init__(
self.method: Method = look_up_option(method, Method)
super().__init__(mode=mode, **kwargs)

def compute_pad_width(self, spatial_shape: Sequence[int]) -> list[tuple[int, int]]:
def compute_pad_width(self, spatial_shape: Sequence[int]) -> tuple[tuple[int, int]]:
"""
dynamically compute the pad width according to the spatial shape.

Expand All @@ -213,10 +213,10 @@ def compute_pad_width(self, spatial_shape: Sequence[int]) -> list[tuple[int, int
pad_width = []
for i, sp_i in enumerate(spatial_size):
width = max(sp_i - spatial_shape[i], 0)
pad_width.append((width // 2, width - (width // 2)))
pad_width.append((int(width // 2), int(width - (width // 2))))
else:
pad_width = [(0, max(sp_i - spatial_shape[i], 0)) for i, sp_i in enumerate(spatial_size)]
return [(0, 0)] + pad_width
pad_width = [(0, int(max(sp_i - spatial_shape[i], 0))) for i, sp_i in enumerate(spatial_size)]
return tuple([(0, 0)] + pad_width) # type: ignore


class BorderPad(Pad):
Expand Down Expand Up @@ -249,24 +249,26 @@ def __init__(self, spatial_border: Sequence[int] | int, mode: str = PytorchPadMo
self.spatial_border = spatial_border
super().__init__(mode=mode, **kwargs)

def compute_pad_width(self, spatial_shape: Sequence[int]) -> list[tuple[int, int]]:
def compute_pad_width(self, spatial_shape: Sequence[int]) -> tuple[tuple[int, int]]:
spatial_border = ensure_tuple(self.spatial_border)
if not all(isinstance(b, int) for b in spatial_border):
raise ValueError(f"self.spatial_border must contain only ints, got {spatial_border}.")
spatial_border = tuple(max(0, b) for b in spatial_border)

if len(spatial_border) == 1:
data_pad_width = [(spatial_border[0], spatial_border[0]) for _ in spatial_shape]
data_pad_width = [(int(spatial_border[0]), int(spatial_border[0])) for _ in spatial_shape]
elif len(spatial_border) == len(spatial_shape):
data_pad_width = [(sp, sp) for sp in spatial_border[: len(spatial_shape)]]
data_pad_width = [(int(sp), int(sp)) for sp in spatial_border[: len(spatial_shape)]]
elif len(spatial_border) == len(spatial_shape) * 2:
data_pad_width = [(spatial_border[2 * i], spatial_border[2 * i + 1]) for i in range(len(spatial_shape))]
data_pad_width = [
(int(spatial_border[2 * i]), int(spatial_border[2 * i + 1])) for i in range(len(spatial_shape))
]
else:
raise ValueError(
f"Unsupported spatial_border length: {len(spatial_border)}, available options are "
f"[1, len(spatial_shape)={len(spatial_shape)}, 2*len(spatial_shape)={2*len(spatial_shape)}]."
)
return [(0, 0)] + data_pad_width
return tuple([(0, 0)] + data_pad_width) # type: ignore


class DivisiblePad(Pad):
Expand Down Expand Up @@ -301,7 +303,7 @@ def __init__(
self.method: Method = Method(method)
super().__init__(mode=mode, **kwargs)

def compute_pad_width(self, spatial_shape: Sequence[int]) -> list[tuple[int, int]]:
def compute_pad_width(self, spatial_shape: Sequence[int]) -> tuple[tuple[int, int]]:
new_size = compute_divisible_spatial_size(spatial_shape=spatial_shape, k=self.k)
spatial_pad = SpatialPad(spatial_size=new_size, method=self.method)
return spatial_pad.compute_pad_width(spatial_shape)
Expand All @@ -322,7 +324,7 @@ def compute_slices(
roi_start: Sequence[int] | NdarrayOrTensor | None = None,
roi_end: Sequence[int] | NdarrayOrTensor | None = None,
roi_slices: Sequence[slice] | None = None,
):
) -> tuple[slice]:
"""
Compute the crop slices based on specified `center & size` or `start & end` or `slices`.

Expand All @@ -340,8 +342,8 @@ def compute_slices(

if roi_slices:
if not all(s.step is None or s.step == 1 for s in roi_slices):
raise ValueError("only slice steps of 1/None are currently supported")
return list(roi_slices)
raise ValueError(f"only slice steps of 1/None are currently supported, got {roi_slices}.")
return ensure_tuple(roi_slices) # type: ignore
else:
if roi_center is not None and roi_size is not None:
roi_center_t = convert_to_tensor(data=roi_center, dtype=torch.int16, wrap_sequence=True, device="cpu")
Expand All @@ -363,11 +365,12 @@ def compute_slices(
roi_end_t = torch.maximum(roi_end_t, roi_start_t)
# convert to slices (accounting for 1d)
if roi_start_t.numel() == 1:
return [slice(int(roi_start_t.item()), int(roi_end_t.item()))]
else:
return [slice(int(s), int(e)) for s, e in zip(roi_start_t.tolist(), roi_end_t.tolist())]
return ensure_tuple([slice(int(roi_start_t.item()), int(roi_end_t.item()))]) # type: ignore
return ensure_tuple( # type: ignore
[slice(int(s), int(e)) for s, e in zip(roi_start_t.tolist(), roi_end_t.tolist())]
)

def __call__(self, img: torch.Tensor, slices: tuple[slice, ...]) -> torch.Tensor: # type: ignore
def __call__(self, img: torch.Tensor, slices: tuple[slice, ...]) -> torch.Tensor: # type: ignore[override]
"""
Apply the transform to `img`, assuming `img` is channel-first and
slicing doesn't apply to the channel dim.
Expand All @@ -378,10 +381,10 @@ def __call__(self, img: torch.Tensor, slices: tuple[slice, ...]) -> torch.Tensor
if len(slices_) < sd:
slices_ += [slice(None)] * (sd - len(slices_))
# Add in the channel (no cropping)
slices = tuple([slice(None)] + slices_[:sd])
slices_ = list([slice(None)] + slices_[:sd])

img_t: MetaTensor = convert_to_tensor(data=img, track_meta=get_track_meta())
return crop_func(img_t, slices, self.get_transform_info()) # type: ignore
return crop_func(img_t, tuple(slices_), self.get_transform_info())

def inverse(self, img: MetaTensor) -> MetaTensor:
transform = self.pop_transform(img)
Expand Down Expand Up @@ -429,13 +432,13 @@ def __init__(
roi_center=roi_center, roi_size=roi_size, roi_start=roi_start, roi_end=roi_end, roi_slices=roi_slices
)

def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore
def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore[override]
"""
Apply the transform to `img`, assuming `img` is channel-first and
slicing doesn't apply to the channel dim.

"""
return super().__call__(img=img, slices=self.slices)
return super().__call__(img=img, slices=ensure_tuple(self.slices))


class CenterSpatialCrop(Crop):
Expand All @@ -456,12 +459,12 @@ class CenterSpatialCrop(Crop):
def __init__(self, roi_size: Sequence[int] | int) -> None:
self.roi_size = roi_size

def compute_slices(self, spatial_size: Sequence[int]): # type: ignore
def compute_slices(self, spatial_size: Sequence[int]) -> tuple[slice]: # type: ignore[override]
roi_size = fall_back_tuple(self.roi_size, spatial_size)
roi_center = [i // 2 for i in spatial_size]
return super().compute_slices(roi_center=roi_center, roi_size=roi_size)

def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore
def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore[override]
"""
Apply the transform to `img`, assuming `img` is channel-first and
slicing doesn't apply to the channel dim.
Expand All @@ -486,7 +489,7 @@ class CenterScaleCrop(Crop):
def __init__(self, roi_scale: Sequence[float] | float):
self.roi_scale = roi_scale

def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore
def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore[override]
img_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]
ndim = len(img_size)
roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)]
Expand Down Expand Up @@ -771,7 +774,7 @@ def lazy_evaluation(self, _val: bool):
self._lazy_evaluation = _val
self.padder.lazy_evaluation = _val

def compute_bounding_box(self, img: torch.Tensor):
def compute_bounding_box(self, img: torch.Tensor) -> tuple[np.ndarray, np.ndarray]:
"""
Compute the start points and end points of bounding box to crop.
And adjust bounding box coords to be divisible by `k`.
Expand All @@ -794,7 +797,7 @@ def compute_bounding_box(self, img: torch.Tensor):

def crop_pad(
self, img: torch.Tensor, box_start: np.ndarray, box_end: np.ndarray, mode: str | None = None, **pad_kwargs
):
) -> torch.Tensor:
"""
Crop and pad based on the bounding box.

Expand All @@ -817,7 +820,9 @@ def crop_pad(
ret.applied_operations[-1][TraceKeys.EXTRA_INFO]["pad_info"] = ret.applied_operations.pop()
return ret

def __call__(self, img: torch.Tensor, mode: str | None = None, **pad_kwargs): # type: ignore
def __call__( # type: ignore[override]
self, img: torch.Tensor, mode: str | None = None, **pad_kwargs
) -> torch.Tensor:
"""
Apply the transform to `img`, assuming `img` is channel-first and
slicing doesn't change the channel dim.
Expand All @@ -826,7 +831,7 @@ def __call__(self, img: torch.Tensor, mode: str | None = None, **pad_kwargs): #
cropped = self.crop_pad(img, box_start, box_end, mode, **pad_kwargs)

if self.return_coords:
return cropped, box_start, box_end
return cropped, box_start, box_end # type: ignore[return-value]
return cropped

def inverse(self, img: MetaTensor) -> MetaTensor:
Expand Down Expand Up @@ -995,7 +1000,7 @@ def __init__(
self.num_samples = num_samples
self.image = image
self.image_threshold = image_threshold
self.centers: list[list[int]] | None = None
self.centers: tuple[tuple] | None = None
self.fg_indices = fg_indices
self.bg_indices = bg_indices
self.allow_smaller = allow_smaller
Expand Down Expand Up @@ -1173,7 +1178,7 @@ def __init__(
self.num_samples = num_samples
self.image = image
self.image_threshold = image_threshold
self.centers: list[list[int]] | None = None
self.centers: tuple[tuple] | None = None
self.indices = indices
self.allow_smaller = allow_smaller
self.warn = warn
Expand Down
4 changes: 2 additions & 2 deletions monai/transforms/croppad/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,9 +698,9 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc
self.cropper: CropForeground
box_start, box_end = self.cropper.compute_bounding_box(img=d[self.source_key])
if self.start_coord_key is not None:
d[self.start_coord_key] = box_start
d[self.start_coord_key] = box_start # type: ignore
if self.end_coord_key is not None:
d[self.end_coord_key] = box_end
d[self.end_coord_key] = box_end # type: ignore
for key, m in self.key_iterator(d, self.mode):
d[key] = self.cropper.crop_pad(img=d[key], box_start=box_start, box_end=box_end, mode=m)
return d
Expand Down
28 changes: 15 additions & 13 deletions monai/transforms/croppad/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,9 @@ def crop_or_pad_nd(img: torch.Tensor, translation_mat, spatial_size: tuple[int,
return img


def pad_func(img: torch.Tensor, to_pad: list[tuple[int, int]], mode: str, transform_info: dict, kwargs):
def pad_func(
img: torch.Tensor, to_pad: tuple[tuple[int, int]], mode: str, transform_info: dict, kwargs
) -> torch.Tensor:
"""
Functional implementation of padding a MetaTensor. This function operates eagerly or lazily according
to ``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``).
Expand All @@ -166,17 +168,17 @@ def pad_func(img: torch.Tensor, to_pad: list[tuple[int, int]], mode: str, transf
kwargs: other arguments for the `np.pad` or `torch.pad` function.
note that `np.pad` treats channel dimension as the first dimension.
"""
extra_info = {"padded": to_pad, "mode": str(mode)}
extra_info = {"padded": to_pad, "mode": f"{mode}"}
img_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]
spatial_rank = img.peek_pending_rank() if isinstance(img, MetaTensor) else 3
do_pad = np.asarray(to_pad).any()
if do_pad:
to_pad = list(to_pad)
if len(to_pad) < len(img.shape):
to_pad = list(to_pad) + [(0, 0)] * (len(img.shape) - len(to_pad))
to_shift = [-s[0] for s in to_pad[1:]] # skipping the channel pad
to_pad_list = [(int(p[0]), int(p[1])) for p in to_pad]
if len(to_pad_list) < len(img.shape):
to_pad_list += [(0, 0)] * (len(img.shape) - len(to_pad_list))
to_shift = [-s[0] for s in to_pad_list[1:]] # skipping the channel pad
xform = create_translate(spatial_rank, to_shift)
shape = [d + s + e for d, (s, e) in zip(img_size, to_pad[1:])]
shape = [d + s + e for d, (s, e) in zip(img_size, to_pad_list[1:])]
else:
shape = img_size
xform = torch.eye(int(spatial_rank) + 1, device=torch.device("cpu"), dtype=torch.float64)
Expand All @@ -191,13 +193,13 @@ def pad_func(img: torch.Tensor, to_pad: list[tuple[int, int]], mode: str, transf
)
out = convert_to_tensor(img.as_tensor() if isinstance(img, MetaTensor) else img, track_meta=get_track_meta())
if transform_info.get(TraceKeys.LAZY_EVALUATION, False):
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info
out = pad_nd(out, to_pad, mode, **kwargs) if do_pad else out
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info # type: ignore
out = pad_nd(out, to_pad_list, mode, **kwargs) if do_pad else out
out = convert_to_tensor(out, track_meta=get_track_meta())
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out # type: ignore


def crop_func(img: torch.Tensor, slices: tuple[slice, ...], transform_info: dict):
def crop_func(img: torch.Tensor, slices: tuple[slice, ...], transform_info: dict) -> torch.Tensor:
"""
Functional implementation of cropping a MetaTensor. This function operates eagerly or lazily according
to ``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``).
Expand Down Expand Up @@ -229,6 +231,6 @@ def crop_func(img: torch.Tensor, slices: tuple[slice, ...], transform_info: dict
)
out = convert_to_tensor(img.as_tensor() if isinstance(img, MetaTensor) else img, track_meta=get_track_meta())
if transform_info.get(TraceKeys.LAZY_EVALUATION, False):
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info # type: ignore
out = out[slices]
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out # type: ignore
2 changes: 1 addition & 1 deletion monai/transforms/inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def track_transform_meta(
return data
return out_obj # return with data_t as tensor if get_track_meta() is False

info = transform_info
info = transform_info.copy()
# track the current spatial shape
if orig_size is not None:
info[TraceKeys.ORIG_SIZE] = orig_size
Expand Down
Loading