diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 6dc75141af..7f15ff4109 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -409,7 +409,10 @@ def map_classes_to_indices( if img_flat is not None: label_flat = img_flat & label_flat # no need to save the indices in GPU, otherwise, still need to move to CPU at runtime when crop by indices - cls_indices: NdarrayOrTensor = convert_data_type(nonzero(label_flat), device=torch.device("cpu"))[0] + output_type = torch.Tensor if isinstance(label, monai.data.MetaTensor) else None + cls_indices: NdarrayOrTensor = convert_data_type( + nonzero(label_flat), output_type=output_type, device=torch.device("cpu") + )[0] if max_samples_per_class and len(cls_indices) > max_samples_per_class and len(cls_indices) > 1: sample_id = np.round(np.linspace(0, len(cls_indices) - 1, max_samples_per_class)).astype(int) indices.append(cls_indices[sample_id])