diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 637f0873f1..3e45d899ec 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -554,6 +554,12 @@ IO :members: :special-members: __call__ +`WriteFileMapping` +"""""""""""""""""" +.. autoclass:: WriteFileMapping + :members: + :special-members: __call__ + NVIDIA Tool Extension (NVTX) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -1642,6 +1648,12 @@ IO (Dict) :members: :special-members: __call__ +`WriteFileMappingd` +""""""""""""""""""" +.. autoclass:: WriteFileMappingd + :members: + :special-members: __call__ + Post-processing (Dict) ^^^^^^^^^^^^^^^^^^^^^^ diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 9548443768..f37016e63f 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -238,8 +238,18 @@ ) from .inverse import InvertibleTransform, TraceableTransform from .inverse_batch_transform import BatchInverseTransform, Decollated, DecollateD, DecollateDict -from .io.array import SUPPORTED_READERS, LoadImage, SaveImage -from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict +from .io.array import SUPPORTED_READERS, LoadImage, SaveImage, WriteFileMapping +from .io.dictionary import ( + LoadImaged, + LoadImageD, + LoadImageDict, + SaveImaged, + SaveImageD, + SaveImageDict, + WriteFileMappingd, + WriteFileMappingD, + WriteFileMappingDict, +) from .lazy.array import ApplyPending from .lazy.dictionary import ApplyPendingd, ApplyPendingD, ApplyPendingDict from .lazy.functional import apply_pending diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 7c0e8f7123..4e71870fc9 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -15,6 +15,7 @@ from __future__ import annotations import inspect +import json import logging import sys import traceback @@ -45,11 +46,19 @@ from monai.transforms.utility.array import EnsureChannelFirst from monai.utils import GridSamplePadMode from monai.utils import ImageMetaKey as Key -from monai.utils import OptionalImportError, convert_to_dst_type, ensure_tuple, look_up_option, optional_import +from monai.utils import ( + MetaKeys, + OptionalImportError, + convert_to_dst_type, + ensure_tuple, + look_up_option, + optional_import, +) nib, _ = optional_import("nibabel") Image, _ = optional_import("PIL.Image") nrrd, _ = optional_import("nrrd") +FileLock, has_filelock = optional_import("filelock", name="FileLock") __all__ = ["LoadImage", "SaveImage", "SUPPORTED_READERS"] @@ -505,7 +514,7 @@ def __call__( else: self._data_index += 1 if self.savepath_in_metadict and meta_data is not None: - meta_data["saved_to"] = filename + meta_data[MetaKeys.SAVED_TO] = filename return img msg = "\n".join([f"{e}" for e in err]) raise RuntimeError( @@ -514,3 +523,50 @@ def __call__( " https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies.\n" f" The current registered writers for {self.output_ext}: {self.writers}.\n{msg}" ) + + +class WriteFileMapping(Transform): + """ + Writes a JSON file that logs the mapping between input image paths and their corresponding output paths. + This class uses FileLock to ensure safe writing to the JSON file in a multiprocess environment. + + Args: + mapping_file_path (Path or str): Path to the JSON file where the mappings will be saved. + """ + + def __init__(self, mapping_file_path: Path | str = "mapping.json"): + self.mapping_file_path = Path(mapping_file_path) + + def __call__(self, img: NdarrayOrTensor): + """ + Args: + img: The input image with metadata. + """ + if isinstance(img, MetaTensor): + meta_data = img.meta + + if MetaKeys.SAVED_TO not in meta_data: + raise KeyError( + "Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True." + ) + + input_path = meta_data[Key.FILENAME_OR_OBJ] + output_path = meta_data[MetaKeys.SAVED_TO] + log_data = {"input": input_path, "output": output_path} + + if has_filelock: + with FileLock(str(self.mapping_file_path) + ".lock"): + self._write_to_file(log_data) + else: + self._write_to_file(log_data) + return img + + def _write_to_file(self, log_data): + try: + with self.mapping_file_path.open("r") as f: + existing_log_data = json.load(f) + except (FileNotFoundError, json.JSONDecodeError): + existing_log_data = [] + existing_log_data.append(log_data) + with self.mapping_file_path.open("w") as f: + json.dump(existing_log_data, f, indent=4) diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index 4da1d422ca..be1e78db8a 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -17,16 +17,17 @@ from __future__ import annotations +from collections.abc import Hashable, Mapping from pathlib import Path from typing import Callable import numpy as np import monai -from monai.config import DtypeLike, KeysCollection +from monai.config import DtypeLike, KeysCollection, NdarrayOrTensor from monai.data import image_writer from monai.data.image_reader import ImageReader -from monai.transforms.io.array import LoadImage, SaveImage +from monai.transforms.io.array import LoadImage, SaveImage, WriteFileMapping from monai.transforms.transform import MapTransform, Transform from monai.utils import GridSamplePadMode, ensure_tuple, ensure_tuple_rep from monai.utils.enums import PostFix @@ -320,5 +321,31 @@ def __call__(self, data): return d +class WriteFileMappingd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.WriteFileMapping`. + + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + mapping_file_path: Path to the JSON file where the mappings will be saved. + Defaults to "mapping.json". + allow_missing_keys: don't raise exception if key is missing. + """ + + def __init__( + self, keys: KeysCollection, mapping_file_path: Path | str = "mapping.json", allow_missing_keys: bool = False + ) -> None: + super().__init__(keys, allow_missing_keys) + self.mapping = WriteFileMapping(mapping_file_path) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.mapping(d[key]) + return d + + LoadImageD = LoadImageDict = LoadImaged SaveImageD = SaveImageDict = SaveImaged +WriteFileMappingD = WriteFileMappingDict = WriteFileMappingd diff --git a/monai/utils/enums.py b/monai/utils/enums.py index b786e92151..eba1be18ed 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -543,6 +543,7 @@ class MetaKeys(StrEnum): SPATIAL_SHAPE = "spatial_shape" # optional key for the length in each spatial dimension SPACE = "space" # possible values of space type are defined in `SpaceKeys` ORIGINAL_CHANNEL_DIM = "original_channel_dim" # an integer or float("nan") + SAVED_TO = "saved_to" class ColorOrder(StrEnum): diff --git a/tests/test_mapping_file.py b/tests/test_mapping_file.py new file mode 100644 index 0000000000..97fa4312ed --- /dev/null +++ b/tests/test_mapping_file.py @@ -0,0 +1,117 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +import os +import shutil +import tempfile +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.data import DataLoader, Dataset +from monai.transforms import Compose, LoadImage, SaveImage, WriteFileMapping +from monai.utils import optional_import + +nib, has_nib = optional_import("nibabel") + + +def create_input_file(temp_dir, name): + test_image = np.random.rand(128, 128, 128) + output_ext = ".nii.gz" + input_file = os.path.join(temp_dir, name + output_ext) + nib.save(nib.Nifti1Image(test_image, np.eye(4)), input_file) + return input_file + + +def create_transform(temp_dir, mapping_file_path, savepath_in_metadict=True): + return Compose( + [ + LoadImage(image_only=True), + SaveImage(output_dir=temp_dir, output_ext=".nii.gz", savepath_in_metadict=savepath_in_metadict), + WriteFileMapping(mapping_file_path=mapping_file_path), + ] + ) + + +@unittest.skipUnless(has_nib, "nibabel required") +class TestWriteFileMapping(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.temp_dir) + + @parameterized.expand([(True,), (False,)]) + def test_mapping_file(self, savepath_in_metadict): + mapping_file_path = os.path.join(self.temp_dir, "mapping.json") + name = "test_image" + input_file = create_input_file(self.temp_dir, name) + output_file = os.path.join(self.temp_dir, name, name + "_trans.nii.gz") + + transform = create_transform(self.temp_dir, mapping_file_path, savepath_in_metadict) + + if savepath_in_metadict: + transform(input_file) + self.assertTrue(os.path.exists(mapping_file_path)) + with open(mapping_file_path) as f: + mapping_data = json.load(f) + self.assertEqual(len(mapping_data), 1) + self.assertEqual(mapping_data[0]["input"], input_file) + self.assertEqual(mapping_data[0]["output"], output_file) + else: + with self.assertRaises(RuntimeError) as cm: + transform(input_file) + cause_exception = cm.exception.__cause__ + self.assertIsInstance(cause_exception, KeyError) + self.assertIn( + "Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True.", + str(cause_exception), + ) + + def test_multiprocess_mapping_file(self): + num_images = 50 + + single_mapping_file = os.path.join(self.temp_dir, "single_mapping.json") + multi_mapping_file = os.path.join(self.temp_dir, "multi_mapping.json") + + data = [create_input_file(self.temp_dir, f"test_image_{i}") for i in range(num_images)] + + # single process + single_transform = create_transform(self.temp_dir, single_mapping_file) + single_dataset = Dataset(data=data, transform=single_transform) + single_loader = DataLoader(single_dataset, batch_size=1, num_workers=0, shuffle=True) + for _ in single_loader: + pass + + # multiple processes + multi_transform = create_transform(self.temp_dir, multi_mapping_file) + multi_dataset = Dataset(data=data, transform=multi_transform) + multi_loader = DataLoader(multi_dataset, batch_size=4, num_workers=3, shuffle=True) + for _ in multi_loader: + pass + + with open(single_mapping_file) as f: + single_mapping_data = json.load(f) + with open(multi_mapping_file) as f: + multi_mapping_data = json.load(f) + + single_set = {(entry["input"], entry["output"]) for entry in single_mapping_data} + multi_set = {(entry["input"], entry["output"]) for entry in multi_mapping_data} + + self.assertEqual(single_set, multi_set) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_mapping_filed.py b/tests/test_mapping_filed.py new file mode 100644 index 0000000000..d0f8bcf938 --- /dev/null +++ b/tests/test_mapping_filed.py @@ -0,0 +1,122 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +import os +import shutil +import tempfile +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.data import DataLoader, Dataset, decollate_batch +from monai.inferers import sliding_window_inference +from monai.networks.nets import UNet +from monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, SaveImaged, WriteFileMappingd +from monai.utils import optional_import + +nib, has_nib = optional_import("nibabel") + + +def create_input_file(temp_dir, name): + test_image = np.random.rand(128, 128, 128) + input_file = os.path.join(temp_dir, name + ".nii.gz") + nib.save(nib.Nifti1Image(test_image, np.eye(4)), input_file) + return input_file + + +# Test cases that should succeed +SUCCESS_CASES = [(["seg"], ["seg"]), (["image", "seg"], ["seg"])] + +# Test cases that should fail +FAILURE_CASES = [(["seg"], ["image"]), (["image"], ["seg"]), (["seg"], ["image", "seg"])] + + +@unittest.skipUnless(has_nib, "nibabel required") +class TestWriteFileMappingd(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + self.output_dir = os.path.join(self.temp_dir, "output") + os.makedirs(self.output_dir) + self.mapping_file_path = os.path.join(self.temp_dir, "mapping.json") + + def tearDown(self): + shutil.rmtree(self.temp_dir) + if os.path.exists(self.mapping_file_path): + os.remove(self.mapping_file_path) + + def run_test(self, save_keys, write_keys): + name = "test_image" + input_file = create_input_file(self.temp_dir, name) + output_file = os.path.join(self.output_dir, name, name + "_seg.nii.gz") + data = [{"image": input_file}] + + test_transforms = Compose([LoadImaged(keys=["image"]), EnsureChannelFirstd(keys=["image"])]) + + post_transforms = Compose( + [ + SaveImaged( + keys=save_keys, + meta_keys="image_meta_dict", + output_dir=self.output_dir, + output_postfix="seg", + savepath_in_metadict=True, + ), + WriteFileMappingd(keys=write_keys, mapping_file_path=self.mapping_file_path), + ] + ) + + dataset = Dataset(data=data, transform=test_transforms) + dataloader = DataLoader(dataset, batch_size=1) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = UNet(spatial_dims=3, in_channels=1, out_channels=2, channels=(16, 32), strides=(2,)).to(device) + model.eval() + + with torch.no_grad(): + for batch_data in dataloader: + test_inputs = batch_data["image"].to(device) + roi_size = (64, 64, 64) + sw_batch_size = 2 + batch_data["seg"] = sliding_window_inference(test_inputs, roi_size, sw_batch_size, model) + batch_data = [post_transforms(i) for i in decollate_batch(batch_data)] + + return input_file, output_file + + @parameterized.expand(SUCCESS_CASES) + def test_successful_mapping_filed(self, save_keys, write_keys): + input_file, output_file = self.run_test(save_keys, write_keys) + self.assertTrue(os.path.exists(self.mapping_file_path)) + with open(self.mapping_file_path) as f: + mapping_data = json.load(f) + self.assertEqual(len(mapping_data), len(write_keys)) + for entry in mapping_data: + self.assertEqual(entry["input"], input_file) + self.assertEqual(entry["output"], output_file) + + @parameterized.expand(FAILURE_CASES) + def test_failure_mapping_filed(self, save_keys, write_keys): + with self.assertRaises(RuntimeError) as cm: + self.run_test(save_keys, write_keys) + + cause_exception = cm.exception.__cause__ + self.assertIsInstance(cause_exception, KeyError) + self.assertIn( + "Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True.", + str(cause_exception), + ) + + +if __name__ == "__main__": + unittest.main()