From 28686fd6abe4a64046f46e4a36f84634d5557210 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 13 Aug 2021 20:52:41 +0800 Subject: [PATCH] [DLMED] enhance partition_data_classes Signed-off-by: Nic Ma --- monai/data/utils.py | 2 +- tests/test_partition_dataset_classes.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/monai/data/utils.py b/monai/data/utils.py index 25b3c24e4a..aab23217dc 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -960,7 +960,7 @@ def partition_dataset_classes( [[2, 8, 4, 1, 3, 6, 5, 11, 12], [10, 13, 7, 9, 14]] """ - if not classes or len(classes) != len(data): + if not issequenceiterable(classes) or len(classes) != len(data): raise ValueError(f"length of classes {classes} must match the dataset length {len(data)}.") datasets = [] class_indices = defaultdict(list) diff --git a/tests/test_partition_dataset_classes.py b/tests/test_partition_dataset_classes.py index 0e28b8f76a..3aef47107a 100644 --- a/tests/test_partition_dataset_classes.py +++ b/tests/test_partition_dataset_classes.py @@ -11,6 +11,7 @@ import unittest +import numpy as np from parameterized import parameterized from monai.data import partition_dataset_classes @@ -59,8 +60,8 @@ TEST_CASE_4 = [ { - "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], - "classes": [2, 0, 2, 1, 3, 2, 2, 0, 2, 0, 3, 3, 1, 3], + "data": np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]), + "classes": np.array([2, 0, 2, 1, 3, 2, 2, 0, 2, 0, 3, 3, 1, 3]), "ratios": [1, 2], "num_partitions": None, "shuffle": True,