Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions monai/transforms/croppad/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def __init__(
if roi_start_torch.numel() == 1:
self.slices = [slice(int(roi_start_torch.item()), int(roi_end_torch.item()))]
else:
self.slices = [slice(int(s.item()), int(e.item())) for s, e in zip(roi_start_torch, roi_end_torch)]
self.slices = [slice(int(s), int(e)) for s, e in zip(roi_start_torch.tolist(), roi_end_torch.tolist())]

def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Expand Down Expand Up @@ -966,7 +966,7 @@ def __call__(
results: List[NdarrayOrTensor] = []
if self.centers is not None:
for center in self.centers:
cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size)
cropper = SpatialCrop(roi_center=center, roi_size=self.spatial_size)
results.append(cropper(img))

return results
Expand Down
14 changes: 7 additions & 7 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,9 @@ def map_binary_to_indices(
else:
bg_indices = nonzero(~label_flat)

# no need to save the indices in GPU, otherwise, still need to move to CPU at runtime when crop by indices
fg_indices, *_ = convert_data_type(fg_indices, device=torch.device("cpu"))
bg_indices, *_ = convert_data_type(bg_indices, device=torch.device("cpu"))
return fg_indices, bg_indices


Expand Down Expand Up @@ -389,12 +392,12 @@ def correct_crop_centers(
centers: List[Union[int, torch.Tensor]],
spatial_size: Union[Sequence[int], int],
label_spatial_shape: Sequence[int],
) -> List[int]:
):
"""
Utility to correct the crop center if the crop size is bigger than the image size.

Args:
ceters: pre-computed crop centers, will correct based on the valid region.
centers: pre-computed crop centers of every dim, will correct based on the valid region.
spatial_size: spatial size of the ROIs to be sampled.
label_spatial_shape: spatial shape of the original label data to compare with ROI.

Expand Down Expand Up @@ -422,9 +425,7 @@ def correct_crop_centers(
center_i = valid_end[i] - 1
centers[i] = center_i

corrected_centers: List[int] = [c.item() if isinstance(c, torch.Tensor) else c for c in centers] # type: ignore

return corrected_centers
return centers


def generate_pos_neg_label_crop_centers(
Expand Down Expand Up @@ -476,8 +477,7 @@ def generate_pos_neg_label_crop_centers(
idx = indices_to_use[random_int]
center = unravel_index(idx, label_spatial_shape)
# shift center to range of valid centers
center_ori = list(center)
centers.append(correct_crop_centers(center_ori, spatial_size, label_spatial_shape))
centers.append(correct_crop_centers(center, spatial_size, label_spatial_shape))

return centers

Expand Down
2 changes: 1 addition & 1 deletion monai/transforms/utils_pytorch_numpy_unification.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def unravel_index(idx, shape):
coord.append(idx % dim)
idx = floor_divide(idx, dim)
return torch.stack(coord[::-1])
return np.unravel_index(np.asarray(idx, dtype=int), shape)
return np.asarray(np.unravel_index(idx, shape))


def unravel_indices(idx, shape):
Expand Down
3 changes: 2 additions & 1 deletion tests/test_generate_label_classes_crop_centers.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def test_type_shape(self, input_data, expected_type, expected_count, expected_sh
# check for consistency between numpy, torch and torch.cuda
results.append(result)
if len(results) > 1:
assert_allclose(results[0], results[-1])
for x, y in zip(result[0], result[-1]):
assert_allclose(x, y, type_test=False)


if __name__ == "__main__":
Expand Down
4 changes: 3 additions & 1 deletion tests/test_generate_pos_neg_label_crop_centers.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def test_type_shape(self, input_data, expected_type, expected_count, expected_sh
# check for consistency between numpy, torch and torch.cuda
results.append(result)
if len(results) > 1:
assert_allclose(results[0], results[-1])
# compare every crop center
for x, y in zip(results[0], results[-1]):
assert_allclose(x, y, type_test=False)


if __name__ == "__main__":
Expand Down