From 41d3f9f18a600ea2a34a38174d97a39dd295fa27 Mon Sep 17 00:00:00 2001 From: myron Date: Thu, 4 May 2023 13:18:52 -0700 Subject: [PATCH 1/3] map_classes_to_indices output torch.Tensor Signed-off-by: myron --- monai/transforms/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 6dc75141af..02657fa054 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -409,7 +409,8 @@ def map_classes_to_indices( if img_flat is not None: label_flat = img_flat & label_flat # no need to save the indices in GPU, otherwise, still need to move to CPU at runtime when crop by indices - cls_indices: NdarrayOrTensor = convert_data_type(nonzero(label_flat), device=torch.device("cpu"))[0] + output_type = torch.Tensor if isinstance(label, monai.data.MetaTensor) else type(label) # output in tensor or ndarray + cls_indices: NdarrayOrTensor = convert_data_type(nonzero(label_flat), output_type=output_type, device=torch.device("cpu"))[0] if max_samples_per_class and len(cls_indices) > max_samples_per_class and len(cls_indices) > 1: sample_id = np.round(np.linspace(0, len(cls_indices) - 1, max_samples_per_class)).astype(int) indices.append(cls_indices[sample_id]) From a8cb2514284f13218000ce5c1d60ab06253c4699 Mon Sep 17 00:00:00 2001 From: myron Date: Thu, 4 May 2023 17:04:04 -0700 Subject: [PATCH 2/3] blake Signed-off-by: myron --- monai/transforms/utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 02657fa054..6f2f372c7a 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -409,8 +409,12 @@ def map_classes_to_indices( if img_flat is not None: label_flat = img_flat & label_flat # no need to save the indices in GPU, otherwise, still need to move to CPU at runtime when crop by indices - output_type = torch.Tensor if isinstance(label, monai.data.MetaTensor) else type(label) # output in tensor or ndarray - cls_indices: NdarrayOrTensor = convert_data_type(nonzero(label_flat), output_type=output_type, device=torch.device("cpu"))[0] + output_type = ( + torch.Tensor if isinstance(label, monai.data.MetaTensor) else type(label) + ) # output in tensor or ndarray + cls_indices: NdarrayOrTensor = convert_data_type( + nonzero(label_flat), output_type=output_type, device=torch.device("cpu") + )[0] if max_samples_per_class and len(cls_indices) > max_samples_per_class and len(cls_indices) > 1: sample_id = np.round(np.linspace(0, len(cls_indices) - 1, max_samples_per_class)).astype(int) indices.append(cls_indices[sample_id]) From 537f510ca297bbc8ef6c10f448d7217b08ffa01d Mon Sep 17 00:00:00 2001 From: myron Date: Thu, 4 May 2023 23:26:06 -0700 Subject: [PATCH 3/3] mypy Signed-off-by: myron --- monai/transforms/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 6f2f372c7a..7f15ff4109 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -409,9 +409,7 @@ def map_classes_to_indices( if img_flat is not None: label_flat = img_flat & label_flat # no need to save the indices in GPU, otherwise, still need to move to CPU at runtime when crop by indices - output_type = ( - torch.Tensor if isinstance(label, monai.data.MetaTensor) else type(label) - ) # output in tensor or ndarray + output_type = torch.Tensor if isinstance(label, monai.data.MetaTensor) else None cls_indices: NdarrayOrTensor = convert_data_type( nonzero(label_flat), output_type=output_type, device=torch.device("cpu") )[0]