Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/docker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ jobs:
steps:
- name: Import
run: |
export CUDA_VISIBLE_DEVICES= # cpu-only
export OMP_NUM_THREADS=4 MKL_NUM_THREADS=4 CUDA_VISIBLE_DEVICES= # cpu-only
python -c 'import monai; monai.config.print_debug_info()'
cd /opt/monai
ls -al
Expand Down
7 changes: 7 additions & 0 deletions docs/source/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ Generic Interfaces
:members:
:special-members: __getitem__

`GDSDataset`
~~~~~~~~~~~~~~~~~~~
.. autoclass:: GDSDataset
:members:
:special-members: __getitem__


`CacheNTransDataset`
~~~~~~~~~~~~~~~~~~~~
.. autoclass:: CacheNTransDataset
Expand Down
15 changes: 10 additions & 5 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1521,6 +1521,8 @@ class GDSDataset(PersistentDataset):
bandwidth while decreasing latency and utilization load on the CPU and GPU.

A tutorial is available: https://github.com/Project-MONAI/tutorials/blob/main/modules/GDS_dataset.ipynb.

See also: https://github.com/rapidsai/kvikio
"""

def __init__(
Expand Down Expand Up @@ -1607,17 +1609,20 @@ def _cachecheck(self, item_transformed):
return item
elif isinstance(item_transformed, (np.ndarray, torch.Tensor)):
_meta = self._load_meta_cache(meta_hash_file_name=f"{hashfile.name}-meta")
_data = kvikio_numpy.fromfile(f"{hashfile}", dtype=_meta.pop("dtype"), like=cp.empty(()))
_data = convert_to_tensor(_data.reshape(_meta.pop("shape")), device=f"cuda:{self.device}")
if bool(_meta):
_data = kvikio_numpy.fromfile(f"{hashfile}", dtype=_meta["dtype"], like=cp.empty(()))
_data = convert_to_tensor(_data.reshape(_meta["shape"]), device=f"cuda:{self.device}")
filtered_keys = list(filter(lambda key: key not in ["dtype", "shape"], _meta.keys()))
if bool(filtered_keys):
return (_data, _meta)
return _data
else:
item: list[dict[Any, Any]] = [{} for _ in range(len(item_transformed))] # type:ignore
for i, _item in enumerate(item_transformed):
for k in _item:
meta_i_k = self._load_meta_cache(meta_hash_file_name=f"{hashfile.name}-{k}-meta-{i}")
item_k = kvikio_numpy.fromfile(f"{hashfile}-{k}-{i}", dtype=np.float32, like=cp.empty(()))
item_k = kvikio_numpy.fromfile(
f"{hashfile}-{k}-{i}", dtype=meta_i_k["dtype"], like=cp.empty(())
)
item_k = convert_to_tensor(item[i].reshape(meta_i_k["shape"]), device=f"cuda:{self.device}")
item[i].update({k: item_k, f"{k}_meta_dict": meta_i_k})
return item
Expand Down Expand Up @@ -1653,7 +1658,7 @@ def _create_new_cache(self, data, data_hashfile, meta_hash_file_name):
if isinstance(_item_transformed_data, torch.Tensor):
_item_transformed_data = _item_transformed_data.numpy()
self._meta_cache[meta_hash_file_name]["shape"] = _item_transformed_data.shape
self._meta_cache[meta_hash_file_name]["dtype"] = _item_transformed_data.dtype
self._meta_cache[meta_hash_file_name]["dtype"] = str(_item_transformed_data.dtype)
kvikio_numpy.tofile(_item_transformed_data, data_hashfile)
try:
# NOTE: Writing to a temporary directory and then using a nearly atomic rename operation
Expand Down
42 changes: 34 additions & 8 deletions tests/test_gdsdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.data import GDSDataset, json_hashing
Expand Down Expand Up @@ -48,6 +49,19 @@

TEST_CASE_3 = [None, (128, 128, 128)]

DTYPES = {
np.dtype(np.uint8): torch.uint8,
np.dtype(np.int8): torch.int8,
np.dtype(np.int16): torch.int16,
np.dtype(np.int32): torch.int32,
np.dtype(np.int64): torch.int64,
np.dtype(np.float16): torch.float16,
np.dtype(np.float32): torch.float32,
np.dtype(np.float64): torch.float64,
np.dtype(np.complex64): torch.complex64,
np.dtype(np.complex128): torch.complex128,
}


class _InplaceXform(Transform):
def __call__(self, data):
Expand Down Expand Up @@ -93,16 +107,28 @@ def test_metatensor(self):
shape = (1, 10, 9, 8)
items = [TEST_NDARRAYS[-1](np.arange(0, np.prod(shape)).reshape(shape))]
with tempfile.TemporaryDirectory() as tempdir:
ds = GDSDataset(
data=items,
transform=_InplaceXform(),
cache_dir=tempdir,
device=0,
pickle_module="pickle",
pickle_protocol=pickle.HIGHEST_PROTOCOL,
)
ds = GDSDataset(data=items, transform=_InplaceXform(), cache_dir=tempdir, device=0)
assert_allclose(ds[0], ds[0][0], type_test=False)

def test_dtype(self):
shape = (1, 10, 9, 8)
data = np.arange(0, np.prod(shape)).reshape(shape)
for _dtype in DTYPES.keys():
items = [np.array(data).astype(_dtype)]
with tempfile.TemporaryDirectory() as tempdir:
ds = GDSDataset(data=items, transform=_InplaceXform(), cache_dir=tempdir, device=0)
ds1 = GDSDataset(data=items, transform=_InplaceXform(), cache_dir=tempdir, device=0)
self.assertEqual(ds[0].dtype, _dtype)
self.assertEqual(ds1[0].dtype, DTYPES[_dtype])

for _dtype in DTYPES.keys():
items = [torch.tensor(data, dtype=DTYPES[_dtype])]
with tempfile.TemporaryDirectory() as tempdir:
ds = GDSDataset(data=items, transform=_InplaceXform(), cache_dir=tempdir, device=0)
ds1 = GDSDataset(data=items, transform=_InplaceXform(), cache_dir=tempdir, device=0)
self.assertEqual(ds[0].dtype, DTYPES[_dtype])
self.assertEqual(ds1[0].dtype, DTYPES[_dtype])

@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_shape(self, transform, expected_shape):
test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4))
Expand Down