diff --git a/modules/GDS_dataset.ipynb b/modules/GDS_dataset.ipynb index 3116567eb6..01ab0cc958 100644 --- a/modules/GDS_dataset.ipynb +++ b/modules/GDS_dataset.ipynb @@ -39,10 +39,6 @@ " \n", " (Replace X with the major version of the CUDA toolkit, and Y with the minor version.)\n", "\n", - "- `GDSDataset` inherited from `PersistentDataset`.\n", - "\n", - " In this tutorial, we have implemented a `GDSDataset` that inherits from `PersistentDataset`. We have re-implemented the `_cachecheck` method to create and save cache using GDS.\n", - "\n", "- A simple demo comparing the time taken with and without GDS.\n", "\n", " In this tutorial, we are creating a conda environment to install `kvikio`, which provides a Python API for GDS. To install `kvikio` using other methods, refer to https://github.com/rapidsai/kvikio#install.\n", @@ -79,28 +75,21 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import time\n", - "import cupy\n", "import torch\n", "import shutil\n", "import tempfile\n", - "import numpy as np\n", - "from typing import Any\n", - "from pathlib import Path\n", - "from copy import deepcopy\n", - "from collections.abc import Callable, Sequence\n", - "from kvikio.numpy import fromfile, tofile\n", "\n", "import monai\n", "import monai.transforms as mt\n", "from monai.config import print_config\n", - "from monai.data.utils import pickle_hashing, SUPPORTED_PICKLE_MOD\n", - "from monai.utils import convert_to_tensor, set_determinism, look_up_option\n", + "from monai.data.dataset import GDSDataset\n", + "from monai.utils import set_determinism\n", "\n", "print_config()" ] @@ -135,100 +124,6 @@ "print(root_dir)" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## GDSDataset" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "class GDSDataset(monai.data.PersistentDataset):\n", - " def __init__(\n", - " self,\n", - " data: Sequence,\n", - " transform: Sequence[Callable] | Callable,\n", - " cache_dir: Path | str | None,\n", - " hash_func: Callable[..., bytes] = pickle_hashing,\n", - " hash_transform: Callable[..., bytes] | None = None,\n", - " reset_ops_id: bool = True,\n", - " device: int = None,\n", - " **kwargs: Any,\n", - " ) -> None:\n", - " super().__init__(\n", - " data=data,\n", - " transform=transform,\n", - " cache_dir=cache_dir,\n", - " hash_func=hash_func,\n", - " hash_transform=hash_transform,\n", - " reset_ops_id=reset_ops_id,\n", - " **kwargs,\n", - " )\n", - " self.device = device\n", - "\n", - " def _cachecheck(self, item_transformed):\n", - " \"\"\"given the input dictionary ``item_transformed``, return a transformed version of it\"\"\"\n", - " hashfile = None\n", - " # compute a cache id\n", - " if self.cache_dir is not None:\n", - " data_item_md5 = self.hash_func(item_transformed).decode(\"utf-8\")\n", - " data_item_md5 += self.transform_hash\n", - " hashfile = self.cache_dir / f\"{data_item_md5}.pt\"\n", - "\n", - " if hashfile is not None and hashfile.is_file(): # cache hit\n", - " with cupy.cuda.Device(self.device):\n", - " item = {}\n", - " for k in item_transformed:\n", - " meta_k = torch.load(self.cache_dir / f\"{hashfile.name}-{k}-meta\")\n", - " item[k] = fromfile(f\"{hashfile}-{k}\", dtype=np.float32, like=cupy.empty(()))\n", - " item[k] = convert_to_tensor(item[k].reshape(meta_k[\"shape\"]), device=f\"cuda:{self.device}\")\n", - " item[f\"{k}_meta_dict\"] = meta_k\n", - " return item\n", - "\n", - " # create new cache\n", - " _item_transformed = self._pre_transform(deepcopy(item_transformed)) # keep the original hashed\n", - " if hashfile is None:\n", - " return _item_transformed\n", - "\n", - " for k in _item_transformed: # {'image': ..., 'label': ...}\n", - " _item_transformed_meta = _item_transformed[k].meta\n", - " _item_transformed_data = _item_transformed[k].array\n", - " _item_transformed_meta[\"shape\"] = _item_transformed_data.shape\n", - " tofile(_item_transformed_data, f\"{hashfile}-{k}\")\n", - " try:\n", - " # NOTE: Writing to a temporary directory and then using a nearly atomic rename operation\n", - " # to make the cache more robust to manual killing of parent process\n", - " # which may leave partially written cache files in an incomplete state\n", - " with tempfile.TemporaryDirectory() as tmpdirname:\n", - " meta_hash_file_name = f\"{hashfile.name}-{k}-meta\"\n", - " meta_hash_file = self.cache_dir / meta_hash_file_name\n", - " temp_hash_file = Path(tmpdirname) / meta_hash_file_name\n", - " torch.save(\n", - " obj=_item_transformed_meta,\n", - " f=temp_hash_file,\n", - " pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD),\n", - " pickle_protocol=self.pickle_protocol,\n", - " )\n", - " if temp_hash_file.is_file() and not meta_hash_file.is_file():\n", - " # On Unix, if target exists and is a file, it will be replaced silently if the\n", - " # user has permission.\n", - " # for more details: https://docs.python.org/3/library/shutil.html#shutil.move.\n", - " try:\n", - " shutil.move(str(temp_hash_file), meta_hash_file)\n", - " except FileExistsError:\n", - " pass\n", - " except PermissionError: # project-monai/monai issue #3613\n", - " pass\n", - " open(hashfile, \"a\").close() # store cacheid\n", - "\n", - " return _item_transformed" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -245,16 +140,16 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "2023-07-12 09:26:17,878 - INFO - Expected md5 is None, skip md5 check for file samples.zip.\n", - "2023-07-12 09:26:17,878 - INFO - File exists: samples.zip, skipped downloading.\n", - "2023-07-12 09:26:17,879 - INFO - Writing into directory: /raid/yliu/test_tutorial.\n" + "2023-07-27 07:59:12,054 - INFO - Expected md5 is None, skip md5 check for file samples.zip.\n", + "2023-07-27 07:59:12,055 - INFO - File exists: samples.zip, skipped downloading.\n", + "2023-07-27 07:59:12,056 - INFO - Writing into directory: /raid/yliu/test_tutorial.\n" ] } ], @@ -283,7 +178,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -299,7 +194,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -332,19 +227,19 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "epoch0 time 19.746733903884888\n", - "epoch1 time 0.9976603984832764\n", - "epoch2 time 0.982248067855835\n", - "epoch3 time 0.9838874340057373\n", - "epoch4 time 0.9793403148651123\n", - "total time 23.69102692604065\n" + "epoch0 time 20.148560762405396\n", + "epoch1 time 0.9835140705108643\n", + "epoch2 time 0.9708101749420166\n", + "epoch3 time 0.9711742401123047\n", + "epoch4 time 0.9711296558380127\n", + "total time 24.04619812965393\n" ] } ], @@ -372,19 +267,19 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "epoch0 time 21.206729650497437\n", - "epoch1 time 1.510526180267334\n", - "epoch2 time 1.588256597518921\n", - "epoch3 time 1.4431262016296387\n", - "epoch4 time 1.4594802856445312\n", - "total time 27.20927882194519\n" + "epoch0 time 21.170511722564697\n", + "epoch1 time 1.482978105545044\n", + "epoch2 time 1.5378782749176025\n", + "epoch3 time 1.4499244689941406\n", + "epoch4 time 1.4379286766052246\n", + "total time 27.08065962791443\n" ] } ],