diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 9339897d7a..340c5eb8fa 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -33,6 +33,7 @@ CSVDataset, Dataset, DatasetFunc, + GDSDataset, LMDBDataset, NPZDictItemDataset, PersistentDataset, diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 912576bdcc..6aebe47ed7 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -34,6 +34,7 @@ from torch.utils.data import Dataset as _TorchDataset from torch.utils.data import Subset +from monai.data.meta_tensor import MetaTensor from monai.data.utils import SUPPORTED_PICKLE_MOD, convert_tables_to_dicts, pickle_hashing from monai.transforms import ( Compose, @@ -44,7 +45,7 @@ convert_to_contiguous, reset_ops_id, ) -from monai.utils import MAX_SEED, get_seed, look_up_option, min_version, optional_import +from monai.utils import MAX_SEED, convert_to_tensor, get_seed, look_up_option, min_version, optional_import from monai.utils.misc import first if TYPE_CHECKING: @@ -54,8 +55,10 @@ else: tqdm, has_tqdm = optional_import("tqdm", "4.47.0", min_version, "tqdm") +cp, _ = optional_import("cupy") lmdb, _ = optional_import("lmdb") pd, _ = optional_import("pandas") +kvikio_numpy, _ = optional_import("kvikio.numpy") class Dataset(_TorchDataset): @@ -326,7 +329,6 @@ def _pre_transform(self, item_transformed): first_random = self.transform.get_index_of_first( lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform) ) - item_transformed = self.transform(item_transformed, end=first_random, threading=True) if self.reset_ops_id: @@ -1510,3 +1512,175 @@ def __init__( dfs=dfs, row_indices=row_indices, col_names=col_names, col_types=col_types, col_groups=col_groups, **kwargs ) super().__init__(data=data, transform=transform) + + +class GDSDataset(PersistentDataset): + """ + An extension of the PersistentDataset using direct memory access(DMA) data path between + GPU memory and storage, thus avoiding a bounce buffer through the CPU. This direct path can increase system + bandwidth while decreasing latency and utilization load on the CPU and GPU. + + A tutorial is available: https://github.com/Project-MONAI/tutorials/blob/main/modules/GDS_dataset.ipynb. + """ + + def __init__( + self, + data: Sequence, + transform: Sequence[Callable] | Callable, + cache_dir: Path | str | None, + device: int, + hash_func: Callable[..., bytes] = pickle_hashing, + hash_transform: Callable[..., bytes] | None = None, + reset_ops_id: bool = True, + **kwargs: Any, + ) -> None: + """ + Args: + data: input data file paths to load and transform to generate dataset for model. + `GDSDataset` expects input data to be a list of serializable + and hashes them as cache keys using `hash_func`. + transform: transforms to execute operations on input data. + cache_dir: If specified, this is the location for gpu direct storage + of pre-computed transformed data tensors. The cache_dir is computed once, and + persists on disk until explicitly removed. Different runs, programs, experiments + may share a common cache dir provided that the transforms pre-processing is consistent. + If `cache_dir` doesn't exist, will automatically create it. + If `cache_dir` is `None`, there is effectively no caching. + device: target device to put the output Tensor data. Note that only int can be used to + specify the gpu to be used. + hash_func: a callable to compute hash from data items to be cached. + defaults to `monai.data.utils.pickle_hashing`. + hash_transform: a callable to compute hash from the transform information when caching. + This may reduce errors due to transforms changing during experiments. Default to None (no hash). + Other options are `pickle_hashing` and `json_hashing` functions from `monai.data.utils`. + reset_ops_id: whether to set `TraceKeys.ID` to ``Tracekys.NONE``, defaults to ``True``. + When this is enabled, the traced transform instance IDs will be removed from the cached MetaTensors. + This is useful for skipping the transform instance checks when inverting applied operations + using the cached content and with re-created transform instances. + + """ + super().__init__( + data=data, + transform=transform, + cache_dir=cache_dir, + hash_func=hash_func, + hash_transform=hash_transform, + reset_ops_id=reset_ops_id, + **kwargs, + ) + self.device = device + self._meta_cache: dict[Any, dict[Any, Any]] = {} + + def _cachecheck(self, item_transformed): + """ + In order to enable direct storage to the GPU when loading the hashfile, rewritten this function. + Note that in this function, it will always return `torch.Tensor` when load data from cache. + + Args: + item_transformed: The current data element to be mutated into transformed representation + + Returns: + The transformed data_element, either from cache, or explicitly computing it. + + Warning: + The current implementation does not encode transform information as part of the + hashing mechanism used for generating cache names when `hash_transform` is None. + If the transforms applied are changed in any way, the objects in the cache dir will be invalid. + + """ + hashfile = None + # compute a cache id + if self.cache_dir is not None: + data_item_md5 = self.hash_func(item_transformed).decode("utf-8") + data_item_md5 += self.transform_hash + hashfile = self.cache_dir / f"{data_item_md5}.pt" + + if hashfile is not None and hashfile.is_file(): # cache hit + with cp.cuda.Device(self.device): + if isinstance(item_transformed, dict): + item: dict[Any, Any] = {} # type:ignore + for k in item_transformed: + meta_k = self._load_meta_cache(meta_hash_file_name=f"{hashfile.name}-{k}-meta") + item[k] = kvikio_numpy.fromfile(f"{hashfile}-{k}", dtype=meta_k["dtype"], like=cp.empty(())) + item[k] = convert_to_tensor(item[k].reshape(meta_k["shape"]), device=f"cuda:{self.device}") + item[f"{k}_meta_dict"] = meta_k + return item + elif isinstance(item_transformed, (np.ndarray, torch.Tensor)): + _meta = self._load_meta_cache(meta_hash_file_name=f"{hashfile.name}-meta") + _data = kvikio_numpy.fromfile(f"{hashfile}", dtype=_meta.pop("dtype"), like=cp.empty(())) + _data = convert_to_tensor(_data.reshape(_meta.pop("shape")), device=f"cuda:{self.device}") + if bool(_meta): + return (_data, _meta) + return _data + else: + item: list[dict[Any, Any]] = [{} for _ in range(len(item_transformed))] # type:ignore + for i, _item in enumerate(item_transformed): + for k in _item: + meta_i_k = self._load_meta_cache(meta_hash_file_name=f"{hashfile.name}-{k}-meta-{i}") + item_k = kvikio_numpy.fromfile(f"{hashfile}-{k}-{i}", dtype=np.float32, like=cp.empty(())) + item_k = convert_to_tensor(item[i].reshape(meta_i_k["shape"]), device=f"cuda:{self.device}") + item[i].update({k: item_k, f"{k}_meta_dict": meta_i_k}) + return item + + # create new cache + _item_transformed = self._pre_transform(deepcopy(item_transformed)) # keep the original hashed + if hashfile is None: + return _item_transformed + if isinstance(_item_transformed, dict): + for k in _item_transformed: + data_hashfile = f"{hashfile}-{k}" + meta_hash_file_name = f"{hashfile.name}-{k}-meta" + if isinstance(_item_transformed[k], (np.ndarray, torch.Tensor)): + self._create_new_cache(_item_transformed[k], data_hashfile, meta_hash_file_name) + else: + return _item_transformed + elif isinstance(_item_transformed, (np.ndarray, torch.Tensor)): + data_hashfile = f"{hashfile}" + meta_hash_file_name = f"{hashfile.name}-meta" + self._create_new_cache(_item_transformed, data_hashfile, meta_hash_file_name) + else: + for i, _item in enumerate(_item_transformed): + for k in _item: + data_hashfile = f"{hashfile}-{k}-{i}" + meta_hash_file_name = f"{hashfile.name}-{k}-meta-{i}" + self._create_new_cache(_item, data_hashfile, meta_hash_file_name) + open(hashfile, "a").close() # store cacheid + return _item_transformed + + def _create_new_cache(self, data, data_hashfile, meta_hash_file_name): + self._meta_cache[meta_hash_file_name] = copy(data.meta) if isinstance(data, MetaTensor) else {} + _item_transformed_data = data.array if isinstance(data, MetaTensor) else data + if isinstance(_item_transformed_data, torch.Tensor): + _item_transformed_data = _item_transformed_data.numpy() + self._meta_cache[meta_hash_file_name]["shape"] = _item_transformed_data.shape + self._meta_cache[meta_hash_file_name]["dtype"] = _item_transformed_data.dtype + kvikio_numpy.tofile(_item_transformed_data, data_hashfile) + try: + # NOTE: Writing to a temporary directory and then using a nearly atomic rename operation + # to make the cache more robust to manual killing of parent process + # which may leave partially written cache files in an incomplete state + with tempfile.TemporaryDirectory() as tmpdirname: + meta_hash_file = self.cache_dir / meta_hash_file_name + temp_hash_file = Path(tmpdirname) / meta_hash_file_name + torch.save( + obj=self._meta_cache[meta_hash_file_name], + f=temp_hash_file, + pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD), + pickle_protocol=self.pickle_protocol, + ) + if temp_hash_file.is_file() and not meta_hash_file.is_file(): + # On Unix, if target exists and is a file, it will be replaced silently if the + # user has permission. + # for more details: https://docs.python.org/3/library/shutil.html#shutil.move. + try: + shutil.move(str(temp_hash_file), meta_hash_file) + except FileExistsError: + pass + except PermissionError: # project-monai/monai issue #3613 + pass + + def _load_meta_cache(self, meta_hash_file_name): + if meta_hash_file_name in self._meta_cache: + return self._meta_cache[meta_hash_file_name] + else: + return torch.load(self.cache_dir / meta_hash_file_name) # type:ignore diff --git a/tests/test_gdsdataset.py b/tests/test_gdsdataset.py new file mode 100644 index 0000000000..2971b34fe7 --- /dev/null +++ b/tests/test_gdsdataset.py @@ -0,0 +1,196 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import pickle +import tempfile +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.data import GDSDataset, json_hashing +from monai.transforms import Compose, Flip, Identity, LoadImaged, SimulateDelayd, Transform +from monai.utils import optional_import +from tests.utils import TEST_NDARRAYS, assert_allclose + +_, has_cp = optional_import("cupy") +nib, has_nib = optional_import("nibabel") +_, has_kvikio_numpy = optional_import("kvikio.numpy") + +TEST_CASE_1 = [ + Compose( + [ + LoadImaged(keys=["image", "label", "extra"], image_only=True), + SimulateDelayd(keys=["image", "label", "extra"], delay_time=[1e-7, 1e-6, 1e-5]), + ] + ), + (128, 128, 128), +] + +TEST_CASE_2 = [ + [ + LoadImaged(keys=["image", "label", "extra"], image_only=True), + SimulateDelayd(keys=["image", "label", "extra"], delay_time=[1e-7, 1e-6, 1e-5]), + ], + (128, 128, 128), +] + +TEST_CASE_3 = [None, (128, 128, 128)] + + +class _InplaceXform(Transform): + def __call__(self, data): + data[0] = data[0] + 1 + return data + + +@unittest.skipUnless(has_cp, "Requires CuPy library.") +@unittest.skipUnless(has_nib, "Requires nibabel package.") +@unittest.skipUnless(has_kvikio_numpy, "Requires scikit-image library.") +class TestDataset(unittest.TestCase): + def test_cache(self): + """testing no inplace change to the hashed item""" + for p in TEST_NDARRAYS[:2]: + shape = (1, 10, 9, 8) + items = [p(np.arange(0, np.prod(shape)).reshape(shape))] + + with tempfile.TemporaryDirectory() as tempdir: + ds = GDSDataset( + data=items, + transform=_InplaceXform(), + cache_dir=tempdir, + device=0, + pickle_module="pickle", + pickle_protocol=pickle.HIGHEST_PROTOCOL, + ) + assert_allclose(items[0], p(np.arange(0, np.prod(shape)).reshape(shape))) + ds1 = GDSDataset(items, transform=_InplaceXform(), cache_dir=tempdir, device=0) + assert_allclose(ds[0], ds1[0], type_test=False) + assert_allclose(items[0], p(np.arange(0, np.prod(shape)).reshape(shape))) + + ds = GDSDataset( + items, transform=_InplaceXform(), cache_dir=tempdir, hash_transform=json_hashing, device=0 + ) + assert_allclose(items[0], p(np.arange(0, np.prod(shape)).reshape(shape))) + ds1 = GDSDataset( + items, transform=_InplaceXform(), cache_dir=tempdir, hash_transform=json_hashing, device=0 + ) + assert_allclose(ds[0], ds1[0], type_test=False) + assert_allclose(items[0], p(np.arange(0, np.prod(shape)).reshape(shape))) + + def test_metatensor(self): + shape = (1, 10, 9, 8) + items = [TEST_NDARRAYS[-1](np.arange(0, np.prod(shape)).reshape(shape))] + with tempfile.TemporaryDirectory() as tempdir: + ds = GDSDataset( + data=items, + transform=_InplaceXform(), + cache_dir=tempdir, + device=0, + pickle_module="pickle", + pickle_protocol=pickle.HIGHEST_PROTOCOL, + ) + assert_allclose(ds[0], ds[0][0], type_test=False) + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_shape(self, transform, expected_shape): + test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4)) + with tempfile.TemporaryDirectory() as tempdir: + nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_extra1.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_image2.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_label2.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_extra2.nii.gz")) + test_data = [ + { + "image": os.path.join(tempdir, "test_image1.nii.gz"), + "label": os.path.join(tempdir, "test_label1.nii.gz"), + "extra": os.path.join(tempdir, "test_extra1.nii.gz"), + }, + { + "image": os.path.join(tempdir, "test_image2.nii.gz"), + "label": os.path.join(tempdir, "test_label2.nii.gz"), + "extra": os.path.join(tempdir, "test_extra2.nii.gz"), + }, + ] + + cache_dir = os.path.join(os.path.join(tempdir, "cache"), "data") + dataset_precached = GDSDataset(data=test_data, transform=transform, cache_dir=cache_dir, device=0) + data1_precached = dataset_precached[0] + data2_precached = dataset_precached[1] + + dataset_postcached = GDSDataset(data=test_data, transform=transform, cache_dir=cache_dir, device=0) + data1_postcached = dataset_postcached[0] + data2_postcached = dataset_postcached[1] + data3_postcached = dataset_postcached[0:2] + + if transform is None: + self.assertEqual(data1_precached["image"], os.path.join(tempdir, "test_image1.nii.gz")) + self.assertEqual(data2_precached["label"], os.path.join(tempdir, "test_label2.nii.gz")) + self.assertEqual(data1_postcached["image"], os.path.join(tempdir, "test_image1.nii.gz")) + self.assertEqual(data2_postcached["extra"], os.path.join(tempdir, "test_extra2.nii.gz")) + else: + self.assertTupleEqual(data1_precached["image"].shape, expected_shape) + self.assertTupleEqual(data1_precached["label"].shape, expected_shape) + self.assertTupleEqual(data1_precached["extra"].shape, expected_shape) + self.assertTupleEqual(data2_precached["image"].shape, expected_shape) + self.assertTupleEqual(data2_precached["label"].shape, expected_shape) + self.assertTupleEqual(data2_precached["extra"].shape, expected_shape) + + self.assertTupleEqual(data1_postcached["image"].shape, expected_shape) + self.assertTupleEqual(data1_postcached["label"].shape, expected_shape) + self.assertTupleEqual(data1_postcached["extra"].shape, expected_shape) + self.assertTupleEqual(data2_postcached["image"].shape, expected_shape) + self.assertTupleEqual(data2_postcached["label"].shape, expected_shape) + self.assertTupleEqual(data2_postcached["extra"].shape, expected_shape) + for d in data3_postcached: + self.assertTupleEqual(d["image"].shape, expected_shape) + + # update the data to cache + test_data_new = [ + { + "image": os.path.join(tempdir, "test_image1_new.nii.gz"), + "label": os.path.join(tempdir, "test_label1_new.nii.gz"), + "extra": os.path.join(tempdir, "test_extra1_new.nii.gz"), + }, + { + "image": os.path.join(tempdir, "test_image2_new.nii.gz"), + "label": os.path.join(tempdir, "test_label2_new.nii.gz"), + "extra": os.path.join(tempdir, "test_extra2_new.nii.gz"), + }, + ] + dataset_postcached.set_data(data=test_data_new) + # test new exchanged cache content + if transform is None: + self.assertEqual(dataset_postcached[0]["image"], os.path.join(tempdir, "test_image1_new.nii.gz")) + self.assertEqual(dataset_postcached[0]["label"], os.path.join(tempdir, "test_label1_new.nii.gz")) + self.assertEqual(dataset_postcached[1]["extra"], os.path.join(tempdir, "test_extra2_new.nii.gz")) + + def test_different_transforms(self): + """ + Different instances of `GDSDataset` with the same cache_dir, + same input data, but different transforms should give different results. + """ + shape = (1, 10, 9, 8) + im = np.arange(0, np.prod(shape)).reshape(shape) + with tempfile.TemporaryDirectory() as path: + im1 = GDSDataset([im], Identity(), cache_dir=path, hash_transform=json_hashing, device=0)[0] + im2 = GDSDataset([im], Flip(1), cache_dir=path, hash_transform=json_hashing, device=0)[0] + l2 = ((im1 - im2) ** 2).sum() ** 0.5 + self.assertTrue(l2 > 1) + + +if __name__ == "__main__": + unittest.main()