From 77eecffaaffd2ffb918b801914bb43d1ccdf8eee Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Wed, 19 Jul 2023 11:24:30 +0800 Subject: [PATCH 01/13] add `GDSDataset` Signed-off-by: KumoLiu --- monai/data/dataset.py | 85 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 84 insertions(+), 1 deletion(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 912576bdcc..7557b01422 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -44,7 +44,7 @@ convert_to_contiguous, reset_ops_id, ) -from monai.utils import MAX_SEED, get_seed, look_up_option, min_version, optional_import +from monai.utils import MAX_SEED, get_seed, look_up_option, min_version, optional_import, convert_to_tensor from monai.utils.misc import first if TYPE_CHECKING: @@ -54,8 +54,10 @@ else: tqdm, has_tqdm = optional_import("tqdm", "4.47.0", min_version, "tqdm") +cp, _ = optional_import("cupy") lmdb, _ = optional_import("lmdb") pd, _ = optional_import("pandas") +kvikio_numpy, _ = optional_import("kvikio.numpy") class Dataset(_TorchDataset): @@ -1510,3 +1512,84 @@ def __init__( dfs=dfs, row_indices=row_indices, col_names=col_names, col_types=col_types, col_groups=col_groups, **kwargs ) super().__init__(data=data, transform=transform) + + +class GDSDataset(PersistentDataset): + def __init__( + self, + data: Sequence, + transform: Sequence[Callable] | Callable, + cache_dir: Path | str | None, + hash_func: Callable[..., bytes] = pickle_hashing, + hash_transform: Callable[..., bytes] | None = None, + reset_ops_id: bool = True, + device: int = None, + **kwargs: Any, + ) -> None: + super().__init__( + data=data, + transform=transform, + cache_dir=cache_dir, + hash_func=hash_func, + hash_transform=hash_transform, + reset_ops_id=reset_ops_id, + **kwargs, + ) + self.device = device + + def _cachecheck(self, item_transformed): + """given the input dictionary ``item_transformed``, return a transformed version of it""" + hashfile = None + # compute a cache id + if self.cache_dir is not None: + data_item_md5 = self.hash_func(item_transformed).decode("utf-8") + data_item_md5 += self.transform_hash + hashfile = self.cache_dir / f"{data_item_md5}.pt" + + if hashfile is not None and hashfile.is_file(): # cache hit + with cp.cuda.Device(self.device): + item = {} + for k in item_transformed: + meta_k = torch.load(self.cache_dir / f"{hashfile.name}-{k}-meta") + item[k] = kvikio_numpy.fromfile(f"{hashfile}-{k}", dtype=np.float32, like=cp.empty(())) + item[k] = convert_to_tensor(item[k].reshape(meta_k["shape"]), device=f"cuda:{self.device}") + item[f"{k}_meta_dict"] = meta_k + return item + + # create new cache + _item_transformed = self._pre_transform(deepcopy(item_transformed)) # keep the original hashed + if hashfile is None: + return _item_transformed + + for k in _item_transformed: # {'image': ..., 'label': ...} + _item_transformed_meta = _item_transformed[k].meta + _item_transformed_data = _item_transformed[k].array + _item_transformed_meta["shape"] = _item_transformed_data.shape + kvikio_numpy.tofile(_item_transformed_data, f"{hashfile}-{k}") + try: + # NOTE: Writing to a temporary directory and then using a nearly atomic rename operation + # to make the cache more robust to manual killing of parent process + # which may leave partially written cache files in an incomplete state + with tempfile.TemporaryDirectory() as tmpdirname: + meta_hash_file_name = f"{hashfile.name}-{k}-meta" + meta_hash_file = self.cache_dir / meta_hash_file_name + temp_hash_file = Path(tmpdirname) / meta_hash_file_name + torch.save( + obj=_item_transformed_meta, + f=temp_hash_file, + pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD), + pickle_protocol=self.pickle_protocol, + ) + if temp_hash_file.is_file() and not meta_hash_file.is_file(): + # On Unix, if target exists and is a file, it will be replaced silently if the + # user has permission. + # for more details: https://docs.python.org/3/library/shutil.html#shutil.move. + try: + shutil.move(str(temp_hash_file), meta_hash_file) + except FileExistsError: + pass + except PermissionError: # project-monai/monai issue #3613 + pass + open(hashfile, "a").close() # store cacheid + + return _item_transformed From 98aa1b66de1e85f81f948fb5305c41fa0211be8e Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Tue, 25 Jul 2023 00:42:29 +0800 Subject: [PATCH 02/13] add unittest Signed-off-by: KumoLiu --- monai/data/__init__.py | 1 + monai/data/dataset.py | 121 ++++++++++++++-------- tests/test_gdsdataset.py | 171 ++++++++++++++++++++++++++++++++ tests/test_persistentdataset.py | 147 ++++++++++++++------------- 4 files changed, 327 insertions(+), 113 deletions(-) create mode 100644 tests/test_gdsdataset.py diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 0e9759aaf1..c16caec4d7 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -38,6 +38,7 @@ PersistentDataset, SmartCacheDataset, ZipDataset, + GDSDataset, ) from .dataset_summary import DatasetSummary from .decathlon_datalist import ( diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 7557b01422..caa580fd4b 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -34,6 +34,7 @@ from torch.utils.data import Dataset as _TorchDataset from torch.utils.data import Subset +from monai.data.meta_tensor import MetaTensor from monai.data.utils import SUPPORTED_PICKLE_MOD, convert_tables_to_dicts, pickle_hashing from monai.transforms import ( Compose, @@ -46,6 +47,7 @@ ) from monai.utils import MAX_SEED, get_seed, look_up_option, min_version, optional_import, convert_to_tensor from monai.utils.misc import first +from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor if TYPE_CHECKING: from tqdm import tqdm @@ -328,7 +330,6 @@ def _pre_transform(self, item_transformed): first_random = self.transform.get_index_of_first( lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform) ) - item_transformed = self.transform(item_transformed, end=first_random, threading=True) if self.reset_ops_id: @@ -377,6 +378,8 @@ def _cachecheck(self, item_transformed): """ hashfile = None if self.cache_dir is not None: + if isinstance(item_transformed, np.ndarray): + print('*** Attention ****', item_transformed.dtype, item_transformed.shape) data_item_md5 = self.hash_func(item_transformed).decode("utf-8") data_item_md5 += self.transform_hash hashfile = self.cache_dir / f"{data_item_md5}.pt" @@ -1520,10 +1523,10 @@ def __init__( data: Sequence, transform: Sequence[Callable] | Callable, cache_dir: Path | str | None, + device: int, hash_func: Callable[..., bytes] = pickle_hashing, hash_transform: Callable[..., bytes] | None = None, reset_ops_id: bool = True, - device: int = None, **kwargs: Any, ) -> None: super().__init__( @@ -1546,50 +1549,90 @@ def _cachecheck(self, item_transformed): data_item_md5 += self.transform_hash hashfile = self.cache_dir / f"{data_item_md5}.pt" + # print('cache ', self.cache_dir, hashfile, type(item_transformed), isinstance(item_transformed, (np.ndarray, torch.Tensor))) if hashfile is not None and hashfile.is_file(): # cache hit with cp.cuda.Device(self.device): - item = {} - for k in item_transformed: - meta_k = torch.load(self.cache_dir / f"{hashfile.name}-{k}-meta") - item[k] = kvikio_numpy.fromfile(f"{hashfile}-{k}", dtype=np.float32, like=cp.empty(())) - item[k] = convert_to_tensor(item[k].reshape(meta_k["shape"]), device=f"cuda:{self.device}") - item[f"{k}_meta_dict"] = meta_k - return item + if isinstance(item_transformed, dict): + item = {} + for k in item_transformed: + meta_k = torch.load(self.cache_dir / f"{hashfile.name}-{k}-meta") + item[k] = kvikio_numpy.fromfile(f"{hashfile}-{k}", dtype=meta_k["dtype"], like=cp.empty(())) + item[k] = convert_to_tensor(item[k].reshape(meta_k["shape"]), device=f"cuda:{self.device}") + item[f"{k}_meta_dict"] = meta_k + return item + elif isinstance(item_transformed, (np.ndarray, torch.Tensor)): + _meta = torch.load(self.cache_dir / f"{hashfile.name}-meta") + _data = kvikio_numpy.fromfile(f"{hashfile}", dtype=_meta.pop("dtype"), like=cp.empty(())) + _data = convert_to_tensor(_data.reshape(_meta.pop("shape")), device=f"cuda:{self.device}") + if bool(_meta): + return (_data, _meta) + return _data + else: + item = [] + for i, _item in enumerate(item_transformed): + for k in _item: + meta_i_k = torch.load(self.cache_dir / f"{hashfile.name}-{k}-meta-{i}") + item_k = kvikio_numpy.fromfile(f"{hashfile}-{k}-{i}", dtype=np.float32, like=cp.empty(())) + item_k = convert_to_tensor(item[i].reshape(meta_i_k["shape"]), device=f"cuda:{self.device}") + item[i] = {f"{k}": item_k, f"{k}_meta_dict": meta_k} + return item # create new cache _item_transformed = self._pre_transform(deepcopy(item_transformed)) # keep the original hashed if hashfile is None: return _item_transformed - - for k in _item_transformed: # {'image': ..., 'label': ...} - _item_transformed_meta = _item_transformed[k].meta - _item_transformed_data = _item_transformed[k].array - _item_transformed_meta["shape"] = _item_transformed_data.shape - kvikio_numpy.tofile(_item_transformed_data, f"{hashfile}-{k}") - try: - # NOTE: Writing to a temporary directory and then using a nearly atomic rename operation - # to make the cache more robust to manual killing of parent process - # which may leave partially written cache files in an incomplete state - with tempfile.TemporaryDirectory() as tmpdirname: - meta_hash_file_name = f"{hashfile.name}-{k}-meta" - meta_hash_file = self.cache_dir / meta_hash_file_name - temp_hash_file = Path(tmpdirname) / meta_hash_file_name - torch.save( - obj=_item_transformed_meta, - f=temp_hash_file, - pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD), - pickle_protocol=self.pickle_protocol, - ) - if temp_hash_file.is_file() and not meta_hash_file.is_file(): - # On Unix, if target exists and is a file, it will be replaced silently if the - # user has permission. - # for more details: https://docs.python.org/3/library/shutil.html#shutil.move. - try: - shutil.move(str(temp_hash_file), meta_hash_file) - except FileExistsError: - pass - except PermissionError: # project-monai/monai issue #3613 - pass + if isinstance(_item_transformed, dict): # {"image": ,"label": } + print("*********") + for k in _item_transformed: + data_hashfile = f"{hashfile}-{k}" + meta_hash_file_name = f"{hashfile.name}-{k}-meta" + if isinstance(_item_transformed[k], (np.ndarray, torch.Tensor)): + self._create_new_cache(_item_transformed[k], data_hashfile, meta_hash_file_name) + else: + return _item_transformed + elif isinstance(_item_transformed, (np.ndarray, torch.Tensor)): # [array, {}] + data_hashfile = f"{hashfile}" + meta_hash_file_name = f"{hashfile.name}-meta" + self._create_new_cache(_item_transformed, data_hashfile, meta_hash_file_name) + else: # [{"image": metatensor,"label": metatensor}, {"image": ,"label": }, "image_meta_dict"], [metatensor, metatensor, ] + for i, _item in enumerate(_item_transformed): + for k in _item: + data_hashfile = f"{hashfile}-{k}-{i}" + meta_hash_file_name = f"{hashfile.name}-{k}-meta-{i}" + self._create_new_cache(_item, data_hashfile, meta_hash_file_name) open(hashfile, "a").close() # store cacheid return _item_transformed + + def _create_new_cache(self, data, data_hashfile, meta_hash_file_name): + _item_transformed_meta = data.meta if isinstance(data, MetaTensor) else {} + _item_transformed_data = data.array if isinstance(data, MetaTensor) else data + print(type(_item_transformed_data)) + if isinstance(_item_transformed_data, torch.Tensor): + _item_transformed_data = _item_transformed_data.numpy() + _item_transformed_meta["shape"] = _item_transformed_data.shape + _item_transformed_meta["dtype"] = _item_transformed_data.dtype + kvikio_numpy.tofile(_item_transformed_data, data_hashfile) + try: + # NOTE: Writing to a temporary directory and then using a nearly atomic rename operation + # to make the cache more robust to manual killing of parent process + # which may leave partially written cache files in an incomplete state + with tempfile.TemporaryDirectory() as tmpdirname: + meta_hash_file = self.cache_dir / meta_hash_file_name + temp_hash_file = Path(tmpdirname) / meta_hash_file_name + torch.save( + obj=_item_transformed_meta, + f=temp_hash_file, + pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD), + pickle_protocol=self.pickle_protocol, + ) + if temp_hash_file.is_file() and not meta_hash_file.is_file(): + # On Unix, if target exists and is a file, it will be replaced silently if the + # user has permission. + # for more details: https://docs.python.org/3/library/shutil.html#shutil.move. + try: + shutil.move(str(temp_hash_file), meta_hash_file) + except FileExistsError: + pass + except PermissionError: # project-monai/monai issue #3613 + pass diff --git a/tests/test_gdsdataset.py b/tests/test_gdsdataset.py new file mode 100644 index 0000000000..9a1a946626 --- /dev/null +++ b/tests/test_gdsdataset.py @@ -0,0 +1,171 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import pickle +import tempfile +import unittest + +import nibabel as nib +import numpy as np +from parameterized import parameterized + +from monai.data import GDSDataset, json_hashing, PersistentDataset +from monai.transforms import Compose, Flip, Identity, LoadImaged, SimulateDelayd, Transform +from tests.utils import TEST_NDARRAYS, assert_allclose + +TEST_CASE_1 = [ + Compose( + [ + LoadImaged(keys=["image", "label", "extra"], image_only=True), + SimulateDelayd(keys=["image", "label", "extra"], delay_time=[1e-7, 1e-6, 1e-5]), + ] + ), + (128, 128, 128), +] + +TEST_CASE_2 = [ + [ + LoadImaged(keys=["image", "label", "extra"], image_only=True), + SimulateDelayd(keys=["image", "label", "extra"], delay_time=[1e-7, 1e-6, 1e-5]), + ], + (128, 128, 128), +] + +TEST_CASE_3 = [None, (128, 128, 128)] + + +class _InplaceXform(Transform): + def __call__(self, data): + data[0] = data[0] + 1 + return data + + +class TestDataset(unittest.TestCase): + def test_cache(self): + """testing no inplace change to the hashed item""" + print(TEST_NDARRAYS[:2] + TEST_NDARRAYS[3:]) + for p in TEST_NDARRAYS: + shape = (1, 10, 9, 8) + items = [p(np.arange(0, np.prod(shape)).reshape(shape))] + + print("--- type:", type(items[0])) + with tempfile.TemporaryDirectory() as tempdir: + ds = GDSDataset( + data=items, + transform=_InplaceXform(), + cache_dir=tempdir, + device=0, + pickle_module="pickle", + pickle_protocol=pickle.HIGHEST_PROTOCOL, + ) + ds1 = GDSDataset(items, transform=_InplaceXform(), cache_dir=tempdir, device=0) + assert_allclose(ds[0], ds1[0], type_test=False) + assert_allclose(items[0], p(np.arange(0, np.prod(shape)).reshape(shape))) + + ds2 = GDSDataset(items, transform=_InplaceXform(), cache_dir=tempdir, hash_transform=json_hashing, device=0) + assert_allclose(ds[0], ds2[0], type_test=False) + assert_allclose(items[0], p(np.arange(0, np.prod(shape)).reshape(shape))) + + # @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + # def test_shape(self, transform, expected_shape): + # import torch + # test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4)) + # with tempfile.TemporaryDirectory() as tempdir: + # nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz")) + # nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz")) + # nib.save(test_image, os.path.join(tempdir, "test_extra1.nii.gz")) + # nib.save(test_image, os.path.join(tempdir, "test_image2.nii.gz")) + # nib.save(test_image, os.path.join(tempdir, "test_label2.nii.gz")) + # nib.save(test_image, os.path.join(tempdir, "test_extra2.nii.gz")) + # test_data = [ + # { + # "image": os.path.join(tempdir, "test_image1.nii.gz"), + # "label": os.path.join(tempdir, "test_label1.nii.gz"), + # "extra": os.path.join(tempdir, "test_extra1.nii.gz"), + # }, + # { + # "image": os.path.join(tempdir, "test_image2.nii.gz"), + # "label": os.path.join(tempdir, "test_label2.nii.gz"), + # "extra": os.path.join(tempdir, "test_extra2.nii.gz"), + # }, + # ] + + # cache_dir = os.path.join(os.path.join(tempdir, "cache"), "data") + # dataset_precached = GDSDataset(data=test_data, transform=transform, cache_dir=cache_dir, device=0) + # data1_precached = dataset_precached[0] + # data2_precached = dataset_precached[1] + + # dataset_postcached = GDSDataset(data=test_data, transform=transform, cache_dir=cache_dir, device=0) + # data1_postcached = dataset_postcached[0] + # data2_postcached = dataset_postcached[1] + # data3_postcached = dataset_postcached[0:2] + + # if transform is None: + # self.assertEqual(data1_precached["image"], os.path.join(tempdir, "test_image1.nii.gz")) + # self.assertEqual(data2_precached["label"], os.path.join(tempdir, "test_label2.nii.gz")) + # self.assertEqual(data1_postcached["image"], os.path.join(tempdir, "test_image1.nii.gz")) + # self.assertEqual(data2_postcached["extra"], os.path.join(tempdir, "test_extra2.nii.gz")) + # else: + # self.assertTupleEqual(data1_precached["image"].shape, expected_shape) + # self.assertTupleEqual(data1_precached["label"].shape, expected_shape) + # self.assertTupleEqual(data1_precached["extra"].shape, expected_shape) + # self.assertTupleEqual(data2_precached["image"].shape, expected_shape) + # self.assertTupleEqual(data2_precached["label"].shape, expected_shape) + # self.assertTupleEqual(data2_precached["extra"].shape, expected_shape) + + # self.assertTupleEqual(data1_postcached["image"].shape, expected_shape) + # self.assertTupleEqual(data1_postcached["label"].shape, expected_shape) + # self.assertTupleEqual(data1_postcached["extra"].shape, expected_shape) + # self.assertTupleEqual(data2_postcached["image"].shape, expected_shape) + # self.assertTupleEqual(data2_postcached["label"].shape, expected_shape) + # self.assertTupleEqual(data2_postcached["extra"].shape, expected_shape) + # for d in data3_postcached: + # self.assertTupleEqual(d["image"].shape, expected_shape) + + # # # update the data to cache + # # test_data_new = [ + # # { + # # "image": os.path.join(tempdir, "test_image1_new.nii.gz"), + # # "label": os.path.join(tempdir, "test_label1_new.nii.gz"), + # # "extra": os.path.join(tempdir, "test_extra1_new.nii.gz"), + # # }, + # # { + # # "image": os.path.join(tempdir, "test_image2_new.nii.gz"), + # # "label": os.path.join(tempdir, "test_label2_new.nii.gz"), + # # "extra": os.path.join(tempdir, "test_extra2_new.nii.gz"), + # # }, + # # ] + # # dataset_postcached.set_data(data=test_data_new) + # # # test new exchanged cache content + # # if transform is None: + # # self.assertEqual(dataset_postcached[0]["image"], os.path.join(tempdir, "test_image1_new.nii.gz")) + # # self.assertEqual(dataset_postcached[0]["label"], os.path.join(tempdir, "test_label1_new.nii.gz")) + # # self.assertEqual(dataset_postcached[1]["extra"], os.path.join(tempdir, "test_extra2_new.nii.gz")) + + # def test_different_transforms(self): + # """ + # Different instances of `GDSDataset` with the same cache_dir, + # same input data, but different transforms should give different results. + # """ + # shape = (1, 10, 9, 8) + # im = np.arange(0, np.prod(shape)).reshape(shape) + # with tempfile.TemporaryDirectory() as path: + # im1 = GDSDataset([im], Identity(), cache_dir=path, hash_transform=json_hashing, device=0)[0] + # im2 = GDSDataset([im], Flip(1), cache_dir=path, hash_transform=json_hashing, device=0)[0] + # l2 = ((im1 - im2) ** 2).sum() ** 0.5 + # self.assertTrue(l2 > 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_persistentdataset.py b/tests/test_persistentdataset.py index 1b8245e318..5939eeef87 100644 --- a/tests/test_persistentdataset.py +++ b/tests/test_persistentdataset.py @@ -77,80 +77,79 @@ def test_cache(self): self.assertEqual(list(ds1), list(ds)) self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]]) - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_shape(self, transform, expected_shape): - test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4)) - with tempfile.TemporaryDirectory() as tempdir: - nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_extra1.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_image2.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_label2.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_extra2.nii.gz")) - test_data = [ - { - "image": os.path.join(tempdir, "test_image1.nii.gz"), - "label": os.path.join(tempdir, "test_label1.nii.gz"), - "extra": os.path.join(tempdir, "test_extra1.nii.gz"), - }, - { - "image": os.path.join(tempdir, "test_image2.nii.gz"), - "label": os.path.join(tempdir, "test_label2.nii.gz"), - "extra": os.path.join(tempdir, "test_extra2.nii.gz"), - }, - ] - - cache_dir = os.path.join(os.path.join(tempdir, "cache"), "data") - dataset_precached = PersistentDataset(data=test_data, transform=transform, cache_dir=cache_dir) - data1_precached = dataset_precached[0] - data2_precached = dataset_precached[1] - - dataset_postcached = PersistentDataset(data=test_data, transform=transform, cache_dir=cache_dir) - data1_postcached = dataset_postcached[0] - data2_postcached = dataset_postcached[1] - data3_postcached = dataset_postcached[0:2] - - if transform is None: - self.assertEqual(data1_precached["image"], os.path.join(tempdir, "test_image1.nii.gz")) - self.assertEqual(data2_precached["label"], os.path.join(tempdir, "test_label2.nii.gz")) - self.assertEqual(data1_postcached["image"], os.path.join(tempdir, "test_image1.nii.gz")) - self.assertEqual(data2_postcached["extra"], os.path.join(tempdir, "test_extra2.nii.gz")) - else: - self.assertTupleEqual(data1_precached["image"].shape, expected_shape) - self.assertTupleEqual(data1_precached["label"].shape, expected_shape) - self.assertTupleEqual(data1_precached["extra"].shape, expected_shape) - self.assertTupleEqual(data2_precached["image"].shape, expected_shape) - self.assertTupleEqual(data2_precached["label"].shape, expected_shape) - self.assertTupleEqual(data2_precached["extra"].shape, expected_shape) - - self.assertTupleEqual(data1_postcached["image"].shape, expected_shape) - self.assertTupleEqual(data1_postcached["label"].shape, expected_shape) - self.assertTupleEqual(data1_postcached["extra"].shape, expected_shape) - self.assertTupleEqual(data2_postcached["image"].shape, expected_shape) - self.assertTupleEqual(data2_postcached["label"].shape, expected_shape) - self.assertTupleEqual(data2_postcached["extra"].shape, expected_shape) - for d in data3_postcached: - self.assertTupleEqual(d["image"].shape, expected_shape) - - # update the data to cache - test_data_new = [ - { - "image": os.path.join(tempdir, "test_image1_new.nii.gz"), - "label": os.path.join(tempdir, "test_label1_new.nii.gz"), - "extra": os.path.join(tempdir, "test_extra1_new.nii.gz"), - }, - { - "image": os.path.join(tempdir, "test_image2_new.nii.gz"), - "label": os.path.join(tempdir, "test_label2_new.nii.gz"), - "extra": os.path.join(tempdir, "test_extra2_new.nii.gz"), - }, - ] - dataset_postcached.set_data(data=test_data_new) - # test new exchanged cache content - if transform is None: - self.assertEqual(dataset_postcached[0]["image"], os.path.join(tempdir, "test_image1_new.nii.gz")) - self.assertEqual(dataset_postcached[0]["label"], os.path.join(tempdir, "test_label1_new.nii.gz")) - self.assertEqual(dataset_postcached[1]["extra"], os.path.join(tempdir, "test_extra2_new.nii.gz")) + # @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + # def test_shape(self, transform, expected_shape): + # test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4)) + # with tempfile.TemporaryDirectory() as tempdir: + # nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz")) + # nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz")) + # nib.save(test_image, os.path.join(tempdir, "test_extra1.nii.gz")) + # nib.save(test_image, os.path.join(tempdir, "test_image2.nii.gz")) + # nib.save(test_image, os.path.join(tempdir, "test_label2.nii.gz")) + # nib.save(test_image, os.path.join(tempdir, "test_extra2.nii.gz")) + # test_data = [ + # { + # "image": os.path.join(tempdir, "test_image1.nii.gz"), + # "label": os.path.join(tempdir, "test_label1.nii.gz"), + # "extra": os.path.join(tempdir, "test_extra1.nii.gz"), + # }, + # { + # "image": os.path.join(tempdir, "test_image2.nii.gz"), + # "label": os.path.join(tempdir, "test_label2.nii.gz"), + # "extra": os.path.join(tempdir, "test_extra2.nii.gz"), + # }, + # ] + + # cache_dir = os.path.join(os.path.join(tempdir, "cache"), "data") + # dataset_precached = PersistentDataset(data=test_data, transform=transform, cache_dir=cache_dir) + # data1_precached = dataset_precached[0] + # data2_precached = dataset_precached[1] + # dataset_postcached = PersistentDataset(data=test_data, transform=transform, cache_dir=cache_dir) + # data1_postcached = dataset_postcached[0] + # data2_postcached = dataset_postcached[1] + # data3_postcached = dataset_postcached[0:2] + + # if transform is None: + # self.assertEqual(data1_precached["image"], os.path.join(tempdir, "test_image1.nii.gz")) + # self.assertEqual(data2_precached["label"], os.path.join(tempdir, "test_label2.nii.gz")) + # self.assertEqual(data1_postcached["image"], os.path.join(tempdir, "test_image1.nii.gz")) + # self.assertEqual(data2_postcached["extra"], os.path.join(tempdir, "test_extra2.nii.gz")) + # else: + # self.assertTupleEqual(data1_precached["image"].shape, expected_shape) + # self.assertTupleEqual(data1_precached["label"].shape, expected_shape) + # self.assertTupleEqual(data1_precached["extra"].shape, expected_shape) + # self.assertTupleEqual(data2_precached["image"].shape, expected_shape) + # self.assertTupleEqual(data2_precached["label"].shape, expected_shape) + # self.assertTupleEqual(data2_precached["extra"].shape, expected_shape) + + # self.assertTupleEqual(data1_postcached["image"].shape, expected_shape) + # self.assertTupleEqual(data1_postcached["label"].shape, expected_shape) + # self.assertTupleEqual(data1_postcached["extra"].shape, expected_shape) + # self.assertTupleEqual(data2_postcached["image"].shape, expected_shape) + # self.assertTupleEqual(data2_postcached["label"].shape, expected_shape) + # self.assertTupleEqual(data2_postcached["extra"].shape, expected_shape) + # for d in data3_postcached: + # self.assertTupleEqual(d["image"].shape, expected_shape) + + # # update the data to cache + # test_data_new = [ + # { + # "image": os.path.join(tempdir, "test_image1_new.nii.gz"), + # "label": os.path.join(tempdir, "test_label1_new.nii.gz"), + # "extra": os.path.join(tempdir, "test_extra1_new.nii.gz"), + # }, + # { + # "image": os.path.join(tempdir, "test_image2_new.nii.gz"), + # "label": os.path.join(tempdir, "test_label2_new.nii.gz"), + # "extra": os.path.join(tempdir, "test_extra2_new.nii.gz"), + # }, + # ] + # dataset_postcached.set_data(data=test_data_new) + # # test new exchanged cache content + # if transform is None: + # self.assertEqual(dataset_postcached[0]["image"], os.path.join(tempdir, "test_image1_new.nii.gz")) + # self.assertEqual(dataset_postcached[0]["label"], os.path.join(tempdir, "test_label1_new.nii.gz")) + # self.assertEqual(dataset_postcached[1]["extra"], os.path.join(tempdir, "test_extra2_new.nii.gz")) def test_different_transforms(self): """ From 0ccac952c4a87e35381ece90ae21a1f4c35e368b Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Wed, 26 Jul 2023 15:45:39 +0800 Subject: [PATCH 03/13] update unittests Signed-off-by: KumoLiu --- monai/data/dataset.py | 4 - tests/test_gdsdataset.py | 200 +++++++++++++++++--------------- tests/test_persistentdataset.py | 147 +++++++++++------------ 3 files changed, 181 insertions(+), 170 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index caa580fd4b..962de97fa5 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -1549,7 +1549,6 @@ def _cachecheck(self, item_transformed): data_item_md5 += self.transform_hash hashfile = self.cache_dir / f"{data_item_md5}.pt" - # print('cache ', self.cache_dir, hashfile, type(item_transformed), isinstance(item_transformed, (np.ndarray, torch.Tensor))) if hashfile is not None and hashfile.is_file(): # cache hit with cp.cuda.Device(self.device): if isinstance(item_transformed, dict): @@ -1582,7 +1581,6 @@ def _cachecheck(self, item_transformed): if hashfile is None: return _item_transformed if isinstance(_item_transformed, dict): # {"image": ,"label": } - print("*********") for k in _item_transformed: data_hashfile = f"{hashfile}-{k}" meta_hash_file_name = f"{hashfile.name}-{k}-meta" @@ -1601,13 +1599,11 @@ def _cachecheck(self, item_transformed): meta_hash_file_name = f"{hashfile.name}-{k}-meta-{i}" self._create_new_cache(_item, data_hashfile, meta_hash_file_name) open(hashfile, "a").close() # store cacheid - return _item_transformed def _create_new_cache(self, data, data_hashfile, meta_hash_file_name): _item_transformed_meta = data.meta if isinstance(data, MetaTensor) else {} _item_transformed_data = data.array if isinstance(data, MetaTensor) else data - print(type(_item_transformed_data)) if isinstance(_item_transformed_data, torch.Tensor): _item_transformed_data = _item_transformed_data.numpy() _item_transformed_meta["shape"] = _item_transformed_data.shape diff --git a/tests/test_gdsdataset.py b/tests/test_gdsdataset.py index 9a1a946626..c643b238e3 100644 --- a/tests/test_gdsdataset.py +++ b/tests/test_gdsdataset.py @@ -54,12 +54,10 @@ def __call__(self, data): class TestDataset(unittest.TestCase): def test_cache(self): """testing no inplace change to the hashed item""" - print(TEST_NDARRAYS[:2] + TEST_NDARRAYS[3:]) - for p in TEST_NDARRAYS: + for p in TEST_NDARRAYS[:2]: shape = (1, 10, 9, 8) items = [p(np.arange(0, np.prod(shape)).reshape(shape))] - print("--- type:", type(items[0])) with tempfile.TemporaryDirectory() as tempdir: ds = GDSDataset( data=items, @@ -69,102 +67,118 @@ def test_cache(self): pickle_module="pickle", pickle_protocol=pickle.HIGHEST_PROTOCOL, ) + assert_allclose(items[0], p(np.arange(0, np.prod(shape)).reshape(shape))) ds1 = GDSDataset(items, transform=_InplaceXform(), cache_dir=tempdir, device=0) assert_allclose(ds[0], ds1[0], type_test=False) assert_allclose(items[0], p(np.arange(0, np.prod(shape)).reshape(shape))) - ds2 = GDSDataset(items, transform=_InplaceXform(), cache_dir=tempdir, hash_transform=json_hashing, device=0) - assert_allclose(ds[0], ds2[0], type_test=False) + ds = GDSDataset(items, transform=_InplaceXform(), cache_dir=tempdir, hash_transform=json_hashing, device=0) + assert_allclose(items[0], p(np.arange(0, np.prod(shape)).reshape(shape))) + ds1 = GDSDataset(items, transform=_InplaceXform(), cache_dir=tempdir, hash_transform=json_hashing, device=0) + assert_allclose(ds[0], ds1[0], type_test=False) assert_allclose(items[0], p(np.arange(0, np.prod(shape)).reshape(shape))) - # @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - # def test_shape(self, transform, expected_shape): - # import torch - # test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4)) - # with tempfile.TemporaryDirectory() as tempdir: - # nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz")) - # nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz")) - # nib.save(test_image, os.path.join(tempdir, "test_extra1.nii.gz")) - # nib.save(test_image, os.path.join(tempdir, "test_image2.nii.gz")) - # nib.save(test_image, os.path.join(tempdir, "test_label2.nii.gz")) - # nib.save(test_image, os.path.join(tempdir, "test_extra2.nii.gz")) - # test_data = [ - # { - # "image": os.path.join(tempdir, "test_image1.nii.gz"), - # "label": os.path.join(tempdir, "test_label1.nii.gz"), - # "extra": os.path.join(tempdir, "test_extra1.nii.gz"), - # }, - # { - # "image": os.path.join(tempdir, "test_image2.nii.gz"), - # "label": os.path.join(tempdir, "test_label2.nii.gz"), - # "extra": os.path.join(tempdir, "test_extra2.nii.gz"), - # }, - # ] - - # cache_dir = os.path.join(os.path.join(tempdir, "cache"), "data") - # dataset_precached = GDSDataset(data=test_data, transform=transform, cache_dir=cache_dir, device=0) - # data1_precached = dataset_precached[0] - # data2_precached = dataset_precached[1] - - # dataset_postcached = GDSDataset(data=test_data, transform=transform, cache_dir=cache_dir, device=0) - # data1_postcached = dataset_postcached[0] - # data2_postcached = dataset_postcached[1] - # data3_postcached = dataset_postcached[0:2] - - # if transform is None: - # self.assertEqual(data1_precached["image"], os.path.join(tempdir, "test_image1.nii.gz")) - # self.assertEqual(data2_precached["label"], os.path.join(tempdir, "test_label2.nii.gz")) - # self.assertEqual(data1_postcached["image"], os.path.join(tempdir, "test_image1.nii.gz")) - # self.assertEqual(data2_postcached["extra"], os.path.join(tempdir, "test_extra2.nii.gz")) - # else: - # self.assertTupleEqual(data1_precached["image"].shape, expected_shape) - # self.assertTupleEqual(data1_precached["label"].shape, expected_shape) - # self.assertTupleEqual(data1_precached["extra"].shape, expected_shape) - # self.assertTupleEqual(data2_precached["image"].shape, expected_shape) - # self.assertTupleEqual(data2_precached["label"].shape, expected_shape) - # self.assertTupleEqual(data2_precached["extra"].shape, expected_shape) - - # self.assertTupleEqual(data1_postcached["image"].shape, expected_shape) - # self.assertTupleEqual(data1_postcached["label"].shape, expected_shape) - # self.assertTupleEqual(data1_postcached["extra"].shape, expected_shape) - # self.assertTupleEqual(data2_postcached["image"].shape, expected_shape) - # self.assertTupleEqual(data2_postcached["label"].shape, expected_shape) - # self.assertTupleEqual(data2_postcached["extra"].shape, expected_shape) - # for d in data3_postcached: - # self.assertTupleEqual(d["image"].shape, expected_shape) - - # # # update the data to cache - # # test_data_new = [ - # # { - # # "image": os.path.join(tempdir, "test_image1_new.nii.gz"), - # # "label": os.path.join(tempdir, "test_label1_new.nii.gz"), - # # "extra": os.path.join(tempdir, "test_extra1_new.nii.gz"), - # # }, - # # { - # # "image": os.path.join(tempdir, "test_image2_new.nii.gz"), - # # "label": os.path.join(tempdir, "test_label2_new.nii.gz"), - # # "extra": os.path.join(tempdir, "test_extra2_new.nii.gz"), - # # }, - # # ] - # # dataset_postcached.set_data(data=test_data_new) - # # # test new exchanged cache content - # # if transform is None: - # # self.assertEqual(dataset_postcached[0]["image"], os.path.join(tempdir, "test_image1_new.nii.gz")) - # # self.assertEqual(dataset_postcached[0]["label"], os.path.join(tempdir, "test_label1_new.nii.gz")) - # # self.assertEqual(dataset_postcached[1]["extra"], os.path.join(tempdir, "test_extra2_new.nii.gz")) - - # def test_different_transforms(self): - # """ - # Different instances of `GDSDataset` with the same cache_dir, - # same input data, but different transforms should give different results. - # """ - # shape = (1, 10, 9, 8) - # im = np.arange(0, np.prod(shape)).reshape(shape) - # with tempfile.TemporaryDirectory() as path: - # im1 = GDSDataset([im], Identity(), cache_dir=path, hash_transform=json_hashing, device=0)[0] - # im2 = GDSDataset([im], Flip(1), cache_dir=path, hash_transform=json_hashing, device=0)[0] - # l2 = ((im1 - im2) ** 2).sum() ** 0.5 - # self.assertTrue(l2 > 1) + def test_metatensor(self): + shape = (1, 10, 9, 8) + items = [TEST_NDARRAYS[-1](np.arange(0, np.prod(shape)).reshape(shape))] + with tempfile.TemporaryDirectory() as tempdir: + ds = GDSDataset( + data=items, + transform=_InplaceXform(), + cache_dir=tempdir, + device=0, + pickle_module="pickle", + pickle_protocol=pickle.HIGHEST_PROTOCOL, + ) + assert_allclose(ds[0], ds[0][0], type_test=False) + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_shape(self, transform, expected_shape): + test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4)) + with tempfile.TemporaryDirectory() as tempdir: + nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_extra1.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_image2.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_label2.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_extra2.nii.gz")) + test_data = [ + { + "image": os.path.join(tempdir, "test_image1.nii.gz"), + "label": os.path.join(tempdir, "test_label1.nii.gz"), + "extra": os.path.join(tempdir, "test_extra1.nii.gz"), + }, + { + "image": os.path.join(tempdir, "test_image2.nii.gz"), + "label": os.path.join(tempdir, "test_label2.nii.gz"), + "extra": os.path.join(tempdir, "test_extra2.nii.gz"), + }, + ] + + cache_dir = os.path.join(os.path.join(tempdir, "cache"), "data") + dataset_precached = GDSDataset(data=test_data, transform=transform, cache_dir=cache_dir, device=0) + data1_precached = dataset_precached[0] + data2_precached = dataset_precached[1] + + dataset_postcached = GDSDataset(data=test_data, transform=transform, cache_dir=cache_dir, device=0) + data1_postcached = dataset_postcached[0] + data2_postcached = dataset_postcached[1] + data3_postcached = dataset_postcached[0:2] + + if transform is None: + self.assertEqual(data1_precached["image"], os.path.join(tempdir, "test_image1.nii.gz")) + self.assertEqual(data2_precached["label"], os.path.join(tempdir, "test_label2.nii.gz")) + self.assertEqual(data1_postcached["image"], os.path.join(tempdir, "test_image1.nii.gz")) + self.assertEqual(data2_postcached["extra"], os.path.join(tempdir, "test_extra2.nii.gz")) + else: + self.assertTupleEqual(data1_precached["image"].shape, expected_shape) + self.assertTupleEqual(data1_precached["label"].shape, expected_shape) + self.assertTupleEqual(data1_precached["extra"].shape, expected_shape) + self.assertTupleEqual(data2_precached["image"].shape, expected_shape) + self.assertTupleEqual(data2_precached["label"].shape, expected_shape) + self.assertTupleEqual(data2_precached["extra"].shape, expected_shape) + + self.assertTupleEqual(data1_postcached["image"].shape, expected_shape) + self.assertTupleEqual(data1_postcached["label"].shape, expected_shape) + self.assertTupleEqual(data1_postcached["extra"].shape, expected_shape) + self.assertTupleEqual(data2_postcached["image"].shape, expected_shape) + self.assertTupleEqual(data2_postcached["label"].shape, expected_shape) + self.assertTupleEqual(data2_postcached["extra"].shape, expected_shape) + for d in data3_postcached: + self.assertTupleEqual(d["image"].shape, expected_shape) + + # update the data to cache + test_data_new = [ + { + "image": os.path.join(tempdir, "test_image1_new.nii.gz"), + "label": os.path.join(tempdir, "test_label1_new.nii.gz"), + "extra": os.path.join(tempdir, "test_extra1_new.nii.gz"), + }, + { + "image": os.path.join(tempdir, "test_image2_new.nii.gz"), + "label": os.path.join(tempdir, "test_label2_new.nii.gz"), + "extra": os.path.join(tempdir, "test_extra2_new.nii.gz"), + }, + ] + dataset_postcached.set_data(data=test_data_new) + # test new exchanged cache content + if transform is None: + self.assertEqual(dataset_postcached[0]["image"], os.path.join(tempdir, "test_image1_new.nii.gz")) + self.assertEqual(dataset_postcached[0]["label"], os.path.join(tempdir, "test_label1_new.nii.gz")) + self.assertEqual(dataset_postcached[1]["extra"], os.path.join(tempdir, "test_extra2_new.nii.gz")) + + def test_different_transforms(self): + """ + Different instances of `GDSDataset` with the same cache_dir, + same input data, but different transforms should give different results. + """ + shape = (1, 10, 9, 8) + im = np.arange(0, np.prod(shape)).reshape(shape) + with tempfile.TemporaryDirectory() as path: + im1 = GDSDataset([im], Identity(), cache_dir=path, hash_transform=json_hashing, device=0)[0] + im2 = GDSDataset([im], Flip(1), cache_dir=path, hash_transform=json_hashing, device=0)[0] + l2 = ((im1 - im2) ** 2).sum() ** 0.5 + self.assertTrue(l2 > 1) if __name__ == "__main__": diff --git a/tests/test_persistentdataset.py b/tests/test_persistentdataset.py index 5939eeef87..1b8245e318 100644 --- a/tests/test_persistentdataset.py +++ b/tests/test_persistentdataset.py @@ -77,79 +77,80 @@ def test_cache(self): self.assertEqual(list(ds1), list(ds)) self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]]) - # @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - # def test_shape(self, transform, expected_shape): - # test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4)) - # with tempfile.TemporaryDirectory() as tempdir: - # nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz")) - # nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz")) - # nib.save(test_image, os.path.join(tempdir, "test_extra1.nii.gz")) - # nib.save(test_image, os.path.join(tempdir, "test_image2.nii.gz")) - # nib.save(test_image, os.path.join(tempdir, "test_label2.nii.gz")) - # nib.save(test_image, os.path.join(tempdir, "test_extra2.nii.gz")) - # test_data = [ - # { - # "image": os.path.join(tempdir, "test_image1.nii.gz"), - # "label": os.path.join(tempdir, "test_label1.nii.gz"), - # "extra": os.path.join(tempdir, "test_extra1.nii.gz"), - # }, - # { - # "image": os.path.join(tempdir, "test_image2.nii.gz"), - # "label": os.path.join(tempdir, "test_label2.nii.gz"), - # "extra": os.path.join(tempdir, "test_extra2.nii.gz"), - # }, - # ] - - # cache_dir = os.path.join(os.path.join(tempdir, "cache"), "data") - # dataset_precached = PersistentDataset(data=test_data, transform=transform, cache_dir=cache_dir) - # data1_precached = dataset_precached[0] - # data2_precached = dataset_precached[1] - # dataset_postcached = PersistentDataset(data=test_data, transform=transform, cache_dir=cache_dir) - # data1_postcached = dataset_postcached[0] - # data2_postcached = dataset_postcached[1] - # data3_postcached = dataset_postcached[0:2] - - # if transform is None: - # self.assertEqual(data1_precached["image"], os.path.join(tempdir, "test_image1.nii.gz")) - # self.assertEqual(data2_precached["label"], os.path.join(tempdir, "test_label2.nii.gz")) - # self.assertEqual(data1_postcached["image"], os.path.join(tempdir, "test_image1.nii.gz")) - # self.assertEqual(data2_postcached["extra"], os.path.join(tempdir, "test_extra2.nii.gz")) - # else: - # self.assertTupleEqual(data1_precached["image"].shape, expected_shape) - # self.assertTupleEqual(data1_precached["label"].shape, expected_shape) - # self.assertTupleEqual(data1_precached["extra"].shape, expected_shape) - # self.assertTupleEqual(data2_precached["image"].shape, expected_shape) - # self.assertTupleEqual(data2_precached["label"].shape, expected_shape) - # self.assertTupleEqual(data2_precached["extra"].shape, expected_shape) - - # self.assertTupleEqual(data1_postcached["image"].shape, expected_shape) - # self.assertTupleEqual(data1_postcached["label"].shape, expected_shape) - # self.assertTupleEqual(data1_postcached["extra"].shape, expected_shape) - # self.assertTupleEqual(data2_postcached["image"].shape, expected_shape) - # self.assertTupleEqual(data2_postcached["label"].shape, expected_shape) - # self.assertTupleEqual(data2_postcached["extra"].shape, expected_shape) - # for d in data3_postcached: - # self.assertTupleEqual(d["image"].shape, expected_shape) - - # # update the data to cache - # test_data_new = [ - # { - # "image": os.path.join(tempdir, "test_image1_new.nii.gz"), - # "label": os.path.join(tempdir, "test_label1_new.nii.gz"), - # "extra": os.path.join(tempdir, "test_extra1_new.nii.gz"), - # }, - # { - # "image": os.path.join(tempdir, "test_image2_new.nii.gz"), - # "label": os.path.join(tempdir, "test_label2_new.nii.gz"), - # "extra": os.path.join(tempdir, "test_extra2_new.nii.gz"), - # }, - # ] - # dataset_postcached.set_data(data=test_data_new) - # # test new exchanged cache content - # if transform is None: - # self.assertEqual(dataset_postcached[0]["image"], os.path.join(tempdir, "test_image1_new.nii.gz")) - # self.assertEqual(dataset_postcached[0]["label"], os.path.join(tempdir, "test_label1_new.nii.gz")) - # self.assertEqual(dataset_postcached[1]["extra"], os.path.join(tempdir, "test_extra2_new.nii.gz")) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_shape(self, transform, expected_shape): + test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4)) + with tempfile.TemporaryDirectory() as tempdir: + nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_extra1.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_image2.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_label2.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_extra2.nii.gz")) + test_data = [ + { + "image": os.path.join(tempdir, "test_image1.nii.gz"), + "label": os.path.join(tempdir, "test_label1.nii.gz"), + "extra": os.path.join(tempdir, "test_extra1.nii.gz"), + }, + { + "image": os.path.join(tempdir, "test_image2.nii.gz"), + "label": os.path.join(tempdir, "test_label2.nii.gz"), + "extra": os.path.join(tempdir, "test_extra2.nii.gz"), + }, + ] + + cache_dir = os.path.join(os.path.join(tempdir, "cache"), "data") + dataset_precached = PersistentDataset(data=test_data, transform=transform, cache_dir=cache_dir) + data1_precached = dataset_precached[0] + data2_precached = dataset_precached[1] + + dataset_postcached = PersistentDataset(data=test_data, transform=transform, cache_dir=cache_dir) + data1_postcached = dataset_postcached[0] + data2_postcached = dataset_postcached[1] + data3_postcached = dataset_postcached[0:2] + + if transform is None: + self.assertEqual(data1_precached["image"], os.path.join(tempdir, "test_image1.nii.gz")) + self.assertEqual(data2_precached["label"], os.path.join(tempdir, "test_label2.nii.gz")) + self.assertEqual(data1_postcached["image"], os.path.join(tempdir, "test_image1.nii.gz")) + self.assertEqual(data2_postcached["extra"], os.path.join(tempdir, "test_extra2.nii.gz")) + else: + self.assertTupleEqual(data1_precached["image"].shape, expected_shape) + self.assertTupleEqual(data1_precached["label"].shape, expected_shape) + self.assertTupleEqual(data1_precached["extra"].shape, expected_shape) + self.assertTupleEqual(data2_precached["image"].shape, expected_shape) + self.assertTupleEqual(data2_precached["label"].shape, expected_shape) + self.assertTupleEqual(data2_precached["extra"].shape, expected_shape) + + self.assertTupleEqual(data1_postcached["image"].shape, expected_shape) + self.assertTupleEqual(data1_postcached["label"].shape, expected_shape) + self.assertTupleEqual(data1_postcached["extra"].shape, expected_shape) + self.assertTupleEqual(data2_postcached["image"].shape, expected_shape) + self.assertTupleEqual(data2_postcached["label"].shape, expected_shape) + self.assertTupleEqual(data2_postcached["extra"].shape, expected_shape) + for d in data3_postcached: + self.assertTupleEqual(d["image"].shape, expected_shape) + + # update the data to cache + test_data_new = [ + { + "image": os.path.join(tempdir, "test_image1_new.nii.gz"), + "label": os.path.join(tempdir, "test_label1_new.nii.gz"), + "extra": os.path.join(tempdir, "test_extra1_new.nii.gz"), + }, + { + "image": os.path.join(tempdir, "test_image2_new.nii.gz"), + "label": os.path.join(tempdir, "test_label2_new.nii.gz"), + "extra": os.path.join(tempdir, "test_extra2_new.nii.gz"), + }, + ] + dataset_postcached.set_data(data=test_data_new) + # test new exchanged cache content + if transform is None: + self.assertEqual(dataset_postcached[0]["image"], os.path.join(tempdir, "test_image1_new.nii.gz")) + self.assertEqual(dataset_postcached[0]["label"], os.path.join(tempdir, "test_label1_new.nii.gz")) + self.assertEqual(dataset_postcached[1]["extra"], os.path.join(tempdir, "test_extra2_new.nii.gz")) def test_different_transforms(self): """ From e6837f96e4f8e9d858a220f57f43aad872628fa9 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Wed, 26 Jul 2023 15:49:03 +0800 Subject: [PATCH 04/13] minor fix Signed-off-by: KumoLiu --- monai/data/dataset.py | 1 - tests/test_gdsdataset.py | 6 +++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 962de97fa5..27b5930443 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -47,7 +47,6 @@ ) from monai.utils import MAX_SEED, get_seed, look_up_option, min_version, optional_import, convert_to_tensor from monai.utils.misc import first -from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor if TYPE_CHECKING: from tqdm import tqdm diff --git a/tests/test_gdsdataset.py b/tests/test_gdsdataset.py index c643b238e3..b9827698f3 100644 --- a/tests/test_gdsdataset.py +++ b/tests/test_gdsdataset.py @@ -20,10 +20,13 @@ import numpy as np from parameterized import parameterized -from monai.data import GDSDataset, json_hashing, PersistentDataset +from monai.data import GDSDataset, json_hashing +from monai.utils import optional_import from monai.transforms import Compose, Flip, Identity, LoadImaged, SimulateDelayd, Transform from tests.utils import TEST_NDARRAYS, assert_allclose +_, has_kvikio_numpy = optional_import("kvikio.numpy") + TEST_CASE_1 = [ Compose( [ @@ -51,6 +54,7 @@ def __call__(self, data): return data +@unittest.skipUnless(has_kvikio_numpy, "Requires scikit-image library.") class TestDataset(unittest.TestCase): def test_cache(self): """testing no inplace change to the hashed item""" From 203f4f8eec033c7561760d652dc2d42c080e4d34 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Wed, 26 Jul 2023 16:08:05 +0800 Subject: [PATCH 05/13] add docstring Signed-off-by: KumoLiu --- monai/data/dataset.py | 64 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 60 insertions(+), 4 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 27b5930443..e4794cfd17 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -1517,6 +1517,13 @@ def __init__( class GDSDataset(PersistentDataset): + """ + Re-implementation of the PersistentDataset. GDSDataset enables a direct direct memory access(DMA) data path between + GPU memory and storage, thus avoiding a bounce buffer through the CPU. This direct path can increase system + bandwidth while decreasing latency and utilization load on the CPU and GPU. + + A tutorial is available: https://github.com/Project-MONAI/tutorials/blob/main/modules/GDS_dataset.ipynb. + """ def __init__( self, data: Sequence, @@ -1528,6 +1535,40 @@ def __init__( reset_ops_id: bool = True, **kwargs: Any, ) -> None: + """ + Args: + data: input data file paths to load and transform to generate dataset for model. + `GDSDataset` expects input data to be a list of serializable + and hashes them as cache keys using `hash_func`. + transform: transforms to execute operations on input data. + cache_dir: If specified, this is the location for gpu direct storage + of pre-computed transformed data tensors. The cache_dir is computed once, and + persists on disk until explicitly removed. Different runs, programs, experiments + may share a common cache dir provided that the transforms pre-processing is consistent. + If `cache_dir` doesn't exist, will automatically create it. + If `cache_dir` is `None`, there is effectively no caching. + device: target device to put the output Tensor data. Note that only int can be used to + specify the gpu to be used. + hash_func: a callable to compute hash from data items to be cached. + defaults to `monai.data.utils.pickle_hashing`. + pickle_module: string representing the module used for pickling metadata and objects, + default to `"pickle"`. due to the pickle limitation in multi-processing of Dataloader, + we can't use `pickle` as arg directly, so here we use a string name instead. + if want to use other pickle module at runtime, just register like: + >>> from monai.data import utils + >>> utils.SUPPORTED_PICKLE_MOD["test"] = other_pickle + this arg is used by `torch.save`, for more details, please check: + https://pytorch.org/docs/stable/generated/torch.save.html#torch.save, + and ``monai.data.utils.SUPPORTED_PICKLE_MOD``. + hash_transform: a callable to compute hash from the transform information when caching. + This may reduce errors due to transforms changing during experiments. Default to None (no hash). + Other options are `pickle_hashing` and `json_hashing` functions from `monai.data.utils`. + reset_ops_id: whether to set `TraceKeys.ID` to ``Tracekys.NONE``, defaults to ``True``. + When this is enabled, the traced transform instance IDs will be removed from the cached MetaTensors. + This is useful for skipping the transform instance checks when inverting applied operations + using the cached content and with re-created transform instances. + + """ super().__init__( data=data, transform=transform, @@ -1540,7 +1581,22 @@ def __init__( self.device = device def _cachecheck(self, item_transformed): - """given the input dictionary ``item_transformed``, return a transformed version of it""" + """ + In order to enable direct storage to the GPU when loading the hashfile, rewritten this function. + Note that in this function, it will always return `torch.Tensor` when load data from cache. + + Args: + item_transformed: The current data element to be mutated into transformed representation + + Returns: + The transformed data_element, either from cache, or explicitly computing it. + + Warning: + The current implementation does not encode transform information as part of the + hashing mechanism used for generating cache names when `hash_transform` is None. + If the transforms applied are changed in any way, the objects in the cache dir will be invalid. + + """ hashfile = None # compute a cache id if self.cache_dir is not None: @@ -1579,7 +1635,7 @@ def _cachecheck(self, item_transformed): _item_transformed = self._pre_transform(deepcopy(item_transformed)) # keep the original hashed if hashfile is None: return _item_transformed - if isinstance(_item_transformed, dict): # {"image": ,"label": } + if isinstance(_item_transformed, dict): for k in _item_transformed: data_hashfile = f"{hashfile}-{k}" meta_hash_file_name = f"{hashfile.name}-{k}-meta" @@ -1587,11 +1643,11 @@ def _cachecheck(self, item_transformed): self._create_new_cache(_item_transformed[k], data_hashfile, meta_hash_file_name) else: return _item_transformed - elif isinstance(_item_transformed, (np.ndarray, torch.Tensor)): # [array, {}] + elif isinstance(_item_transformed, (np.ndarray, torch.Tensor)): data_hashfile = f"{hashfile}" meta_hash_file_name = f"{hashfile.name}-meta" self._create_new_cache(_item_transformed, data_hashfile, meta_hash_file_name) - else: # [{"image": metatensor,"label": metatensor}, {"image": ,"label": }, "image_meta_dict"], [metatensor, metatensor, ] + else: for i, _item in enumerate(_item_transformed): for k in _item: data_hashfile = f"{hashfile}-{k}-{i}" From fa50f5c7d0cf0c2a4e367f384dbcd20e8137a84d Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Wed, 26 Jul 2023 16:10:03 +0800 Subject: [PATCH 06/13] fix flake8 Signed-off-by: KumoLiu --- monai/data/__init__.py | 2 +- monai/data/dataset.py | 5 +++-- tests/test_gdsdataset.py | 10 +++++++--- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/monai/data/__init__.py b/monai/data/__init__.py index a98cbdafe5..340c5eb8fa 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -33,12 +33,12 @@ CSVDataset, Dataset, DatasetFunc, + GDSDataset, LMDBDataset, NPZDictItemDataset, PersistentDataset, SmartCacheDataset, ZipDataset, - GDSDataset, ) from .dataset_summary import DatasetSummary from .decathlon_datalist import ( diff --git a/monai/data/dataset.py b/monai/data/dataset.py index e4794cfd17..0a1ef06dc2 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -45,7 +45,7 @@ convert_to_contiguous, reset_ops_id, ) -from monai.utils import MAX_SEED, get_seed, look_up_option, min_version, optional_import, convert_to_tensor +from monai.utils import MAX_SEED, convert_to_tensor, get_seed, look_up_option, min_version, optional_import from monai.utils.misc import first if TYPE_CHECKING: @@ -378,7 +378,7 @@ def _cachecheck(self, item_transformed): hashfile = None if self.cache_dir is not None: if isinstance(item_transformed, np.ndarray): - print('*** Attention ****', item_transformed.dtype, item_transformed.shape) + print("*** Attention ****", item_transformed.dtype, item_transformed.shape) data_item_md5 = self.hash_func(item_transformed).decode("utf-8") data_item_md5 += self.transform_hash hashfile = self.cache_dir / f"{data_item_md5}.pt" @@ -1524,6 +1524,7 @@ class GDSDataset(PersistentDataset): A tutorial is available: https://github.com/Project-MONAI/tutorials/blob/main/modules/GDS_dataset.ipynb. """ + def __init__( self, data: Sequence, diff --git a/tests/test_gdsdataset.py b/tests/test_gdsdataset.py index b9827698f3..f1c6a97453 100644 --- a/tests/test_gdsdataset.py +++ b/tests/test_gdsdataset.py @@ -21,8 +21,8 @@ from parameterized import parameterized from monai.data import GDSDataset, json_hashing -from monai.utils import optional_import from monai.transforms import Compose, Flip, Identity, LoadImaged, SimulateDelayd, Transform +from monai.utils import optional_import from tests.utils import TEST_NDARRAYS, assert_allclose _, has_kvikio_numpy = optional_import("kvikio.numpy") @@ -76,9 +76,13 @@ def test_cache(self): assert_allclose(ds[0], ds1[0], type_test=False) assert_allclose(items[0], p(np.arange(0, np.prod(shape)).reshape(shape))) - ds = GDSDataset(items, transform=_InplaceXform(), cache_dir=tempdir, hash_transform=json_hashing, device=0) + ds = GDSDataset( + items, transform=_InplaceXform(), cache_dir=tempdir, hash_transform=json_hashing, device=0 + ) assert_allclose(items[0], p(np.arange(0, np.prod(shape)).reshape(shape))) - ds1 = GDSDataset(items, transform=_InplaceXform(), cache_dir=tempdir, hash_transform=json_hashing, device=0) + ds1 = GDSDataset( + items, transform=_InplaceXform(), cache_dir=tempdir, hash_transform=json_hashing, device=0 + ) assert_allclose(ds[0], ds1[0], type_test=False) assert_allclose(items[0], p(np.arange(0, np.prod(shape)).reshape(shape))) From 96758c0ddee309aab450ffab4e2728fe64eb52ef Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Wed, 26 Jul 2023 16:27:07 +0800 Subject: [PATCH 07/13] fix mypy Signed-off-by: KumoLiu --- monai/data/dataset.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 0a1ef06dc2..b48ebbefa8 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -377,8 +377,6 @@ def _cachecheck(self, item_transformed): """ hashfile = None if self.cache_dir is not None: - if isinstance(item_transformed, np.ndarray): - print("*** Attention ****", item_transformed.dtype, item_transformed.shape) data_item_md5 = self.hash_func(item_transformed).decode("utf-8") data_item_md5 += self.transform_hash hashfile = self.cache_dir / f"{data_item_md5}.pt" @@ -1608,28 +1606,28 @@ def _cachecheck(self, item_transformed): if hashfile is not None and hashfile.is_file(): # cache hit with cp.cuda.Device(self.device): if isinstance(item_transformed, dict): - item = {} + item: dict[Any, Any] = {} # type:ignore for k in item_transformed: - meta_k = torch.load(self.cache_dir / f"{hashfile.name}-{k}-meta") + meta_k = torch.load(self.cache_dir / f"{hashfile.name}-{k}-meta") # type:ignore item[k] = kvikio_numpy.fromfile(f"{hashfile}-{k}", dtype=meta_k["dtype"], like=cp.empty(())) item[k] = convert_to_tensor(item[k].reshape(meta_k["shape"]), device=f"cuda:{self.device}") item[f"{k}_meta_dict"] = meta_k return item elif isinstance(item_transformed, (np.ndarray, torch.Tensor)): - _meta = torch.load(self.cache_dir / f"{hashfile.name}-meta") + _meta = torch.load(self.cache_dir / f"{hashfile.name}-meta") # type:ignore _data = kvikio_numpy.fromfile(f"{hashfile}", dtype=_meta.pop("dtype"), like=cp.empty(())) _data = convert_to_tensor(_data.reshape(_meta.pop("shape")), device=f"cuda:{self.device}") if bool(_meta): return (_data, _meta) return _data else: - item = [] + item: list[dict[Any, Any]] = [{} for _ in range(len(item_transformed))] # type:ignore for i, _item in enumerate(item_transformed): for k in _item: - meta_i_k = torch.load(self.cache_dir / f"{hashfile.name}-{k}-meta-{i}") + meta_i_k = torch.load(self.cache_dir / f"{hashfile.name}-{k}-meta-{i}") # type:ignore item_k = kvikio_numpy.fromfile(f"{hashfile}-{k}-{i}", dtype=np.float32, like=cp.empty(())) item_k = convert_to_tensor(item[i].reshape(meta_i_k["shape"]), device=f"cuda:{self.device}") - item[i] = {f"{k}": item_k, f"{k}_meta_dict": meta_k} + item[i].update({k: item_k, f"{k}_meta_dict": meta_i_k}) return item # create new cache From 09614533ee76243319f1166598f5dd65d65ef6ce Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Wed, 26 Jul 2023 16:28:16 +0800 Subject: [PATCH 08/13] update unittest Signed-off-by: KumoLiu --- tests/test_gdsdataset.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_gdsdataset.py b/tests/test_gdsdataset.py index f1c6a97453..9c63dfed44 100644 --- a/tests/test_gdsdataset.py +++ b/tests/test_gdsdataset.py @@ -25,6 +25,8 @@ from monai.utils import optional_import from tests.utils import TEST_NDARRAYS, assert_allclose + +_, has_cp = optional_import("cupy") _, has_kvikio_numpy = optional_import("kvikio.numpy") TEST_CASE_1 = [ @@ -54,6 +56,7 @@ def __call__(self, data): return data +@unittest.skipUnless(has_cp, "Requires cupy library.") @unittest.skipUnless(has_kvikio_numpy, "Requires scikit-image library.") class TestDataset(unittest.TestCase): def test_cache(self): From 2209709ddc5815e5d7be513213e1366126e26daa Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Wed, 26 Jul 2023 16:37:32 +0800 Subject: [PATCH 09/13] fix flake8 Signed-off-by: KumoLiu --- tests/test_gdsdataset.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_gdsdataset.py b/tests/test_gdsdataset.py index 9c63dfed44..b8ed71f20b 100644 --- a/tests/test_gdsdataset.py +++ b/tests/test_gdsdataset.py @@ -25,7 +25,6 @@ from monai.utils import optional_import from tests.utils import TEST_NDARRAYS, assert_allclose - _, has_cp = optional_import("cupy") _, has_kvikio_numpy = optional_import("kvikio.numpy") From 98a62c6c972bdb75f128aa7741973a43a8742038 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Wed, 26 Jul 2023 16:53:31 +0800 Subject: [PATCH 10/13] add skip test Signed-off-by: KumoLiu --- tests/test_gdsdataset.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_gdsdataset.py b/tests/test_gdsdataset.py index b8ed71f20b..2971b34fe7 100644 --- a/tests/test_gdsdataset.py +++ b/tests/test_gdsdataset.py @@ -16,7 +16,6 @@ import tempfile import unittest -import nibabel as nib import numpy as np from parameterized import parameterized @@ -26,6 +25,7 @@ from tests.utils import TEST_NDARRAYS, assert_allclose _, has_cp = optional_import("cupy") +nib, has_nib = optional_import("nibabel") _, has_kvikio_numpy = optional_import("kvikio.numpy") TEST_CASE_1 = [ @@ -55,7 +55,8 @@ def __call__(self, data): return data -@unittest.skipUnless(has_cp, "Requires cupy library.") +@unittest.skipUnless(has_cp, "Requires CuPy library.") +@unittest.skipUnless(has_nib, "Requires nibabel package.") @unittest.skipUnless(has_kvikio_numpy, "Requires scikit-image library.") class TestDataset(unittest.TestCase): def test_cache(self): From d0058c8de8c83d97c365c72ff804aeb0d9ee77cf Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Thu, 27 Jul 2023 11:20:07 +0800 Subject: [PATCH 11/13] Update docstring Co-authored-by: Wenqi Li <831580+wyli@users.noreply.github.com> Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/data/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index b48ebbefa8..6d9f4b7751 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -1516,7 +1516,7 @@ def __init__( class GDSDataset(PersistentDataset): """ - Re-implementation of the PersistentDataset. GDSDataset enables a direct direct memory access(DMA) data path between + An extension of the PersistentDataset using direct memory access(DMA) data path between GPU memory and storage, thus avoiding a bounce buffer through the CPU. This direct path can increase system bandwidth while decreasing latency and utilization load on the CPU and GPU. From 304efee0ecca24e96df9a8e0d29845c691457c34 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 27 Jul 2023 11:38:42 +0800 Subject: [PATCH 12/13] update docstring Signed-off-by: KumoLiu --- monai/data/dataset.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 6d9f4b7751..acf9f476ab 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -1550,15 +1550,6 @@ def __init__( specify the gpu to be used. hash_func: a callable to compute hash from data items to be cached. defaults to `monai.data.utils.pickle_hashing`. - pickle_module: string representing the module used for pickling metadata and objects, - default to `"pickle"`. due to the pickle limitation in multi-processing of Dataloader, - we can't use `pickle` as arg directly, so here we use a string name instead. - if want to use other pickle module at runtime, just register like: - >>> from monai.data import utils - >>> utils.SUPPORTED_PICKLE_MOD["test"] = other_pickle - this arg is used by `torch.save`, for more details, please check: - https://pytorch.org/docs/stable/generated/torch.save.html#torch.save, - and ``monai.data.utils.SUPPORTED_PICKLE_MOD``. hash_transform: a callable to compute hash from the transform information when caching. This may reduce errors due to transforms changing during experiments. Default to None (no hash). Other options are `pickle_hashing` and `json_hashing` functions from `monai.data.utils`. From d3c240e33f66b61d1c63f4d79ded43b80df1753e Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 27 Jul 2023 11:58:36 +0800 Subject: [PATCH 13/13] cache meta in `self._meta_cache` Signed-off-by: KumoLiu --- monai/data/dataset.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index acf9f476ab..6aebe47ed7 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -1569,6 +1569,7 @@ def __init__( **kwargs, ) self.device = device + self._meta_cache: dict[Any, dict[Any, Any]] = {} def _cachecheck(self, item_transformed): """ @@ -1599,13 +1600,13 @@ def _cachecheck(self, item_transformed): if isinstance(item_transformed, dict): item: dict[Any, Any] = {} # type:ignore for k in item_transformed: - meta_k = torch.load(self.cache_dir / f"{hashfile.name}-{k}-meta") # type:ignore + meta_k = self._load_meta_cache(meta_hash_file_name=f"{hashfile.name}-{k}-meta") item[k] = kvikio_numpy.fromfile(f"{hashfile}-{k}", dtype=meta_k["dtype"], like=cp.empty(())) item[k] = convert_to_tensor(item[k].reshape(meta_k["shape"]), device=f"cuda:{self.device}") item[f"{k}_meta_dict"] = meta_k return item elif isinstance(item_transformed, (np.ndarray, torch.Tensor)): - _meta = torch.load(self.cache_dir / f"{hashfile.name}-meta") # type:ignore + _meta = self._load_meta_cache(meta_hash_file_name=f"{hashfile.name}-meta") _data = kvikio_numpy.fromfile(f"{hashfile}", dtype=_meta.pop("dtype"), like=cp.empty(())) _data = convert_to_tensor(_data.reshape(_meta.pop("shape")), device=f"cuda:{self.device}") if bool(_meta): @@ -1615,7 +1616,7 @@ def _cachecheck(self, item_transformed): item: list[dict[Any, Any]] = [{} for _ in range(len(item_transformed))] # type:ignore for i, _item in enumerate(item_transformed): for k in _item: - meta_i_k = torch.load(self.cache_dir / f"{hashfile.name}-{k}-meta-{i}") # type:ignore + meta_i_k = self._load_meta_cache(meta_hash_file_name=f"{hashfile.name}-{k}-meta-{i}") item_k = kvikio_numpy.fromfile(f"{hashfile}-{k}-{i}", dtype=np.float32, like=cp.empty(())) item_k = convert_to_tensor(item[i].reshape(meta_i_k["shape"]), device=f"cuda:{self.device}") item[i].update({k: item_k, f"{k}_meta_dict": meta_i_k}) @@ -1647,12 +1648,12 @@ def _cachecheck(self, item_transformed): return _item_transformed def _create_new_cache(self, data, data_hashfile, meta_hash_file_name): - _item_transformed_meta = data.meta if isinstance(data, MetaTensor) else {} + self._meta_cache[meta_hash_file_name] = copy(data.meta) if isinstance(data, MetaTensor) else {} _item_transformed_data = data.array if isinstance(data, MetaTensor) else data if isinstance(_item_transformed_data, torch.Tensor): _item_transformed_data = _item_transformed_data.numpy() - _item_transformed_meta["shape"] = _item_transformed_data.shape - _item_transformed_meta["dtype"] = _item_transformed_data.dtype + self._meta_cache[meta_hash_file_name]["shape"] = _item_transformed_data.shape + self._meta_cache[meta_hash_file_name]["dtype"] = _item_transformed_data.dtype kvikio_numpy.tofile(_item_transformed_data, data_hashfile) try: # NOTE: Writing to a temporary directory and then using a nearly atomic rename operation @@ -1662,7 +1663,7 @@ def _create_new_cache(self, data, data_hashfile, meta_hash_file_name): meta_hash_file = self.cache_dir / meta_hash_file_name temp_hash_file = Path(tmpdirname) / meta_hash_file_name torch.save( - obj=_item_transformed_meta, + obj=self._meta_cache[meta_hash_file_name], f=temp_hash_file, pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD), pickle_protocol=self.pickle_protocol, @@ -1677,3 +1678,9 @@ def _create_new_cache(self, data, data_hashfile, meta_hash_file_name): pass except PermissionError: # project-monai/monai issue #3613 pass + + def _load_meta_cache(self, meta_hash_file_name): + if meta_hash_file_name in self._meta_cache: + return self._meta_cache[meta_hash_file_name] + else: + return torch.load(self.cache_dir / meta_hash_file_name) # type:ignore