From a4b9cb4d6dbbe923ba3fb88ffea148da36c3fef5 Mon Sep 17 00:00:00 2001 From: Suraj Pai Date: Mon, 20 May 2024 22:11:44 +0000 Subject: [PATCH 1/9] Support kwargs on Dataset for apply_transform fn Signed-off-by: Suraj Pai --- monai/data/dataset.py | 10 ++++++++-- tests/test_dataset.py | 16 +++++++++++++++- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 79e066303e..72201958c4 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -77,15 +77,17 @@ class Dataset(_TorchDataset): }, }, }] """ - def __init__(self, data: Sequence, transform: Callable | None = None) -> None: + def __init__(self, data: Sequence, transform: Callable | None = None, **kwargs) -> None: """ Args: data: input data to load and transform to generate dataset for model. transform: a callable data transform on input data. + kwargs: other arguments for the `apply_transform` function called to apply provided transforms. For ex. `map_items=False` """ self.data = data self.transform: Any = transform + self.apply_transform_kwargs = kwargs def __len__(self) -> int: return len(self.data) @@ -95,7 +97,11 @@ def _transform(self, index: int): Fetch single data item from `self.data`. """ data_i = self.data[index] - return apply_transform(self.transform, data_i) if self.transform is not None else data_i + return ( + apply_transform(self.transform, data_i, **self.apply_transform_kwargs) + if self.transform is not None + else data_i + ) def __getitem__(self, index: int | slice | Sequence[int]): """ diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 1398009c63..4b40a17c67 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -23,7 +23,7 @@ from parameterized import parameterized from monai.data import Dataset -from monai.transforms import Compose, LoadImaged, SimulateDelayd +from monai.transforms import Compose, Lambda, LoadImaged, SimulateDelayd, ToTensor from tests.test_compose import TEST_COMPOSE_LAZY_ON_CALL_LOGGING_TEST_CASES, data_from_keys TEST_CASE_1 = [(128, 128, 128)] @@ -98,6 +98,20 @@ def test_dataset_lazy_on_call(self): data = np.zeros((1, 5, 5)) data[0, 0:2, 0:2] = 1 + def test_dataset_transform_kwargs(self): + test_array = np.random.randint(0, 2, size=[128, 128, 128]).astype(float) + test_data_tuple = (test_array, 4) # Simulate input, target tuple common in torchvision datasets + + # Convert both elements in test_data_tuple to tensors + test_transform = ToTensor() + dataset = Dataset(data=[test_data_tuple], transform=test_transform) + self.assertTrue(isinstance(dataset[0], (tuple, list))) + + # Transpose the first element in test_data_tuple and keep the second element as is + test_transform = Lambda(func=lambda x: (x[0].transpose(2, 1, 0), x[1])) + dataset = Dataset(data=[test_data_tuple], transform=test_transform, map_items=False) + self.assertTrue(isinstance(dataset[0], (tuple, list))) + class TestDatsesetWithLazy(unittest.TestCase): LOGGER_NAME = "a_logger_name" From 99a564b2509b140008495dd14b23c80797df53de Mon Sep 17 00:00:00 2001 From: Suraj Pai Date: Tue, 21 May 2024 02:36:52 +0000 Subject: [PATCH 2/9] Fix flake8 line length error Signed-off-by: Suraj Pai --- monai/data/dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 72201958c4..8d98a8ca70 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -82,7 +82,8 @@ def __init__(self, data: Sequence, transform: Callable | None = None, **kwargs) Args: data: input data to load and transform to generate dataset for model. transform: a callable data transform on input data. - kwargs: other arguments for the `apply_transform` function called to apply provided transforms. For ex. `map_items=False` + kwargs: other arguments for `apply_transform` fn called to apply provided transforms. + For ex. `map_items=False` """ self.data = data From fb5edc88b60c71076038e79a2451faa4a8f6f219 Mon Sep 17 00:00:00 2001 From: Suraj Pai Date: Thu, 23 May 2024 01:00:20 +0000 Subject: [PATCH 3/9] Force dataset transform to Compose Signed-off-by: Suraj Pai --- monai/data/dataset.py | 15 ++------- tests/test_dataset.py | 76 ++++++++++++++++++++++++++++++++++++------- 2 files changed, 67 insertions(+), 24 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 8d98a8ca70..2b2d1e68eb 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -77,7 +77,7 @@ class Dataset(_TorchDataset): }, }, }] """ - def __init__(self, data: Sequence, transform: Callable | None = None, **kwargs) -> None: + def __init__(self, data: Sequence, transform: Callable | None = None) -> None: """ Args: data: input data to load and transform to generate dataset for model. @@ -87,8 +87,7 @@ def __init__(self, data: Sequence, transform: Callable | None = None, **kwargs) """ self.data = data - self.transform: Any = transform - self.apply_transform_kwargs = kwargs + self.transform = Compose(transform) if not isinstance(transform, Compose) else transform def __len__(self) -> int: return len(self.data) @@ -98,11 +97,7 @@ def _transform(self, index: int): Fetch single data item from `self.data`. """ data_i = self.data[index] - return ( - apply_transform(self.transform, data_i, **self.apply_transform_kwargs) - if self.transform is not None - else data_i - ) + return self.transform(data_i) if self.transform is not None else data_i def __getitem__(self, index: int | slice | Sequence[int]): """ @@ -271,8 +266,6 @@ def __init__( using the cached content and with re-created transform instances. """ - if not isinstance(transform, Compose): - transform = Compose(transform) super().__init__(data=data, transform=transform) self.cache_dir = Path(cache_dir) if cache_dir is not None else None self.hash_func = hash_func @@ -816,8 +809,6 @@ def __init__( Not following these recommendations may lead to runtime errors or duplicated cache across processes. """ - if not isinstance(transform, Compose): - transform = Compose(transform) super().__init__(data=data, transform=transform) self.set_num = cache_num # tracking the user-provided `cache_num` option self.set_rate = cache_rate # tracking the user-provided `cache_rate` option diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 4b40a17c67..b7349e4200 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -23,7 +23,7 @@ from parameterized import parameterized from monai.data import Dataset -from monai.transforms import Compose, Lambda, LoadImaged, SimulateDelayd, ToTensor +from monai.transforms import Compose, Lambda, LoadImage, LoadImaged, SimulateDelay, SimulateDelayd, ToTensor from tests.test_compose import TEST_COMPOSE_LAZY_ON_CALL_LOGGING_TEST_CASES, data_from_keys TEST_CASE_1 = [(128, 128, 128)] @@ -98,19 +98,71 @@ def test_dataset_lazy_on_call(self): data = np.zeros((1, 5, 5)) data[0, 0:2, 0:2] = 1 - def test_dataset_transform_kwargs(self): - test_array = np.random.randint(0, 2, size=[128, 128, 128]).astype(float) - test_data_tuple = (test_array, 4) # Simulate input, target tuple common in torchvision datasets - # Convert both elements in test_data_tuple to tensors - test_transform = ToTensor() - dataset = Dataset(data=[test_data_tuple], transform=test_transform) - self.assertTrue(isinstance(dataset[0], (tuple, list))) +class TestTupleDataset(unittest.TestCase): - # Transpose the first element in test_data_tuple and keep the second element as is - test_transform = Lambda(func=lambda x: (x[0].transpose(2, 1, 0), x[1])) - dataset = Dataset(data=[test_data_tuple], transform=test_transform, map_items=False) - self.assertTrue(isinstance(dataset[0], (tuple, list))) + @parameterized.expand([TEST_CASE_1]) + def test_shape(self, expected_shape): + test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4)) + with tempfile.TemporaryDirectory() as tempdir: + nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_image2.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_label2.nii.gz")) + test_data = [ + (os.path.join(tempdir, "test_image1.nii.gz"), os.path.join(tempdir, "test_label1.nii.gz")), + (os.path.join(tempdir, "test_image2.nii.gz"), os.path.join(tempdir, "test_label2.nii.gz")), + ] + + test_transform = Compose([LoadImage(), SimulateDelay(delay_time=1e-5)]) + + # Here test_transform is applied element by element for the tuple. + dataset = Dataset(data=test_data, transform=test_transform) + data1 = dataset[0] + data2 = dataset[1] + + # Output is a list/tuple + self.assertTrue(isinstance(data1, (list, tuple))) + self.assertTrue(isinstance(data2, (list, tuple))) + + # Number of elements are 2 + self.assertEqual(len(data1), 2) + self.assertEqual(len(data2), 2) + + # Output shapes are as expected + self.assertTupleEqual(data1[0].shape, expected_shape) + self.assertTupleEqual(data1[1].shape, expected_shape) + self.assertTupleEqual(data2[0].shape, expected_shape) + self.assertTupleEqual(data2[1].shape, expected_shape) + + # Here test_transform is applied to the tuple as a whole. + test_transform = Compose( + [ + # LoadImage creates a channel-stacked image when applied to a tuple + LoadImage(), + # Get the channel-stacked image and the label + Lambda(func=lambda x: (x[0].permute(2, 1, 0), x[1])), + ], + map_items=False, + ) + + dataset = Dataset(data=test_data, transform=test_transform) + data1 = dataset[0] + data2 = dataset[1] + + # Output is a list/tuple + self.assertTrue(isinstance(data1, (list, tuple))) + self.assertTrue(isinstance(data2, (list, tuple))) + + # Number of elements are 2 + self.assertEqual(len(data1), 2) + self.assertEqual(len(data2), 2) + + # Output shapes are as expected + self.assertTupleEqual(data1[0].shape, expected_shape) + self.assertTupleEqual(data1[1].shape, expected_shape) + self.assertTupleEqual(data2[0].shape, expected_shape) + self.assertTupleEqual(data2[1].shape, expected_shape) class TestDatsesetWithLazy(unittest.TestCase): From 8eb61a9e0c59a6bbd5fa1e64bc82cba3974f0f33 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 23 May 2024 01:00:46 +0000 Subject: [PATCH 4/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index b7349e4200..0d37ae2efd 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -23,7 +23,7 @@ from parameterized import parameterized from monai.data import Dataset -from monai.transforms import Compose, Lambda, LoadImage, LoadImaged, SimulateDelay, SimulateDelayd, ToTensor +from monai.transforms import Compose, Lambda, LoadImage, LoadImaged, SimulateDelay, SimulateDelayd from tests.test_compose import TEST_COMPOSE_LAZY_ON_CALL_LOGGING_TEST_CASES, data_from_keys TEST_CASE_1 = [(128, 128, 128)] From 52f0981e3e4a8aafee0b909300929e071e45b1aa Mon Sep 17 00:00:00 2001 From: Suraj Pai Date: Thu, 23 May 2024 05:34:04 +0000 Subject: [PATCH 5/9] Tests fix + refactor all apply_transform to Compose Signed-off-by: Suraj Pai --- monai/data/dataset.py | 49 ++++++++++++++--------------------------- tests/test_profiling.py | 4 +++- 2 files changed, 19 insertions(+), 34 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 2b2d1e68eb..be990f03f1 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -36,15 +36,7 @@ from monai.data.meta_tensor import MetaTensor from monai.data.utils import SUPPORTED_PICKLE_MOD, convert_tables_to_dicts, pickle_hashing -from monai.transforms import ( - Compose, - Randomizable, - RandomizableTrait, - Transform, - apply_transform, - convert_to_contiguous, - reset_ops_id, -) +from monai.transforms import Compose, Randomizable, RandomizableTrait, Transform, convert_to_contiguous, reset_ops_id from monai.utils import MAX_SEED, convert_to_tensor, get_seed, look_up_option, min_version, optional_import from monai.utils.misc import first @@ -77,17 +69,20 @@ class Dataset(_TorchDataset): }, }, }] """ - def __init__(self, data: Sequence, transform: Callable | None = None) -> None: + def __init__(self, data: Sequence, transform: Sequence[Callable] | Callable | None = None) -> None: """ Args: data: input data to load and transform to generate dataset for model. transform: a callable data transform on input data. - kwargs: other arguments for `apply_transform` fn called to apply provided transforms. - For ex. `map_items=False` - """ self.data = data - self.transform = Compose(transform) if not isinstance(transform, Compose) else transform + if transform is None: + self.transform = None + else: + try: + self.transform = Compose(transform) if not isinstance(transform, Compose) else transform + except Exception as e: + raise ValueError("`transform` must be a callable or a list of callables that is Composable") from e def __len__(self) -> int: return len(self.data) @@ -267,6 +262,8 @@ def __init__( """ super().__init__(data=data, transform=transform) + if self.transform is None: + raise ValueError("transform must not be None when provided to a PersistentDataset") self.cache_dir = Path(cache_dir) if cache_dir is not None else None self.hash_func = hash_func self.pickle_module = pickle_module @@ -323,9 +320,6 @@ def _pre_transform(self, item_transformed): random transform object """ - if not isinstance(self.transform, Compose): - raise ValueError("transform must be an instance of monai.transforms.Compose.") - first_random = self.transform.get_index_of_first( lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform) ) @@ -346,9 +340,6 @@ def _post_transform(self, item_transformed): the transformed element through the random transforms """ - if not isinstance(self.transform, Compose): - raise ValueError("transform must be an instance of monai.transforms.Compose.") - first_random = self.transform.get_index_of_first( lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform) ) @@ -501,9 +492,6 @@ def _pre_transform(self, item_transformed): Returns: the transformed element up to the N transform object """ - if not isinstance(self.transform, Compose): - raise ValueError("transform must be an instance of monai.transforms.Compose.") - item_transformed = self.transform(item_transformed, end=self.cache_n_trans, threading=True) reset_ops_id(item_transformed) @@ -519,9 +507,6 @@ def _post_transform(self, item_transformed): Returns: the final transformed result """ - if not isinstance(self.transform, Compose): - raise ValueError("transform must be an instance of monai.transforms.Compose.") - return self.transform(item_transformed, start=self.cache_n_trans) @@ -1280,8 +1265,10 @@ def to_list(x): data = [] for dataset in self.data: data.extend(to_list(dataset[index])) + if self.transform is not None: - data = apply_transform(self.transform, data, map_items=False) # transform the list data + self.transform.map_items = False # Compose object map_items to false so transform is applied to list + data = self.transform(data) # use tuple instead of list as the default collate_fn callback of MONAI DataLoader flattens nested lists return tuple(data) @@ -1430,15 +1417,11 @@ def __len__(self): def _transform(self, index: int): data = {k: v[index] for k, v in self.arrays.items()} - - if not self.transform: - return data - - result = apply_transform(self.transform, data) + result = self.transform(data) if self.transform is not None else data if isinstance(result, dict) or (isinstance(result, list) and isinstance(result[0], dict)): return result - raise AssertionError("With a dict supplied to apply_transform, should return a dict or a list of dicts.") + raise AssertionError("With a dict supplied to Compose, should return a dict or a list of dicts.") class CSVDataset(Dataset): diff --git a/tests/test_profiling.py b/tests/test_profiling.py index 6bee7ba262..649d980ebf 100644 --- a/tests/test_profiling.py +++ b/tests/test_profiling.py @@ -35,6 +35,7 @@ def setUp(self): self.scale = mt.ScaleIntensity() self.scale_call_name = "ScaleIntensity.__call__" + self.compose_call_name = "Compose.__call__" self.test_comp = mt.Compose([mt.ScaleIntensity(), mt.RandAxisFlip(0.5)]) self.test_image = torch.rand(1, 16, 16, 16) self.pid = os.getpid() @@ -82,7 +83,7 @@ def test_profile_multithread(self): self.assertSequenceEqual(batch.shape, (4, 1, 16, 16, 16)) results = wp.get_results() - self.assertSequenceEqual(list(results), [self.scale_call_name]) + self.assertSequenceEqual(list(results), [self.scale_call_name, self.compose_call_name]) prs = results[self.scale_call_name] @@ -98,6 +99,7 @@ def test_profile_context(self): self.scale(self.test_image) results = wp.get_results() + self.assertSequenceEqual(set(results), {"ScaleIntensity.__call__", "context"}) prs = results["context"] From 4a20fd0bcf6c8606973526ff5cc11533548ae361 Mon Sep 17 00:00:00 2001 From: Suraj Pai Date: Thu, 23 May 2024 20:00:51 +0000 Subject: [PATCH 6/9] Delegate handling to Compose Signed-off-by: Suraj Pai --- monai/data/dataset.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index be990f03f1..f08d5ec5ae 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -76,13 +76,10 @@ def __init__(self, data: Sequence, transform: Sequence[Callable] | Callable | No transform: a callable data transform on input data. """ self.data = data - if transform is None: - self.transform = None - else: - try: - self.transform = Compose(transform) if not isinstance(transform, Compose) else transform - except Exception as e: - raise ValueError("`transform` must be a callable or a list of callables that is Composable") from e + try: + self.transform = Compose(transform) if not isinstance(transform, Compose) else transform + except Exception as e: + raise ValueError("`transform` must be a callable or a list of callables that is Composable") from e def __len__(self) -> int: return len(self.data) @@ -92,7 +89,7 @@ def _transform(self, index: int): Fetch single data item from `self.data`. """ data_i = self.data[index] - return self.transform(data_i) if self.transform is not None else data_i + return self.transform(data_i) def __getitem__(self, index: int | slice | Sequence[int]): """ @@ -262,8 +259,6 @@ def __init__( """ super().__init__(data=data, transform=transform) - if self.transform is None: - raise ValueError("transform must not be None when provided to a PersistentDataset") self.cache_dir = Path(cache_dir) if cache_dir is not None else None self.hash_func = hash_func self.pickle_module = pickle_module From 877cb25f7f2329b8c9ab0f16a018a07b56ea6ea9 Mon Sep 17 00:00:00 2001 From: Suraj Pai Date: Sat, 25 May 2024 03:20:42 +0000 Subject: [PATCH 7/9] Update arraydataset test for Compose transform Signed-off-by: Suraj Pai --- tests/test_arraydataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_arraydataset.py b/tests/test_arraydataset.py index efc014a267..b61b3c139c 100644 --- a/tests/test_arraydataset.py +++ b/tests/test_arraydataset.py @@ -41,7 +41,7 @@ class TestCompose(Compose): - def __call__(self, input_, lazy): + def __call__(self, input_, lazy=False): img = self.transforms[0](input_) metadata = img.meta img = self.transforms[1](img) From 10f93658a84e69867432d0fa241c1043e50e302d Mon Sep 17 00:00:00 2001 From: Suraj Pai Date: Wed, 29 May 2024 15:59:12 +0000 Subject: [PATCH 8/9] Update docstring for Dataset transform Signed-off-by: Suraj Pai --- monai/data/dataset.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index f08d5ec5ae..ed7c418b7e 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -73,7 +73,9 @@ def __init__(self, data: Sequence, transform: Sequence[Callable] | Callable | No """ Args: data: input data to load and transform to generate dataset for model. - transform: a callable data transform on input data. + transform: a callable, sequence of callables or None. The transform will be passed + into a `Compose` wrapper by default. Sequence of callables are applied in order and + if `None` is passed, the data is returned as is. """ self.data = data try: From f697630008dfa61389cd9660fd16c4aa4a9b5871 Mon Sep 17 00:00:00 2001 From: Suraj Pai Date: Thu, 30 May 2024 13:04:45 -0400 Subject: [PATCH 9/9] Update monai/data/dataset.py Co-authored-by: Ben Murray Signed-off-by: Suraj Pai --- monai/data/dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index ed7c418b7e..871b523289 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -73,9 +73,9 @@ def __init__(self, data: Sequence, transform: Sequence[Callable] | Callable | No """ Args: data: input data to load and transform to generate dataset for model. - transform: a callable, sequence of callables or None. The transform will be passed - into a `Compose` wrapper by default. Sequence of callables are applied in order and - if `None` is passed, the data is returned as is. + transform: a callable, sequence of callables or None. If transform is not + a `Compose` instance, it will be wrapped in a `Compose` instance. Sequences + of callables are applied in order and if `None` is passed, the data is returned as is. """ self.data = data try: