diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 79e066303e..871b523289 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -36,15 +36,7 @@ from monai.data.meta_tensor import MetaTensor from monai.data.utils import SUPPORTED_PICKLE_MOD, convert_tables_to_dicts, pickle_hashing -from monai.transforms import ( - Compose, - Randomizable, - RandomizableTrait, - Transform, - apply_transform, - convert_to_contiguous, - reset_ops_id, -) +from monai.transforms import Compose, Randomizable, RandomizableTrait, Transform, convert_to_contiguous, reset_ops_id from monai.utils import MAX_SEED, convert_to_tensor, get_seed, look_up_option, min_version, optional_import from monai.utils.misc import first @@ -77,15 +69,19 @@ class Dataset(_TorchDataset): }, }, }] """ - def __init__(self, data: Sequence, transform: Callable | None = None) -> None: + def __init__(self, data: Sequence, transform: Sequence[Callable] | Callable | None = None) -> None: """ Args: data: input data to load and transform to generate dataset for model. - transform: a callable data transform on input data. - + transform: a callable, sequence of callables or None. If transform is not + a `Compose` instance, it will be wrapped in a `Compose` instance. Sequences + of callables are applied in order and if `None` is passed, the data is returned as is. """ self.data = data - self.transform: Any = transform + try: + self.transform = Compose(transform) if not isinstance(transform, Compose) else transform + except Exception as e: + raise ValueError("`transform` must be a callable or a list of callables that is Composable") from e def __len__(self) -> int: return len(self.data) @@ -95,7 +91,7 @@ def _transform(self, index: int): Fetch single data item from `self.data`. """ data_i = self.data[index] - return apply_transform(self.transform, data_i) if self.transform is not None else data_i + return self.transform(data_i) def __getitem__(self, index: int | slice | Sequence[int]): """ @@ -264,8 +260,6 @@ def __init__( using the cached content and with re-created transform instances. """ - if not isinstance(transform, Compose): - transform = Compose(transform) super().__init__(data=data, transform=transform) self.cache_dir = Path(cache_dir) if cache_dir is not None else None self.hash_func = hash_func @@ -323,9 +317,6 @@ def _pre_transform(self, item_transformed): random transform object """ - if not isinstance(self.transform, Compose): - raise ValueError("transform must be an instance of monai.transforms.Compose.") - first_random = self.transform.get_index_of_first( lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform) ) @@ -346,9 +337,6 @@ def _post_transform(self, item_transformed): the transformed element through the random transforms """ - if not isinstance(self.transform, Compose): - raise ValueError("transform must be an instance of monai.transforms.Compose.") - first_random = self.transform.get_index_of_first( lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform) ) @@ -501,9 +489,6 @@ def _pre_transform(self, item_transformed): Returns: the transformed element up to the N transform object """ - if not isinstance(self.transform, Compose): - raise ValueError("transform must be an instance of monai.transforms.Compose.") - item_transformed = self.transform(item_transformed, end=self.cache_n_trans, threading=True) reset_ops_id(item_transformed) @@ -519,9 +504,6 @@ def _post_transform(self, item_transformed): Returns: the final transformed result """ - if not isinstance(self.transform, Compose): - raise ValueError("transform must be an instance of monai.transforms.Compose.") - return self.transform(item_transformed, start=self.cache_n_trans) @@ -809,8 +791,6 @@ def __init__( Not following these recommendations may lead to runtime errors or duplicated cache across processes. """ - if not isinstance(transform, Compose): - transform = Compose(transform) super().__init__(data=data, transform=transform) self.set_num = cache_num # tracking the user-provided `cache_num` option self.set_rate = cache_rate # tracking the user-provided `cache_rate` option @@ -1282,8 +1262,10 @@ def to_list(x): data = [] for dataset in self.data: data.extend(to_list(dataset[index])) + if self.transform is not None: - data = apply_transform(self.transform, data, map_items=False) # transform the list data + self.transform.map_items = False # Compose object map_items to false so transform is applied to list + data = self.transform(data) # use tuple instead of list as the default collate_fn callback of MONAI DataLoader flattens nested lists return tuple(data) @@ -1432,15 +1414,11 @@ def __len__(self): def _transform(self, index: int): data = {k: v[index] for k, v in self.arrays.items()} - - if not self.transform: - return data - - result = apply_transform(self.transform, data) + result = self.transform(data) if self.transform is not None else data if isinstance(result, dict) or (isinstance(result, list) and isinstance(result[0], dict)): return result - raise AssertionError("With a dict supplied to apply_transform, should return a dict or a list of dicts.") + raise AssertionError("With a dict supplied to Compose, should return a dict or a list of dicts.") class CSVDataset(Dataset): diff --git a/tests/test_arraydataset.py b/tests/test_arraydataset.py index efc014a267..b61b3c139c 100644 --- a/tests/test_arraydataset.py +++ b/tests/test_arraydataset.py @@ -41,7 +41,7 @@ class TestCompose(Compose): - def __call__(self, input_, lazy): + def __call__(self, input_, lazy=False): img = self.transforms[0](input_) metadata = img.meta img = self.transforms[1](img) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 1398009c63..0d37ae2efd 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -23,7 +23,7 @@ from parameterized import parameterized from monai.data import Dataset -from monai.transforms import Compose, LoadImaged, SimulateDelayd +from monai.transforms import Compose, Lambda, LoadImage, LoadImaged, SimulateDelay, SimulateDelayd from tests.test_compose import TEST_COMPOSE_LAZY_ON_CALL_LOGGING_TEST_CASES, data_from_keys TEST_CASE_1 = [(128, 128, 128)] @@ -99,6 +99,72 @@ def test_dataset_lazy_on_call(self): data[0, 0:2, 0:2] = 1 +class TestTupleDataset(unittest.TestCase): + + @parameterized.expand([TEST_CASE_1]) + def test_shape(self, expected_shape): + test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4)) + with tempfile.TemporaryDirectory() as tempdir: + nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_image2.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_label2.nii.gz")) + test_data = [ + (os.path.join(tempdir, "test_image1.nii.gz"), os.path.join(tempdir, "test_label1.nii.gz")), + (os.path.join(tempdir, "test_image2.nii.gz"), os.path.join(tempdir, "test_label2.nii.gz")), + ] + + test_transform = Compose([LoadImage(), SimulateDelay(delay_time=1e-5)]) + + # Here test_transform is applied element by element for the tuple. + dataset = Dataset(data=test_data, transform=test_transform) + data1 = dataset[0] + data2 = dataset[1] + + # Output is a list/tuple + self.assertTrue(isinstance(data1, (list, tuple))) + self.assertTrue(isinstance(data2, (list, tuple))) + + # Number of elements are 2 + self.assertEqual(len(data1), 2) + self.assertEqual(len(data2), 2) + + # Output shapes are as expected + self.assertTupleEqual(data1[0].shape, expected_shape) + self.assertTupleEqual(data1[1].shape, expected_shape) + self.assertTupleEqual(data2[0].shape, expected_shape) + self.assertTupleEqual(data2[1].shape, expected_shape) + + # Here test_transform is applied to the tuple as a whole. + test_transform = Compose( + [ + # LoadImage creates a channel-stacked image when applied to a tuple + LoadImage(), + # Get the channel-stacked image and the label + Lambda(func=lambda x: (x[0].permute(2, 1, 0), x[1])), + ], + map_items=False, + ) + + dataset = Dataset(data=test_data, transform=test_transform) + data1 = dataset[0] + data2 = dataset[1] + + # Output is a list/tuple + self.assertTrue(isinstance(data1, (list, tuple))) + self.assertTrue(isinstance(data2, (list, tuple))) + + # Number of elements are 2 + self.assertEqual(len(data1), 2) + self.assertEqual(len(data2), 2) + + # Output shapes are as expected + self.assertTupleEqual(data1[0].shape, expected_shape) + self.assertTupleEqual(data1[1].shape, expected_shape) + self.assertTupleEqual(data2[0].shape, expected_shape) + self.assertTupleEqual(data2[1].shape, expected_shape) + + class TestDatsesetWithLazy(unittest.TestCase): LOGGER_NAME = "a_logger_name" diff --git a/tests/test_profiling.py b/tests/test_profiling.py index 6bee7ba262..649d980ebf 100644 --- a/tests/test_profiling.py +++ b/tests/test_profiling.py @@ -35,6 +35,7 @@ def setUp(self): self.scale = mt.ScaleIntensity() self.scale_call_name = "ScaleIntensity.__call__" + self.compose_call_name = "Compose.__call__" self.test_comp = mt.Compose([mt.ScaleIntensity(), mt.RandAxisFlip(0.5)]) self.test_image = torch.rand(1, 16, 16, 16) self.pid = os.getpid() @@ -82,7 +83,7 @@ def test_profile_multithread(self): self.assertSequenceEqual(batch.shape, (4, 1, 16, 16, 16)) results = wp.get_results() - self.assertSequenceEqual(list(results), [self.scale_call_name]) + self.assertSequenceEqual(list(results), [self.scale_call_name, self.compose_call_name]) prs = results[self.scale_call_name] @@ -98,6 +99,7 @@ def test_profile_context(self): self.scale(self.test_image) results = wp.get_results() + self.assertSequenceEqual(set(results), {"ScaleIntensity.__call__", "context"}) prs = results["context"]