diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 1c3bbf6833..c904631bed 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -419,7 +419,7 @@ def __init__( if roi_start_torch.numel() == 1: self.slices = [slice(int(roi_start_torch.item()), int(roi_end_torch.item()))] else: - self.slices = [slice(int(s.item()), int(e.item())) for s, e in zip(roi_start_torch, roi_end_torch)] + self.slices = [slice(int(s), int(e)) for s, e in zip(roi_start_torch.tolist(), roi_end_torch.tolist())] def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ @@ -966,7 +966,7 @@ def __call__( results: List[NdarrayOrTensor] = [] if self.centers is not None: for center in self.centers: - cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) + cropper = SpatialCrop(roi_center=center, roi_size=self.spatial_size) results.append(cropper(img)) return results diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 01a62e36ff..26110081bb 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -291,6 +291,9 @@ def map_binary_to_indices( else: bg_indices = nonzero(~label_flat) + # no need to save the indices in GPU, otherwise, still need to move to CPU at runtime when crop by indices + fg_indices, *_ = convert_data_type(fg_indices, device=torch.device("cpu")) + bg_indices, *_ = convert_data_type(bg_indices, device=torch.device("cpu")) return fg_indices, bg_indices @@ -389,12 +392,12 @@ def correct_crop_centers( centers: List[Union[int, torch.Tensor]], spatial_size: Union[Sequence[int], int], label_spatial_shape: Sequence[int], -) -> List[int]: +): """ Utility to correct the crop center if the crop size is bigger than the image size. Args: - ceters: pre-computed crop centers, will correct based on the valid region. + centers: pre-computed crop centers of every dim, will correct based on the valid region. spatial_size: spatial size of the ROIs to be sampled. label_spatial_shape: spatial shape of the original label data to compare with ROI. @@ -422,9 +425,7 @@ def correct_crop_centers( center_i = valid_end[i] - 1 centers[i] = center_i - corrected_centers: List[int] = [c.item() if isinstance(c, torch.Tensor) else c for c in centers] # type: ignore - - return corrected_centers + return centers def generate_pos_neg_label_crop_centers( @@ -476,8 +477,7 @@ def generate_pos_neg_label_crop_centers( idx = indices_to_use[random_int] center = unravel_index(idx, label_spatial_shape) # shift center to range of valid centers - center_ori = list(center) - centers.append(correct_crop_centers(center_ori, spatial_size, label_spatial_shape)) + centers.append(correct_crop_centers(center, spatial_size, label_spatial_shape)) return centers diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index 3fe2402504..2c370a4707 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -178,7 +178,7 @@ def unravel_index(idx, shape): coord.append(idx % dim) idx = floor_divide(idx, dim) return torch.stack(coord[::-1]) - return np.unravel_index(np.asarray(idx, dtype=int), shape) + return np.asarray(np.unravel_index(idx, shape)) def unravel_indices(idx, shape): diff --git a/tests/test_generate_label_classes_crop_centers.py b/tests/test_generate_label_classes_crop_centers.py index cc068504bf..0e40750276 100644 --- a/tests/test_generate_label_classes_crop_centers.py +++ b/tests/test_generate_label_classes_crop_centers.py @@ -61,7 +61,8 @@ def test_type_shape(self, input_data, expected_type, expected_count, expected_sh # check for consistency between numpy, torch and torch.cuda results.append(result) if len(results) > 1: - assert_allclose(results[0], results[-1]) + for x, y in zip(result[0], result[-1]): + assert_allclose(x, y, type_test=False) if __name__ == "__main__": diff --git a/tests/test_generate_pos_neg_label_crop_centers.py b/tests/test_generate_pos_neg_label_crop_centers.py index b263f10e55..b8f2840757 100644 --- a/tests/test_generate_pos_neg_label_crop_centers.py +++ b/tests/test_generate_pos_neg_label_crop_centers.py @@ -53,7 +53,9 @@ def test_type_shape(self, input_data, expected_type, expected_count, expected_sh # check for consistency between numpy, torch and torch.cuda results.append(result) if len(results) > 1: - assert_allclose(results[0], results[-1]) + # compare every crop center + for x, y in zip(results[0], results[-1]): + assert_allclose(x, y, type_test=False) if __name__ == "__main__":