From afae503ea4f5d89bdcf4d7f98797ae0f43022411 Mon Sep 17 00:00:00 2001 From: staydelight Date: Tue, 14 May 2024 16:47:47 +0800 Subject: [PATCH 01/49] Fixes #7557 Add a function to create a JSON file that maps input and output paths. Signed-off-by: staydelight --- monai/data/image_reader.py | 98 ++++++++++++++++++++++++++++---------- monai/data/image_writer.py | 28 ++++++++++- 2 files changed, 99 insertions(+), 27 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index f5e199e2a3..fa2c63b2e3 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -11,9 +11,11 @@ from __future__ import annotations +import json +import logging +import sys import glob import os -import re import warnings from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Iterator, Sequence @@ -21,6 +23,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any +from monai.apps.utils import get_logger import numpy as np from torch.utils.data._utils.collate import np_str_obj_array_pattern @@ -51,6 +54,16 @@ pydicom, has_pydicom = optional_import("pydicom") nrrd, has_nrrd = optional_import("nrrd", allow_namespace_pkg=True) +DEFAULT_FMT = "%(asctime)s %(levelname)s %(filename)s:%(lineno)d - %(message)s" + +logger = get_logger(module_name=__name__, fmt=DEFAULT_FMT) +logger = logging.getLogger(__name__) +handler = logging.StreamHandler(sys.stdout) +handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s')) +logger.addHandler(handler) +logger.setLevel(logging.DEBUG) + + __all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"] @@ -98,8 +111,10 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs) -> Sequence[Any] | kwargs: additional args for actual `read` API of 3rd party libs. """ + #self.update_json(input_file=data) raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + @abstractmethod def get_data(self, img) -> tuple[np.ndarray, dict]: """ @@ -147,6 +162,24 @@ def _stack_images(image_list: list, meta_dict: dict): meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0 return np.stack(image_list, axis=0) +def update_json(input_file=None, output_file=None): + record_path = "img-label.json" + + if not os.path.exists(record_path) or os.stat(record_path).st_size == 0: + with open(record_path, 'w') as f: + json.dump([], f) + + with open(record_path, 'r+') as f: + records = json.load(f) + if input_file: + new_record = {"image": input_file, "label": []} + records.append(new_record) + elif output_file and records: + records[-1]["label"].append(output_file) + + f.seek(0) + json.dump(records, f, indent=4) + @require_pkg(pkg_name="itk") class ITKReader(ImageReader): @@ -168,8 +201,8 @@ class ITKReader(ImageReader): series_name: the name of the DICOM series if there are multiple ones. used when loading DICOM series. reverse_indexing: whether to use a reversed spatial indexing convention for the returned data array. - If ``False``, the spatial indexing convention is reversed to be compatible with ITK; - otherwise, the spatial indexing follows the numpy convention. Default is ``False``. + If ``False``, the spatial indexing follows the numpy convention; + otherwise, the spatial indexing convention is reversed to be compatible with ITK. Default is ``False``. This option does not affect the metadata. series_meta: whether to load the metadata of the DICOM series (using the metadata from the first slice). This flag is checked only when loading DICOM series. Default is ``False``. @@ -225,6 +258,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): img_ = [] filenames: Sequence[PathLike] = ensure_tuple(data) + update_json(input_file=filenames) kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: @@ -332,6 +366,25 @@ def _get_affine(self, img, lps_to_ras: bool = True): affine[:sr, -1] = origin[:sr] if lps_to_ras: affine = orientation_ras_lps(affine) + logger.debug("lps is changed to ras") + + # 使用 Logger 輸出信息 + + logger.info("\nOrigin[:sr]:") + logger.info(", ".join(f"{x:.10f}" for x in origin[:sr])) + + logger.info("\nDirection[:sr, :sr]:") + for row in direction[:sr, :sr]: + logger.info(", ".join(f"{x:.15f}" for x in row)) + + logger.info("\nSpacing[:sr]:") + logger.info(", ".join(f"{x:.15f}" for x in spacing[:sr])) + + + # affine = numpy.round(affine, decimals=5) + + logger.debug(f"Affine matrix:\n{affine}") + return affine def _get_spatial_shape(self, img): @@ -404,12 +457,8 @@ class PydicomReader(ImageReader): label_dict: label of the dicom data. If provided, it will be used when loading segmentation data. Keys of the dict are the classes, and values are the corresponding class number. For example: for TCIA collection "C4KC-KiTS", it can be: {"Kidney": 0, "Renal Tumor": 1}. - fname_regex: a regular expression to match the file names when the input is a folder. - If provided, only the matched files will be included. For example, to include the file name - "image_0001.dcm", the regular expression could be `".*image_(\\d+).dcm"`. Default to `""`. - Set it to `None` to use `pydicom.misc.is_dicom` to match valid files. kwargs: additional args for `pydicom.dcmread` API. more details about available args: - https://pydicom.github.io/pydicom/stable/reference/generated/pydicom.filereader.dcmread.html + https://pydicom.github.io/pydicom/stable/reference/generated/pydicom.filereader.dcmread.html#pydicom.filereader.dcmread If the `get_data` function will be called (for example, when using this reader with `monai.transforms.LoadImage`), please ensure that the argument `stop_before_pixels` is `True`, and `specific_tags` covers all necessary tags, such as `PixelSpacing`, @@ -423,7 +472,6 @@ def __init__( swap_ij: bool = True, prune_metadata: bool = True, label_dict: dict | None = None, - fname_regex: str = "", **kwargs, ): super().__init__() @@ -433,7 +481,6 @@ def __init__( self.swap_ij = swap_ij self.prune_metadata = prune_metadata self.label_dict = label_dict - self.fname_regex = fname_regex def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: """ @@ -465,6 +512,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): img_ = [] filenames: Sequence[PathLike] = ensure_tuple(data) + update_json(input_file=filenames) kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) @@ -474,16 +522,9 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): name = f"{name}" if Path(name).is_dir(): # read DICOM series - if self.fname_regex is not None: - series_slcs = [slc for slc in glob.glob(os.path.join(name, "*")) if re.match(self.fname_regex, slc)] - else: - series_slcs = [slc for slc in glob.glob(os.path.join(name, "*")) if pydicom.misc.is_dicom(slc)] - slices = [] - for slc in series_slcs: - try: - slices.append(pydicom.dcmread(fp=slc, **kwargs_)) - except pydicom.errors.InvalidDicomError as e: - warnings.warn(f"Failed to read {slc} with exception: \n{e}.", stacklevel=2) + series_slcs = glob.glob(os.path.join(name, "*")) + series_slcs = [slc for slc in series_slcs if "LICENSE" not in slc] + slices = [pydicom.dcmread(fp=slc, **kwargs_) for slc in series_slcs] img_.append(slices if len(slices) > 1 else slices[0]) if len(slices) > 1: self.has_series = True @@ -913,9 +954,11 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): https://github.com/nipy/nibabel/blob/master/nibabel/loadsave.py """ + logger.info(f"Reading NIfTI data from: {data}") img_: list[Nifti1Image] = [] filenames: Sequence[PathLike] = ensure_tuple(data) + update_json(input_file=filenames) kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: @@ -1076,13 +1119,14 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): img_: list[Nifti1Image] = [] filenames: Sequence[PathLike] = ensure_tuple(data) + update_json(input_file=filenames) kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: img = np.load(name, allow_pickle=True, **kwargs_) if Path(name).name.endswith(".npz"): # load expected items from NPZ file - npz_keys = list(img.keys()) if self.npz_keys is None else self.npz_keys + npz_keys = [f"arr_{i}" for i in range(len(img))] if self.npz_keys is None else self.npz_keys for k in npz_keys: img_.append(img[k]) else: @@ -1173,6 +1217,7 @@ def read(self, data: Sequence[PathLike] | PathLike | np.ndarray, **kwargs): img_: list[PILImage.Image] = [] filenames: Sequence[PathLike] = ensure_tuple(data) + update_json(input_file=filenames) kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: @@ -1297,10 +1342,11 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs) -> Sequence[Any] | """ img_: list = [] filenames: Sequence[PathLike] = ensure_tuple(data) + update_json(input_file=filenames) kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: - nrrd_image = NrrdImage(*nrrd.read(name, index_order=self.index_order, **kwargs_)) + nrrd_image = NrrdImage(*nrrd.read(name, index_order=self.index_order, *kwargs_)) img_.append(nrrd_image) return img_ if len(filenames) > 1 else img_[0] @@ -1323,7 +1369,7 @@ def get_data(self, img: NrrdImage | list[NrrdImage]) -> tuple[np.ndarray, dict]: header = dict(i.header) if self.index_order == "C": header = self._convert_f_to_c_order(header) - header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(header) + header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(i) if self.affine_lps_to_ras: header = self._switch_lps_ras(header) @@ -1344,7 +1390,7 @@ def get_data(self, img: NrrdImage | list[NrrdImage]) -> tuple[np.ndarray, dict]: return _stack_images(img_array, compatible_meta), compatible_meta - def _get_affine(self, header: dict) -> np.ndarray: + def _get_affine(self, img: NrrdImage) -> np.ndarray: """ Get the affine matrix of the image, it can be used to correct spacing, orientation or execute spatial transforms. @@ -1353,8 +1399,8 @@ def _get_affine(self, header: dict) -> np.ndarray: img: A `NrrdImage` loaded from image file """ - direction = header["space directions"] - origin = header["space origin"] + direction = img.header["space directions"] + origin = img.header["space origin"] x, y = direction.shape affine_diam = min(x, y) + 1 diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py index b9e8b9e68e..06209c664a 100644 --- a/monai/data/image_writer.py +++ b/monai/data/image_writer.py @@ -15,6 +15,8 @@ from typing import TYPE_CHECKING, Any, cast import numpy as np +import os +import json from monai.apps.utils import get_logger from monai.config import DtypeLike, NdarrayOrTensor, PathLike @@ -196,6 +198,25 @@ def write(self, filename: PathLike, verbose: bool = True, **kwargs): if verbose: logger.info(f"writing: {filename}") + def update_json(self, input_file=None, output_file=None): + record_path = "img-label.json" + + if not os.path.exists(record_path) or os.stat(record_path).st_size == 0: + with open(record_path, 'w') as f: + json.dump([], f) + + with open(record_path, 'r+') as f: + records = json.load(f) + if input_file: + new_record = {"image": input_file, "label": []} + records.append(new_record) + elif output_file and records: + records[-1]["label"].append(output_file) + + f.seek(0) + json.dump(records, f, indent=4) + + @classmethod def create_backend_obj(cls, data_array: NdarrayOrTensor, **kwargs) -> np.ndarray: """ @@ -276,7 +297,7 @@ def resample_if_needed( # convert back at the end if isinstance(output_array, MetaTensor): output_array.applied_operations = [] - data_array, *_ = convert_data_type(output_array, output_type=orig_type) + data_array, *_ = convert_data_type(output_array, output_type=orig_type) # type: ignore affine, *_ = convert_data_type(output_array.affine, output_type=orig_type) # type: ignore return data_array[0], affine @@ -462,7 +483,9 @@ def write(self, filename: PathLike, verbose: bool = False, **kwargs): - https://github.com/InsightSoftwareConsortium/ITK/blob/v5.2.1/Wrapping/Generators/Python/itk/support/extras.py#L809 """ + logger.info(f"ITKWriter is processing the file: {filename}") super().write(filename, verbose=verbose) + super().update_json(output_file=filename) self.data_obj = self.create_backend_obj( cast(NdarrayOrTensor, self.data_obj), channel_dim=self.channel_dim, @@ -625,7 +648,9 @@ def write(self, filename: PathLike, verbose: bool = False, **obj_kwargs): - https://nipy.org/nibabel/reference/nibabel.nifti1.html#nibabel.nifti1.save """ + logger.info(f"NibabelWriter is processing the file: {filename}") super().write(filename, verbose=verbose) + super().update_json(output_file=filename) self.data_obj = self.create_backend_obj( cast(NdarrayOrTensor, self.data_obj), affine=self.affine, dtype=self.output_dtype, **obj_kwargs ) @@ -771,6 +796,7 @@ def write(self, filename: PathLike, verbose: bool = False, **kwargs): - https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image.save """ super().write(filename, verbose=verbose) + super().update_json(output_file=filename) self.data_obj = self.create_backend_obj( data_array=self.data_obj, dtype=self.output_dtype, From 542a77d5cee53c15dc6c17fef7961ec528e5a299 Mon Sep 17 00:00:00 2001 From: staydelight Date: Tue, 14 May 2024 17:33:04 +0800 Subject: [PATCH 02/49] Fixes #7557 Remove changes unrelated to this issue. Signed-off-by: staydelight --- monai/data/image_reader.py | 1476 ++++++++++++++++++++++++++++++++++-- 1 file changed, 1426 insertions(+), 50 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index fa2c63b2e3..257bebc831 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -11,11 +11,10 @@ from __future__ import annotations -import json -import logging -import sys import glob +import json import os +import re import warnings from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Iterator, Sequence @@ -23,7 +22,6 @@ from pathlib import Path from typing import TYPE_CHECKING, Any -from monai.apps.utils import get_logger import numpy as np from torch.utils.data._utils.collate import np_str_obj_array_pattern @@ -54,16 +52,6 @@ pydicom, has_pydicom = optional_import("pydicom") nrrd, has_nrrd = optional_import("nrrd", allow_namespace_pkg=True) -DEFAULT_FMT = "%(asctime)s %(levelname)s %(filename)s:%(lineno)d - %(message)s" - -logger = get_logger(module_name=__name__, fmt=DEFAULT_FMT) -logger = logging.getLogger(__name__) -handler = logging.StreamHandler(sys.stdout) -handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s')) -logger.addHandler(handler) -logger.setLevel(logging.DEBUG) - - __all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"] @@ -111,10 +99,8 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs) -> Sequence[Any] | kwargs: additional args for actual `read` API of 3rd party libs. """ - #self.update_json(input_file=data) raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") - @abstractmethod def get_data(self, img) -> tuple[np.ndarray, dict]: """ @@ -161,7 +147,8 @@ def _stack_images(image_list: list, meta_dict: dict): # stack at a new first dim as the channel dim, if `'original_channel_dim'` is unspecified meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0 return np.stack(image_list, axis=0) - + + def update_json(input_file=None, output_file=None): record_path = "img-label.json" @@ -201,8 +188,8 @@ class ITKReader(ImageReader): series_name: the name of the DICOM series if there are multiple ones. used when loading DICOM series. reverse_indexing: whether to use a reversed spatial indexing convention for the returned data array. - If ``False``, the spatial indexing follows the numpy convention; - otherwise, the spatial indexing convention is reversed to be compatible with ITK. Default is ``False``. + If ``False``, the spatial indexing convention is reversed to be compatible with ITK; + otherwise, the spatial indexing follows the numpy convention. Default is ``False``. This option does not affect the metadata. series_meta: whether to load the metadata of the DICOM series (using the metadata from the first slice). This flag is checked only when loading DICOM series. Default is ``False``. @@ -366,25 +353,6 @@ def _get_affine(self, img, lps_to_ras: bool = True): affine[:sr, -1] = origin[:sr] if lps_to_ras: affine = orientation_ras_lps(affine) - logger.debug("lps is changed to ras") - - # 使用 Logger 輸出信息 - - logger.info("\nOrigin[:sr]:") - logger.info(", ".join(f"{x:.10f}" for x in origin[:sr])) - - logger.info("\nDirection[:sr, :sr]:") - for row in direction[:sr, :sr]: - logger.info(", ".join(f"{x:.15f}" for x in row)) - - logger.info("\nSpacing[:sr]:") - logger.info(", ".join(f"{x:.15f}" for x in spacing[:sr])) - - - # affine = numpy.round(affine, decimals=5) - - logger.debug(f"Affine matrix:\n{affine}") - return affine def _get_spatial_shape(self, img): @@ -457,8 +425,12 @@ class PydicomReader(ImageReader): label_dict: label of the dicom data. If provided, it will be used when loading segmentation data. Keys of the dict are the classes, and values are the corresponding class number. For example: for TCIA collection "C4KC-KiTS", it can be: {"Kidney": 0, "Renal Tumor": 1}. + fname_regex: a regular expression to match the file names when the input is a folder. + If provided, only the matched files will be included. For example, to include the file name + "image_0001.dcm", the regular expression could be `".*image_(\\d+).dcm"`. Default to `""`. + Set it to `None` to use `pydicom.misc.is_dicom` to match valid files. kwargs: additional args for `pydicom.dcmread` API. more details about available args: - https://pydicom.github.io/pydicom/stable/reference/generated/pydicom.filereader.dcmread.html#pydicom.filereader.dcmread + https://pydicom.github.io/pydicom/stable/reference/generated/pydicom.filereader.dcmread.html If the `get_data` function will be called (for example, when using this reader with `monai.transforms.LoadImage`), please ensure that the argument `stop_before_pixels` is `True`, and `specific_tags` covers all necessary tags, such as `PixelSpacing`, @@ -472,6 +444,7 @@ def __init__( swap_ij: bool = True, prune_metadata: bool = True, label_dict: dict | None = None, + fname_regex: str = "", **kwargs, ): super().__init__() @@ -481,6 +454,7 @@ def __init__( self.swap_ij = swap_ij self.prune_metadata = prune_metadata self.label_dict = label_dict + self.fname_regex = fname_regex def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: """ @@ -522,9 +496,16 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): name = f"{name}" if Path(name).is_dir(): # read DICOM series - series_slcs = glob.glob(os.path.join(name, "*")) - series_slcs = [slc for slc in series_slcs if "LICENSE" not in slc] - slices = [pydicom.dcmread(fp=slc, **kwargs_) for slc in series_slcs] + if self.fname_regex is not None: + series_slcs = [slc for slc in glob.glob(os.path.join(name, "*")) if re.match(self.fname_regex, slc)] + else: + series_slcs = [slc for slc in glob.glob(os.path.join(name, "*")) if pydicom.misc.is_dicom(slc)] + slices = [] + for slc in series_slcs: + try: + slices.append(pydicom.dcmread(fp=slc, **kwargs_)) + except pydicom.errors.InvalidDicomError as e: + warnings.warn(f"Failed to read {slc} with exception: \n{e}.", stacklevel=2) img_.append(slices if len(slices) > 1 else slices[0]) if len(slices) > 1: self.has_series = True @@ -954,7 +935,6 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): https://github.com/nipy/nibabel/blob/master/nibabel/loadsave.py """ - logger.info(f"Reading NIfTI data from: {data}") img_: list[Nifti1Image] = [] filenames: Sequence[PathLike] = ensure_tuple(data) @@ -1126,7 +1106,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): img = np.load(name, allow_pickle=True, **kwargs_) if Path(name).name.endswith(".npz"): # load expected items from NPZ file - npz_keys = [f"arr_{i}" for i in range(len(img))] if self.npz_keys is None else self.npz_keys + npz_keys = list(img.keys()) if self.npz_keys is None else self.npz_keys for k in npz_keys: img_.append(img[k]) else: @@ -1346,7 +1326,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs) -> Sequence[Any] | kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: - nrrd_image = NrrdImage(*nrrd.read(name, index_order=self.index_order, *kwargs_)) + nrrd_image = NrrdImage(*nrrd.read(name, index_order=self.index_order, **kwargs_)) img_.append(nrrd_image) return img_ if len(filenames) > 1 else img_[0] @@ -1369,7 +1349,7 @@ def get_data(self, img: NrrdImage | list[NrrdImage]) -> tuple[np.ndarray, dict]: header = dict(i.header) if self.index_order == "C": header = self._convert_f_to_c_order(header) - header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(i) + header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(header) if self.affine_lps_to_ras: header = self._switch_lps_ras(header) @@ -1390,7 +1370,7 @@ def get_data(self, img: NrrdImage | list[NrrdImage]) -> tuple[np.ndarray, dict]: return _stack_images(img_array, compatible_meta), compatible_meta - def _get_affine(self, img: NrrdImage) -> np.ndarray: + def _get_affine(self, header: dict) -> np.ndarray: """ Get the affine matrix of the image, it can be used to correct spacing, orientation or execute spatial transforms. @@ -1399,8 +1379,8 @@ def _get_affine(self, img: NrrdImage) -> np.ndarray: img: A `NrrdImage` loaded from image file """ - direction = img.header["space directions"] - origin = img.header["space origin"] + direction = header["space directions"] + origin = header["space origin"] x, y = direction.shape affine_diam = min(x, y) + 1 @@ -1440,4 +1420,1400 @@ def _convert_f_to_c_order(self, header: dict) -> dict: header["space directions"] = np.rot90(np.flip(header["space directions"], 0)) header["space origin"] = header["space origin"][::-1] header["sizes"] = header["sizes"][::-1] - return header + return header# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import glob +import os +import re +import warnings +from abc import ABC, abstractmethod +from collections.abc import Callable, Iterable, Iterator, Sequence +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import numpy as np +from torch.utils.data._utils.collate import np_str_obj_array_pattern + +from monai.config import KeysCollection, PathLike +from monai.data.utils import ( + affine_to_spacing, + correct_nifti_header_if_necessary, + is_no_channel, + is_supported_format, + orientation_ras_lps, +) +from monai.utils import MetaKeys, SpaceKeys, TraceKeys, ensure_tuple, optional_import, require_pkg + +if TYPE_CHECKING: + import itk + import nibabel as nib + import nrrd + import pydicom + from nibabel.nifti1 import Nifti1Image + from PIL import Image as PILImage + + has_nrrd = has_itk = has_nib = has_pil = has_pydicom = True +else: + itk, has_itk = optional_import("itk", allow_namespace_pkg=True) + nib, has_nib = optional_import("nibabel") + Nifti1Image, _ = optional_import("nibabel.nifti1", name="Nifti1Image") + PILImage, has_pil = optional_import("PIL.Image") + pydicom, has_pydicom = optional_import("pydicom") + nrrd, has_nrrd = optional_import("nrrd", allow_namespace_pkg=True) + +__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"] + + +class ImageReader(ABC): + """ + An abstract class defines APIs to load image files. + + Typical usage of an implementation of this class is: + + .. code-block:: python + + image_reader = MyImageReader() + img_obj = image_reader.read(path_to_image) + img_data, meta_data = image_reader.get_data(img_obj) + + - The `read` call converts image filenames into image objects, + - The `get_data` call fetches the image data, as well as metadata. + - A reader should implement `verify_suffix` with the logic of checking the input filename + by the filename extensions. + + """ + + @abstractmethod + def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: + """ + Verify whether the specified `filename` is supported by the current reader. + This method should return True if the reader is able to read the format suggested by the + `filename`. + + Args: + filename: file name or a list of file names to read. + if a list of files, verify all the suffixes. + + """ + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + @abstractmethod + def read(self, data: Sequence[PathLike] | PathLike, **kwargs) -> Sequence[Any] | Any: + """ + Read image data from specified file or files. + Note that it returns a data object or a sequence of data objects. + + Args: + data: file name or a list of file names to read. + kwargs: additional args for actual `read` API of 3rd party libs. + + """ + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + @abstractmethod + def get_data(self, img) -> tuple[np.ndarray, dict]: + """ + Extract data array and metadata from loaded image and return them. + This function must return two objects, the first is a numpy array of image data, + the second is a dictionary of metadata. + + Args: + img: an image object loaded from an image file or a list of image objects. + + """ + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + +def _copy_compatible_dict(from_dict: dict, to_dict: dict): + if not isinstance(to_dict, dict): + raise ValueError(f"to_dict must be a Dict, got {type(to_dict)}.") + if not to_dict: + for key in from_dict: + datum = from_dict[key] + if isinstance(datum, np.ndarray) and np_str_obj_array_pattern.search(datum.dtype.str) is not None: + continue + to_dict[key] = str(TraceKeys.NONE) if datum is None else datum # NoneType to string for default_collate + else: + affine_key, shape_key = MetaKeys.AFFINE, MetaKeys.SPATIAL_SHAPE + if affine_key in from_dict and not np.allclose(from_dict[affine_key], to_dict[affine_key]): + raise RuntimeError( + "affine matrix of all images should be the same for channel-wise concatenation. " + f"Got {from_dict[affine_key]} and {to_dict[affine_key]}." + ) + if shape_key in from_dict and not np.allclose(from_dict[shape_key], to_dict[shape_key]): + raise RuntimeError( + "spatial_shape of all images should be the same for channel-wise concatenation. " + f"Got {from_dict[shape_key]} and {to_dict[shape_key]}." + ) + + +def _stack_images(image_list: list, meta_dict: dict): + if len(image_list) <= 1: + return image_list[0] + if not is_no_channel(meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None)): + channel_dim = int(meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM]) + return np.concatenate(image_list, axis=channel_dim) + # stack at a new first dim as the channel dim, if `'original_channel_dim'` is unspecified + meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0 + return np.stack(image_list, axis=0) + + +@require_pkg(pkg_name="itk") +class ITKReader(ImageReader): + """ + Load medical images based on ITK library. + All the supported image formats can be found at: + https://github.com/InsightSoftwareConsortium/ITK/tree/master/Modules/IO + The loaded data array will be in C order, for example, a 3D image NumPy + array index order will be `CDWH`. + + Args: + channel_dim: the channel dimension of the input image, default is None. + This is used to set original_channel_dim in the metadata, EnsureChannelFirstD reads this field. + If None, `original_channel_dim` will be either `no_channel` or `-1`. + + - Nifti file is usually "channel last", so there is no need to specify this argument. + - PNG file usually has `GetNumberOfComponentsPerPixel()==3`, so there is no need to specify this argument. + + series_name: the name of the DICOM series if there are multiple ones. + used when loading DICOM series. + reverse_indexing: whether to use a reversed spatial indexing convention for the returned data array. + If ``False``, the spatial indexing convention is reversed to be compatible with ITK; + otherwise, the spatial indexing follows the numpy convention. Default is ``False``. + This option does not affect the metadata. + series_meta: whether to load the metadata of the DICOM series (using the metadata from the first slice). + This flag is checked only when loading DICOM series. Default is ``False``. + affine_lps_to_ras: whether to convert the affine matrix from "LPS" to "RAS". Defaults to ``True``. + Set to ``True`` to be consistent with ``NibabelReader``, otherwise the affine matrix remains in the ITK convention. + kwargs: additional args for `itk.imread` API. more details about available args: + https://github.com/InsightSoftwareConsortium/ITK/blob/master/Wrapping/Generators/Python/itk/support/extras.py + + """ + + def __init__( + self, + channel_dim: str | int | None = None, + series_name: str = "", + reverse_indexing: bool = False, + series_meta: bool = False, + affine_lps_to_ras: bool = True, + **kwargs, + ): + super().__init__() + self.kwargs = kwargs + self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim + self.series_name = series_name + self.reverse_indexing = reverse_indexing + self.series_meta = series_meta + self.affine_lps_to_ras = affine_lps_to_ras + + def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: + """ + Verify whether the specified file or files format is supported by ITK reader. + + Args: + filename: file name or a list of file names to read. + if a list of files, verify all the suffixes. + + """ + return has_itk + + def read(self, data: Sequence[PathLike] | PathLike, **kwargs): + """ + Read image data from specified file or files, it can read a list of images + and stack them together as multi-channel data in `get_data()`. + If passing directory path instead of file path, will treat it as DICOM images series and read. + Note that the returned object is ITK image object or list of ITK image objects. + + Args: + data: file name or a list of file names to read, + kwargs: additional args for `itk.imread` API, will override `self.kwargs` for existing keys. + More details about available args: + https://github.com/InsightSoftwareConsortium/ITK/blob/master/Wrapping/Generators/Python/itk/support/extras.py + + """ + img_ = [] + + filenames: Sequence[PathLike] = ensure_tuple(data) + kwargs_ = self.kwargs.copy() + kwargs_.update(kwargs) + for name in filenames: + name = f"{name}" + if Path(name).is_dir(): + # read DICOM series + # https://examples.itk.org/src/io/gdcm/readdicomseriesandwrite3dimage/documentation + names_generator = itk.GDCMSeriesFileNames.New() + names_generator.SetUseSeriesDetails(True) + names_generator.AddSeriesRestriction("0008|0021") # Series Date + names_generator.SetDirectory(name) + series_uid = names_generator.GetSeriesUIDs() + + if len(series_uid) < 1: + raise FileNotFoundError(f"no DICOMs in: {name}.") + if len(series_uid) > 1: + warnings.warn(f"the directory: {name} contains more than one DICOM series.") + series_identifier = series_uid[0] if not self.series_name else self.series_name + name = names_generator.GetFileNames(series_identifier) + + name = name[0] if len(name) == 1 else name # type: ignore + _obj = itk.imread(name, **kwargs_) + if self.series_meta: + _reader = itk.ImageSeriesReader.New(FileNames=name) + _reader.Update() + _meta = _reader.GetMetaDataDictionaryArray() + if len(_meta) > 0: + # TODO: using the first slice's meta. this could be improved to filter unnecessary tags. + _obj.SetMetaDataDictionary(_meta[0]) + img_.append(_obj) + else: + img_.append(itk.imread(name, **kwargs_)) + return img_ if len(filenames) > 1 else img_[0] + + def get_data(self, img) -> tuple[np.ndarray, dict]: + """ + Extract data array and metadata from loaded image and return them. + This function returns two objects, first is numpy array of image data, second is dict of metadata. + It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict. + When loading a list of files, they are stacked together at a new dimension as the first dimension, + and the metadata of the first image is used to represent the output metadata. + + Args: + img: an ITK image object loaded from an image file or a list of ITK image objects. + + """ + img_array: list[np.ndarray] = [] + compatible_meta: dict = {} + + for i in ensure_tuple(img): + data = self._get_array_data(i) + img_array.append(data) + header = self._get_meta_dict(i) + header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(i, self.affine_lps_to_ras) + header[MetaKeys.SPACE] = SpaceKeys.RAS if self.affine_lps_to_ras else SpaceKeys.LPS + header[MetaKeys.AFFINE] = header[MetaKeys.ORIGINAL_AFFINE].copy() + header[MetaKeys.SPATIAL_SHAPE] = self._get_spatial_shape(i) + if self.channel_dim is None: # default to "no_channel" or -1 + header[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( + float("nan") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1 + ) + else: + header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim + _copy_compatible_dict(header, compatible_meta) + + return _stack_images(img_array, compatible_meta), compatible_meta + + def _get_meta_dict(self, img) -> dict: + """ + Get all the metadata of the image and convert to dict type. + + Args: + img: an ITK image object loaded from an image file. + + """ + img_meta_dict = img.GetMetaDataDictionary() + meta_dict = {} + for key in img_meta_dict.GetKeys(): + if key.startswith("ITK_"): + continue + val = img_meta_dict[key] + meta_dict[key] = np.asarray(val) if type(val).__name__.startswith("itk") else val + + meta_dict["spacing"] = np.asarray(img.GetSpacing()) + return meta_dict + + def _get_affine(self, img, lps_to_ras: bool = True): + """ + Get or construct the affine matrix of the image, it can be used to correct + spacing, orientation or execute spatial transforms. + + Args: + img: an ITK image object loaded from an image file. + lps_to_ras: whether to convert the affine matrix from "LPS" to "RAS". Defaults to True. + + """ + direction = itk.array_from_matrix(img.GetDirection()) + spacing = np.asarray(img.GetSpacing()) + origin = np.asarray(img.GetOrigin()) + + direction = np.asarray(direction) + sr = min(max(direction.shape[0], 1), 3) + affine: np.ndarray = np.eye(sr + 1) + affine[:sr, :sr] = direction[:sr, :sr] @ np.diag(spacing[:sr]) + affine[:sr, -1] = origin[:sr] + if lps_to_ras: + affine = orientation_ras_lps(affine) + return affine + + def _get_spatial_shape(self, img): + """ + Get the spatial shape of `img`. + + Args: + img: an ITK image object loaded from an image file. + + """ + sr = itk.array_from_matrix(img.GetDirection()).shape[0] + sr = max(min(sr, 3), 1) + _size = list(itk.size(img)) + if isinstance(self.channel_dim, int): + _size.pop(self.channel_dim) + return np.asarray(_size[:sr]) + + def _get_array_data(self, img): + """ + Get the raw array data of the image, converted to Numpy array. + + Following PyTorch conventions, the returned array data has contiguous channels, + e.g. for an RGB image, all red channel image pixels are contiguous in memory. + The last axis of the returned array is the channel axis. + + See also: + + - https://github.com/InsightSoftwareConsortium/ITK/blob/v5.2.1/Modules/Bridge/NumPy/wrapping/PyBuffer.i.in + + Args: + img: an ITK image object loaded from an image file. + + """ + np_img = itk.array_view_from_image(img, keep_axes=False) + if img.GetNumberOfComponentsPerPixel() == 1: # handling spatial images + return np_img if self.reverse_indexing else np_img.T + # handling multi-channel images + return np_img if self.reverse_indexing else np.moveaxis(np_img.T, 0, -1) + + +@require_pkg(pkg_name="pydicom") +class PydicomReader(ImageReader): + """ + Load medical images based on Pydicom library. + All the supported image formats can be found at: + https://dicom.nema.org/medical/dicom/current/output/chtml/part10/chapter_7.html + + PydicomReader is also able to load segmentations, if a dicom file contains tag: `SegmentSequence`, the reader + will consider it as segmentation data, and to load it successfully, `PerFrameFunctionalGroupsSequence` is required + for dicom file, and for each frame of dicom file, `SegmentIdentificationSequence` is required. + This method refers to the Highdicom library. + + This class refers to: + https://nipy.org/nibabel/dicom/dicom_orientation.html#dicom-affine-formula + https://github.com/pydicom/contrib-pydicom/blob/master/input-output/pydicom_series.py + https://highdicom.readthedocs.io/en/latest/usage.html#parsing-segmentation-seg-images + + Args: + channel_dim: the channel dimension of the input image, default is None. + This is used to set original_channel_dim in the metadata, EnsureChannelFirstD reads this field. + If None, `original_channel_dim` will be either `no_channel` or `-1`. + affine_lps_to_ras: whether to convert the affine matrix from "LPS" to "RAS". Defaults to ``True``. + Set to ``True`` to be consistent with ``NibabelReader``, + otherwise the affine matrix remains in the Dicom convention. + swap_ij: whether to swap the first two spatial axes. Default to ``True``, so that the outputs + are consistent with the other readers. + prune_metadata: whether to prune the saved information in metadata. This argument is used for + `get_data` function. If True, only items that are related to the affine matrix will be saved. + Default to ``True``. + label_dict: label of the dicom data. If provided, it will be used when loading segmentation data. + Keys of the dict are the classes, and values are the corresponding class number. For example: + for TCIA collection "C4KC-KiTS", it can be: {"Kidney": 0, "Renal Tumor": 1}. + fname_regex: a regular expression to match the file names when the input is a folder. + If provided, only the matched files will be included. For example, to include the file name + "image_0001.dcm", the regular expression could be `".*image_(\\d+).dcm"`. Default to `""`. + Set it to `None` to use `pydicom.misc.is_dicom` to match valid files. + kwargs: additional args for `pydicom.dcmread` API. more details about available args: + https://pydicom.github.io/pydicom/stable/reference/generated/pydicom.filereader.dcmread.html + If the `get_data` function will be called + (for example, when using this reader with `monai.transforms.LoadImage`), please ensure that the argument + `stop_before_pixels` is `True`, and `specific_tags` covers all necessary tags, such as `PixelSpacing`, + `ImagePositionPatient`, `ImageOrientationPatient` and all `pixel_array` related tags. + """ + + def __init__( + self, + channel_dim: str | int | None = None, + affine_lps_to_ras: bool = True, + swap_ij: bool = True, + prune_metadata: bool = True, + label_dict: dict | None = None, + fname_regex: str = "", + **kwargs, + ): + super().__init__() + self.kwargs = kwargs + self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim + self.affine_lps_to_ras = affine_lps_to_ras + self.swap_ij = swap_ij + self.prune_metadata = prune_metadata + self.label_dict = label_dict + self.fname_regex = fname_regex + + def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: + """ + Verify whether the specified file or files format is supported by Pydicom reader. + + Args: + filename: file name or a list of file names to read. + if a list of files, verify all the suffixes. + + """ + return has_pydicom + + def read(self, data: Sequence[PathLike] | PathLike, **kwargs): + """ + Read image data from specified file or files, it can read a list of images + and stack them together as multi-channel data in `get_data()`. + If passing directory path instead of file path, will treat it as DICOM images series and read. + + Args: + data: file name or a list of file names to read, + kwargs: additional args for `pydicom.dcmread` API, will override `self.kwargs` for existing keys. + + Returns: + If `data` represents a filename: return a pydicom dataset object. + If `data` represents a list of filenames or a directory: return a list of pydicom dataset object. + If `data` represents a list of directories: return a list of list of pydicom dataset object. + + """ + img_ = [] + + filenames: Sequence[PathLike] = ensure_tuple(data) + kwargs_ = self.kwargs.copy() + kwargs_.update(kwargs) + + self.has_series = False + + for name in filenames: + name = f"{name}" + if Path(name).is_dir(): + # read DICOM series + if self.fname_regex is not None: + series_slcs = [slc for slc in glob.glob(os.path.join(name, "*")) if re.match(self.fname_regex, slc)] + else: + series_slcs = [slc for slc in glob.glob(os.path.join(name, "*")) if pydicom.misc.is_dicom(slc)] + slices = [] + for slc in series_slcs: + try: + slices.append(pydicom.dcmread(fp=slc, **kwargs_)) + except pydicom.errors.InvalidDicomError as e: + warnings.warn(f"Failed to read {slc} with exception: \n{e}.", stacklevel=2) + img_.append(slices if len(slices) > 1 else slices[0]) + if len(slices) > 1: + self.has_series = True + else: + ds = pydicom.dcmread(fp=name, **kwargs_) + img_.append(ds) + return img_ if len(filenames) > 1 else img_[0] + + def _combine_dicom_series(self, data: Iterable): + """ + Combine dicom series (a list of pydicom dataset objects). Their data arrays will be stacked together at a new + dimension as the last dimension. + + The stack order depends on Instance Number. The metadata will be based on the + first slice's metadata, and some new items will be added: + + "spacing": the new spacing of the stacked slices. + "lastImagePositionPatient": `ImagePositionPatient` for the last slice, it will be used to achieve the affine + matrix. + "spatial_shape": the spatial shape of the stacked slices. + + Args: + data: a list of pydicom dataset objects. + Returns: + a tuple that consisted with data array and metadata. + """ + slices: list = [] + # for a dicom series + for slc_ds in data: + if hasattr(slc_ds, "InstanceNumber"): + slices.append(slc_ds) + else: + warnings.warn(f"slice: {slc_ds.filename} does not have InstanceNumber tag, skip it.") + slices = sorted(slices, key=lambda s: s.InstanceNumber) + + if len(slices) == 0: + raise ValueError("the input does not have valid slices.") + + first_slice = slices[0] + average_distance = 0.0 + first_array = self._get_array_data(first_slice) + shape = first_array.shape + spacing = getattr(first_slice, "PixelSpacing", [1.0, 1.0, 1.0]) + prev_pos = getattr(first_slice, "ImagePositionPatient", (0.0, 0.0, 0.0))[2] + stack_array = [first_array] + for idx in range(1, len(slices)): + slc_array = self._get_array_data(slices[idx]) + slc_shape = slc_array.shape + slc_spacing = getattr(slices[idx], "PixelSpacing", (1.0, 1.0, 1.0)) + slc_pos = getattr(slices[idx], "ImagePositionPatient", (0.0, 0.0, float(idx)))[2] + if not np.allclose(slc_spacing, spacing): + warnings.warn(f"the list contains slices that have different spacings {spacing} and {slc_spacing}.") + if shape != slc_shape: + warnings.warn(f"the list contains slices that have different shapes {shape} and {slc_shape}.") + average_distance += abs(prev_pos - slc_pos) + prev_pos = slc_pos + stack_array.append(slc_array) + + if len(slices) > 1: + average_distance /= len(slices) - 1 + spacing.append(average_distance) + stack_array = np.stack(stack_array, axis=-1) + stack_metadata = self._get_meta_dict(first_slice) + stack_metadata["spacing"] = np.asarray(spacing) + if hasattr(slices[-1], "ImagePositionPatient"): + stack_metadata["lastImagePositionPatient"] = np.asarray(slices[-1].ImagePositionPatient) + stack_metadata[MetaKeys.SPATIAL_SHAPE] = shape + (len(slices),) + else: + stack_array = stack_array[0] + stack_metadata = self._get_meta_dict(first_slice) + stack_metadata["spacing"] = np.asarray(spacing) + stack_metadata[MetaKeys.SPATIAL_SHAPE] = shape + + return stack_array, stack_metadata + + def get_data(self, data) -> tuple[np.ndarray, dict]: + """ + Extract data array and metadata from loaded image and return them. + This function returns two objects, first is numpy array of image data, second is dict of metadata. + It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict. + For dicom series within the input, all slices will be stacked first, + When loading a list of files (dicom file, or stacked dicom series), they are stacked together at a new + dimension as the first dimension, and the metadata of the first image is used to represent the output metadata. + + To use this function, all pydicom dataset objects (if not segmentation data) should contain: + `pixel_array`, `PixelSpacing`, `ImagePositionPatient` and `ImageOrientationPatient`. + + For segmentation data, we assume that the input is not a dicom series, and the object should contain + `SegmentSequence` in order to identify it. + In addition, tags (5200, 9229) and (5200, 9230) are required to achieve + `PixelSpacing`, `ImageOrientationPatient` and `ImagePositionPatient`. + + Args: + data: a pydicom dataset object, or a list of pydicom dataset objects, or a list of list of + pydicom dataset objects. + + """ + + dicom_data = [] + # combine dicom series if exists + if self.has_series is True: + # a list, all objects within a list belong to one dicom series + if not isinstance(data[0], list): + dicom_data.append(self._combine_dicom_series(data)) + # a list of list, each inner list represents a dicom series + else: + for series in data: + dicom_data.append(self._combine_dicom_series(series)) + else: + # a single pydicom dataset object + if not isinstance(data, list): + data = [data] + for d in data: + if hasattr(d, "SegmentSequence"): + data_array, metadata = self._get_seg_data(d) + else: + data_array = self._get_array_data(d) + metadata = self._get_meta_dict(d) + metadata[MetaKeys.SPATIAL_SHAPE] = data_array.shape + dicom_data.append((data_array, metadata)) + + img_array: list[np.ndarray] = [] + compatible_meta: dict = {} + + for data_array, metadata in ensure_tuple(dicom_data): + img_array.append(np.ascontiguousarray(np.swapaxes(data_array, 0, 1) if self.swap_ij else data_array)) + affine = self._get_affine(metadata, self.affine_lps_to_ras) + metadata[MetaKeys.SPACE] = SpaceKeys.RAS if self.affine_lps_to_ras else SpaceKeys.LPS + if self.swap_ij: + affine = affine @ np.array([[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) + sp_size = list(metadata[MetaKeys.SPATIAL_SHAPE]) + sp_size[0], sp_size[1] = sp_size[1], sp_size[0] + metadata[MetaKeys.SPATIAL_SHAPE] = ensure_tuple(sp_size) + metadata[MetaKeys.ORIGINAL_AFFINE] = affine + metadata[MetaKeys.AFFINE] = affine.copy() + if self.channel_dim is None: # default to "no_channel" or -1 + metadata[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( + float("nan") if len(data_array.shape) == len(metadata[MetaKeys.SPATIAL_SHAPE]) else -1 + ) + else: + metadata[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim + metadata["spacing"] = affine_to_spacing( + metadata[MetaKeys.ORIGINAL_AFFINE], r=len(metadata[MetaKeys.SPATIAL_SHAPE]) + ) + + _copy_compatible_dict(metadata, compatible_meta) + + return _stack_images(img_array, compatible_meta), compatible_meta + + def _get_meta_dict(self, img) -> dict: + """ + Get all the metadata of the image and convert to dict type. + + Args: + img: a Pydicom dataset object. + + """ + + metadata = img.to_json_dict(suppress_invalid_tags=True) + + if self.prune_metadata: + prune_metadata = {} + for key in ["00200037", "00200032", "00280030", "52009229", "52009230"]: + if key in metadata.keys(): + prune_metadata[key] = metadata[key] + return prune_metadata + + # always remove Pixel Data "7FE00008" or "7FE00009" or "7FE00010" + # always remove Data Set Trailing Padding "FFFCFFFC" + for key in ["7FE00008", "7FE00009", "7FE00010", "FFFCFFFC"]: + if key in metadata.keys(): + metadata.pop(key) + + return metadata # type: ignore + + def _get_affine(self, metadata: dict, lps_to_ras: bool = True): + """ + Get or construct the affine matrix of the image, it can be used to correct + spacing, orientation or execute spatial transforms. + + Args: + metadata: metadata with dict type. + lps_to_ras: whether to convert the affine matrix from "LPS" to "RAS". Defaults to True. + + """ + affine: np.ndarray = np.eye(4) + if not ("00200037" in metadata and "00200032" in metadata): + return affine + # "00200037" is the tag of `ImageOrientationPatient` + rx, ry, rz, cx, cy, cz = metadata["00200037"]["Value"] + # "00200032" is the tag of `ImagePositionPatient` + sx, sy, sz = metadata["00200032"]["Value"] + # "00280030" is the tag of `PixelSpacing` + spacing = metadata["00280030"]["Value"] if "00280030" in metadata else (1.0, 1.0) + dr, dc = metadata.get("spacing", spacing)[:2] + affine[0, 0] = cx * dr + affine[0, 1] = rx * dc + affine[0, 3] = sx + affine[1, 0] = cy * dr + affine[1, 1] = ry * dc + affine[1, 3] = sy + affine[2, 0] = cz * dr + affine[2, 1] = rz * dc + affine[2, 2] = 1.0 + affine[2, 3] = sz + + # 3d + if "lastImagePositionPatient" in metadata: + t1n, t2n, t3n = metadata["lastImagePositionPatient"] + n = metadata[MetaKeys.SPATIAL_SHAPE][-1] + k1, k2, k3 = (t1n - sx) / (n - 1), (t2n - sy) / (n - 1), (t3n - sz) / (n - 1) + affine[0, 2] = k1 + affine[1, 2] = k2 + affine[2, 2] = k3 + + if lps_to_ras: + affine = orientation_ras_lps(affine) + return affine + + def _get_frame_data(self, img) -> Iterator: + """ + yield frames and description from the segmentation image. + This function is adapted from Highdicom: + https://github.com/herrmannlab/highdicom/blob/v0.18.2/src/highdicom/seg/utils.py + + which has the following license... + + # ========================================================================= + # https://github.com/herrmannlab/highdicom/blob/v0.18.2/LICENSE + # + # Copyright 2020 MGH Computational Pathology + # Permission is hereby granted, free of charge, to any person obtaining a + # copy of this software and associated documentation files (the + # "Software"), to deal in the Software without restriction, including + # without limitation the rights to use, copy, modify, merge, publish, + # distribute, sublicense, and/or sell copies of the Software, and to + # permit persons to whom the Software is furnished to do so, subject to + # the following conditions: + # The above copyright notice and this permission notice shall be included + # in all copies or substantial portions of the Software. + # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + # OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + # IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + # CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + # SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + # ========================================================================= + + (https://github.com/herrmannlab/highdicom/issues/188) + + Args: + img: a Pydicom dataset object that has attribute "SegmentSequence". + + """ + + if not hasattr(img, "PerFrameFunctionalGroupsSequence"): + raise NotImplementedError( + f"To read dicom seg: {img.filename}, 'PerFrameFunctionalGroupsSequence' is required." + ) + + frame_seg_nums = [] + for f in img.PerFrameFunctionalGroupsSequence: + if not hasattr(f, "SegmentIdentificationSequence"): + raise NotImplementedError( + f"To read dicom seg: {img.filename}, 'SegmentIdentificationSequence' is required for each frame." + ) + frame_seg_nums.append(int(f.SegmentIdentificationSequence[0].ReferencedSegmentNumber)) + + frame_seg_nums_arr = np.array(frame_seg_nums) + + seg_descriptions = {int(f.SegmentNumber): f for f in img.SegmentSequence} + + for i in np.unique(frame_seg_nums_arr): + indices = np.where(frame_seg_nums_arr == i)[0] + yield (img.pixel_array[indices, ...], seg_descriptions[i]) + + def _get_seg_data(self, img): + """ + Get the array data and metadata of the segmentation image. + + Aegs: + img: a Pydicom dataset object that has attribute "SegmentSequence". + + """ + + metadata = self._get_meta_dict(img) + n_classes = len(img.SegmentSequence) + spatial_shape = list(img.pixel_array.shape) + spatial_shape[0] = spatial_shape[0] // n_classes + + if self.label_dict is not None: + metadata["labels"] = self.label_dict + all_segs = np.zeros([*spatial_shape, len(self.label_dict)]) + else: + metadata["labels"] = {} + all_segs = np.zeros([*spatial_shape, n_classes]) + + for i, (frames, description) in enumerate(self._get_frame_data(img)): + segment_label = getattr(description, "SegmentLabel", f"label_{i}") + class_name = getattr(description, "SegmentDescription", segment_label) + if class_name not in metadata["labels"].keys(): + metadata["labels"][class_name] = i + class_num = metadata["labels"][class_name] + all_segs[..., class_num] = frames + + all_segs = all_segs.transpose([1, 2, 0, 3]) + metadata[MetaKeys.SPATIAL_SHAPE] = all_segs.shape[:-1] + + if "52009229" in metadata.keys(): + shared_func_group_seq = metadata["52009229"]["Value"][0] + + # get `ImageOrientationPatient` + if "00209116" in shared_func_group_seq.keys(): + plane_orient_seq = shared_func_group_seq["00209116"]["Value"][0] + if "00200037" in plane_orient_seq.keys(): + metadata["00200037"] = plane_orient_seq["00200037"] + + # get `PixelSpacing` + if "00289110" in shared_func_group_seq.keys(): + pixel_measure_seq = shared_func_group_seq["00289110"]["Value"][0] + + if "00280030" in pixel_measure_seq.keys(): + pixel_spacing = pixel_measure_seq["00280030"]["Value"] + metadata["spacing"] = pixel_spacing + if "00180050" in pixel_measure_seq.keys(): + metadata["spacing"] += pixel_measure_seq["00180050"]["Value"] + + if self.prune_metadata: + metadata.pop("52009229") + + # get `ImagePositionPatient` + if "52009230" in metadata.keys(): + first_frame_func_group_seq = metadata["52009230"]["Value"][0] + if "00209113" in first_frame_func_group_seq.keys(): + plane_position_seq = first_frame_func_group_seq["00209113"]["Value"][0] + if "00200032" in plane_position_seq.keys(): + metadata["00200032"] = plane_position_seq["00200032"] + metadata["lastImagePositionPatient"] = metadata["52009230"]["Value"][-1]["00209113"]["Value"][0][ + "00200032" + ]["Value"] + if self.prune_metadata: + metadata.pop("52009230") + + return all_segs, metadata + + def _get_array_data(self, img): + """ + Get the array data of the image. If `RescaleSlope` and `RescaleIntercept` are available, the raw array data + will be rescaled. The output data has the dtype np.float32 if the rescaling is applied. + + Args: + img: a Pydicom dataset object. + + """ + # process Dicom series + if not hasattr(img, "pixel_array"): + raise ValueError(f"dicom data: {img.filename} does not have pixel_array.") + data = img.pixel_array + + slope, offset = 1.0, 0.0 + rescale_flag = False + if hasattr(img, "RescaleSlope"): + slope = img.RescaleSlope + rescale_flag = True + if hasattr(img, "RescaleIntercept"): + offset = img.RescaleIntercept + rescale_flag = True + if rescale_flag: + data = data.astype(np.float32) * slope + offset + + return data + + +@require_pkg(pkg_name="nibabel") +class NibabelReader(ImageReader): + """ + Load NIfTI format images based on Nibabel library. + + Args: + as_closest_canonical: if True, load the image as closest to canonical axis format. + squeeze_non_spatial_dims: if True, non-spatial singletons will be squeezed, e.g. (256,256,1,3) -> (256,256,3) + channel_dim: the channel dimension of the input image, default is None. + this is used to set original_channel_dim in the metadata, EnsureChannelFirstD reads this field. + if None, `original_channel_dim` will be either `no_channel` or `-1`. + most Nifti files are usually "channel last", no need to specify this argument for them. + kwargs: additional args for `nibabel.load` API. more details about available args: + https://github.com/nipy/nibabel/blob/master/nibabel/loadsave.py + + """ + + def __init__( + self, + channel_dim: str | int | None = None, + as_closest_canonical: bool = False, + squeeze_non_spatial_dims: bool = False, + **kwargs, + ): + super().__init__() + self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim + self.as_closest_canonical = as_closest_canonical + self.squeeze_non_spatial_dims = squeeze_non_spatial_dims + self.kwargs = kwargs + + def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: + """ + Verify whether the specified file or files format is supported by Nibabel reader. + + Args: + filename: file name or a list of file names to read. + if a list of files, verify all the suffixes. + + """ + suffixes: Sequence[str] = ["nii", "nii.gz"] + return has_nib and is_supported_format(filename, suffixes) + + def read(self, data: Sequence[PathLike] | PathLike, **kwargs): + """ + Read image data from specified file or files, it can read a list of images + and stack them together as multi-channel data in `get_data()`. + Note that the returned object is Nibabel image object or list of Nibabel image objects. + + Args: + data: file name or a list of file names to read. + kwargs: additional args for `nibabel.load` API, will override `self.kwargs` for existing keys. + More details about available args: + https://github.com/nipy/nibabel/blob/master/nibabel/loadsave.py + + """ + img_: list[Nifti1Image] = [] + + filenames: Sequence[PathLike] = ensure_tuple(data) + kwargs_ = self.kwargs.copy() + kwargs_.update(kwargs) + for name in filenames: + img = nib.load(name, **kwargs_) + img = correct_nifti_header_if_necessary(img) + img_.append(img) # type: ignore + return img_ if len(filenames) > 1 else img_[0] + + def get_data(self, img) -> tuple[np.ndarray, dict]: + """ + Extract data array and metadata from loaded image and return them. + This function returns two objects, first is numpy array of image data, second is dict of metadata. + It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict. + When loading a list of files, they are stacked together at a new dimension as the first dimension, + and the metadata of the first image is used to present the output metadata. + + Args: + img: a Nibabel image object loaded from an image file or a list of Nibabel image objects. + + """ + img_array: list[np.ndarray] = [] + compatible_meta: dict = {} + + for i in ensure_tuple(img): + header = self._get_meta_dict(i) + header[MetaKeys.AFFINE] = self._get_affine(i) + header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(i) + header["as_closest_canonical"] = self.as_closest_canonical + if self.as_closest_canonical: + i = nib.as_closest_canonical(i) + header[MetaKeys.AFFINE] = self._get_affine(i) + header[MetaKeys.SPATIAL_SHAPE] = self._get_spatial_shape(i) + header[MetaKeys.SPACE] = SpaceKeys.RAS + data = self._get_array_data(i) + if self.squeeze_non_spatial_dims: + for d in range(len(data.shape), len(header[MetaKeys.SPATIAL_SHAPE]), -1): + if data.shape[d - 1] == 1: + data = data.squeeze(axis=d - 1) + img_array.append(data) + if self.channel_dim is None: # default to "no_channel" or -1 + header[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( + float("nan") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1 + ) + else: + header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim + _copy_compatible_dict(header, compatible_meta) + + return _stack_images(img_array, compatible_meta), compatible_meta + + def _get_meta_dict(self, img) -> dict: + """ + Get the all the metadata of the image and convert to dict type. + + Args: + img: a Nibabel image object loaded from an image file. + + """ + # swap to little endian as PyTorch doesn't support big endian + try: + header = img.header.as_byteswapped("<") + except ValueError: + header = img.header + return dict(header) + + def _get_affine(self, img): + """ + Get the affine matrix of the image, it can be used to correct + spacing, orientation or execute spatial transforms. + + Args: + img: a Nibabel image object loaded from an image file. + + """ + return np.array(img.affine, copy=True) + + def _get_spatial_shape(self, img): + """ + Get the spatial shape of image data, it doesn't contain the channel dim. + + Args: + img: a Nibabel image object loaded from an image file. + + """ + # swap to little endian as PyTorch doesn't support big endian + try: + header = img.header.as_byteswapped("<") + except ValueError: + header = img.header + dim = header.get("dim", None) + if dim is None: + dim = header.get("dims") # mgh format? + dim = np.insert(dim, 0, 3) + ndim = dim[0] + size = list(dim[1:]) + if not is_no_channel(self.channel_dim): + size.pop(int(self.channel_dim)) # type: ignore + spatial_rank = max(min(ndim, 3), 1) + return np.asarray(size[:spatial_rank]) + + def _get_array_data(self, img): + """ + Get the raw array data of the image, converted to Numpy array. + + Args: + img: a Nibabel image object loaded from an image file. + + """ + return np.asanyarray(img.dataobj, order="C") + + +class NumpyReader(ImageReader): + """ + Load NPY or NPZ format data based on Numpy library, they can be arrays or pickled objects. + A typical usage is to load the `mask` data for classification task. + It can load part of the npz file with specified `npz_keys`. + + Args: + npz_keys: if loading npz file, only load the specified keys, if None, load all the items. + stack the loaded items together to construct a new first dimension. + channel_dim: if not None, explicitly specify the channel dim, otherwise, treat the array as no channel. + kwargs: additional args for `numpy.load` API except `allow_pickle`. more details about available args: + https://numpy.org/doc/stable/reference/generated/numpy.load.html + + """ + + def __init__(self, npz_keys: KeysCollection | None = None, channel_dim: str | int | None = None, **kwargs): + super().__init__() + if npz_keys is not None: + npz_keys = ensure_tuple(npz_keys) + self.npz_keys = npz_keys + self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim + self.kwargs = kwargs + + def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: + """ + Verify whether the specified file or files format is supported by Numpy reader. + + Args: + filename: file name or a list of file names to read. + if a list of files, verify all the suffixes. + """ + suffixes: Sequence[str] = ["npz", "npy"] + return is_supported_format(filename, suffixes) + + def read(self, data: Sequence[PathLike] | PathLike, **kwargs): + """ + Read image data from specified file or files, it can read a list of data files + and stack them together as multi-channel data in `get_data()`. + Note that the returned object is Numpy array or list of Numpy arrays. + + Args: + data: file name or a list of file names to read. + kwargs: additional args for `numpy.load` API except `allow_pickle`, will override `self.kwargs` for existing keys. + More details about available args: + https://numpy.org/doc/stable/reference/generated/numpy.load.html + + """ + img_: list[Nifti1Image] = [] + + filenames: Sequence[PathLike] = ensure_tuple(data) + kwargs_ = self.kwargs.copy() + kwargs_.update(kwargs) + for name in filenames: + img = np.load(name, allow_pickle=True, **kwargs_) + if Path(name).name.endswith(".npz"): + # load expected items from NPZ file + npz_keys = list(img.keys()) if self.npz_keys is None else self.npz_keys + for k in npz_keys: + img_.append(img[k]) + else: + img_.append(img) + + return img_ if len(img_) > 1 else img_[0] + + def get_data(self, img) -> tuple[np.ndarray, dict]: + """ + Extract data array and metadata from loaded image and return them. + This function returns two objects, first is numpy array of image data, second is dict of metadata. + It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict. + When loading a list of files, they are stacked together at a new dimension as the first dimension, + and the metadata of the first image is used to represent the output metadata. + + Args: + img: a Numpy array loaded from a file or a list of Numpy arrays. + + """ + img_array: list[np.ndarray] = [] + compatible_meta: dict = {} + if isinstance(img, np.ndarray): + img = (img,) + + for i in ensure_tuple(img): + header: dict[MetaKeys, Any] = {} + if isinstance(i, np.ndarray): + # if `channel_dim` is None, can not detect the channel dim, use all the dims as spatial_shape + spatial_shape = np.asarray(i.shape) + if isinstance(self.channel_dim, int): + spatial_shape = np.delete(spatial_shape, self.channel_dim) + header[MetaKeys.SPATIAL_SHAPE] = spatial_shape + header[MetaKeys.SPACE] = SpaceKeys.RAS + img_array.append(i) + header[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( + self.channel_dim if isinstance(self.channel_dim, int) else float("nan") + ) + _copy_compatible_dict(header, compatible_meta) + + return _stack_images(img_array, compatible_meta), compatible_meta + + +@require_pkg(pkg_name="PIL") +class PILReader(ImageReader): + """ + Load common 2D image format (supports PNG, JPG, BMP) file or files from provided path. + + Args: + converter: additional function to convert the image data after `read()`. + for example, use `converter=lambda image: image.convert("LA")` to convert image format. + reverse_indexing: whether to swap axis 0 and 1 after loading the array, this is enabled by default, + so that output of the reader is consistent with the other readers. Set this option to ``False`` to use + the PIL backend's original spatial axes convention. + kwargs: additional args for `Image.open` API in `read()`, mode details about available args: + https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.open + """ + + def __init__(self, converter: Callable | None = None, reverse_indexing: bool = True, **kwargs): + super().__init__() + self.converter = converter + self.reverse_indexing = reverse_indexing + self.kwargs = kwargs + + def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: + """ + Verify whether the specified file or files format is supported by PIL reader. + + Args: + filename: file name or a list of file names to read. + if a list of files, verify all the suffixes. + """ + suffixes: Sequence[str] = ["png", "jpg", "jpeg", "bmp"] + return has_pil and is_supported_format(filename, suffixes) + + def read(self, data: Sequence[PathLike] | PathLike | np.ndarray, **kwargs): + """ + Read image data from specified file or files, it can read a list of images + and stack them together as multi-channel data in `get_data()`. + Note that the returned object is PIL image or list of PIL image. + + Args: + data: file name or a list of file names to read. + kwargs: additional args for `Image.open` API in `read()`, will override `self.kwargs` for existing keys. + Mode details about available args: + https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.open + + """ + img_: list[PILImage.Image] = [] + + filenames: Sequence[PathLike] = ensure_tuple(data) + kwargs_ = self.kwargs.copy() + kwargs_.update(kwargs) + for name in filenames: + img = PILImage.open(name, **kwargs_) + if callable(self.converter): + img = self.converter(img) + img_.append(img) + + return img_ if len(filenames) > 1 else img_[0] + + def get_data(self, img) -> tuple[np.ndarray, dict]: + """ + Extract data array and metadata from loaded image and return them. + This function returns two objects, first is numpy array of image data, second is dict of metadata. + It computes `spatial_shape` and stores it in meta dict. + When loading a list of files, they are stacked together at a new dimension as the first dimension, + and the metadata of the first image is used to represent the output metadata. + Note that by default `self.reverse_indexing` is set to ``True``, which swaps axis 0 and 1 after loading + the array because the spatial axes definition in PIL is different from other common medical packages. + + Args: + img: a PIL Image object loaded from a file or a list of PIL Image objects. + + """ + img_array: list[np.ndarray] = [] + compatible_meta: dict = {} + + for i in ensure_tuple(img): + header = self._get_meta_dict(i) + header[MetaKeys.SPATIAL_SHAPE] = self._get_spatial_shape(i) + data = np.moveaxis(np.asarray(i), 0, 1) if self.reverse_indexing else np.asarray(i) + img_array.append(data) + header[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( + float("nan") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1 + ) + _copy_compatible_dict(header, compatible_meta) + + return _stack_images(img_array, compatible_meta), compatible_meta + + def _get_meta_dict(self, img) -> dict: + """ + Get the all the metadata of the image and convert to dict type. + Args: + img: a PIL Image object loaded from an image file. + + """ + return {"format": img.format, "mode": img.mode, "width": img.width, "height": img.height} + + def _get_spatial_shape(self, img): + """ + Get the spatial shape of image data, it doesn't contain the channel dim. + Args: + img: a PIL Image object loaded from an image file. + """ + return np.asarray((img.width, img.height)) + + +@dataclass +class NrrdImage: + """Class to wrap nrrd image array and metadata header""" + + array: np.ndarray + header: dict + + +@require_pkg(pkg_name="nrrd") +class NrrdReader(ImageReader): + """ + Load NRRD format images based on pynrrd library. + + Args: + channel_dim: the channel dimension of the input image, default is None. + This is used to set original_channel_dim in the metadata, EnsureChannelFirstD reads this field. + If None, `original_channel_dim` will be either `no_channel` or `0`. + NRRD files are usually "channel first". + dtype: dtype of the data array when loading image. + index_order: Specify whether the returned data array should be in C-order (‘C’) or Fortran-order (‘F’). + Numpy is usually in C-order, but default on the NRRD header is F + affine_lps_to_ras: whether to convert the affine matrix from "LPS" to "RAS". Defaults to ``True``. + Set to ``True`` to be consistent with ``NibabelReader``, otherwise the affine matrix is unmodified. + + kwargs: additional args for `nrrd.read` API. more details about available args: + https://github.com/mhe/pynrrd/blob/master/nrrd/reader.py + + """ + + def __init__( + self, + channel_dim: str | int | None = None, + dtype: np.dtype | type | str | None = np.float32, + index_order: str = "F", + affine_lps_to_ras: bool = True, + **kwargs, + ): + self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim + self.dtype = dtype + self.index_order = index_order + self.affine_lps_to_ras = affine_lps_to_ras + self.kwargs = kwargs + + def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: + """ + Verify whether the specified `filename` is supported by pynrrd reader. + + Args: + filename: file name or a list of file names to read. + if a list of files, verify all the suffixes. + + """ + suffixes: Sequence[str] = ["nrrd", "seg.nrrd"] + return has_nrrd and is_supported_format(filename, suffixes) + + def read(self, data: Sequence[PathLike] | PathLike, **kwargs) -> Sequence[Any] | Any: + """ + Read image data from specified file or files. + Note that it returns a data object or a sequence of data objects. + + Args: + data: file name or a list of file names to read. + kwargs: additional args for actual `read` API of 3rd party libs. + + """ + img_: list = [] + filenames: Sequence[PathLike] = ensure_tuple(data) + kwargs_ = self.kwargs.copy() + kwargs_.update(kwargs) + for name in filenames: + nrrd_image = NrrdImage(*nrrd.read(name, index_order=self.index_order, **kwargs_)) + img_.append(nrrd_image) + return img_ if len(filenames) > 1 else img_[0] + + def get_data(self, img: NrrdImage | list[NrrdImage]) -> tuple[np.ndarray, dict]: + """ + Extract data array and metadata from loaded image and return them. + This function must return two objects, the first is a numpy array of image data, + the second is a dictionary of metadata. + + Args: + img: a `NrrdImage` loaded from an image file or a list of image objects. + + """ + img_array: list[np.ndarray] = [] + compatible_meta: dict = {} + + for i in ensure_tuple(img): + data = i.array.astype(self.dtype) + img_array.append(data) + header = dict(i.header) + if self.index_order == "C": + header = self._convert_f_to_c_order(header) + header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(header) + + if self.affine_lps_to_ras: + header = self._switch_lps_ras(header) + if header.get(MetaKeys.SPACE, "left-posterior-superior") == "left-posterior-superior": + header[MetaKeys.SPACE] = SpaceKeys.LPS # assuming LPS if not specified + + header[MetaKeys.AFFINE] = header[MetaKeys.ORIGINAL_AFFINE].copy() + header[MetaKeys.SPATIAL_SHAPE] = header["sizes"] + [header.pop(k) for k in ("sizes", "space origin", "space directions")] # rm duplicated data in header + + if self.channel_dim is None: # default to "no_channel" or -1 + header[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( + float("nan") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else 0 + ) + else: + header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim + _copy_compatible_dict(header, compatible_meta) + + return _stack_images(img_array, compatible_meta), compatible_meta + + def _get_affine(self, header: dict) -> np.ndarray: + """ + Get the affine matrix of the image, it can be used to correct + spacing, orientation or execute spatial transforms. + + Args: + img: A `NrrdImage` loaded from image file + + """ + direction = header["space directions"] + origin = header["space origin"] + + x, y = direction.shape + affine_diam = min(x, y) + 1 + affine: np.ndarray = np.eye(affine_diam) + affine[:x, :y] = direction + affine[: (affine_diam - 1), -1] = origin # len origin is always affine_diam - 1 + return affine + + def _switch_lps_ras(self, header: dict) -> dict: + """ + For compatibility with nibabel, switch from LPS to RAS. Adapt affine matrix and + `space` argument in header accordingly. If no information of space is given in the header, + LPS is assumed and thus converted to RAS. If information about space is given, + but is not LPS, the unchanged header is returned. + + Args: + header: The image metadata as dict + + """ + if "space" not in header or header["space"] == "left-posterior-superior": + header[MetaKeys.ORIGINAL_AFFINE] = orientation_ras_lps(header[MetaKeys.ORIGINAL_AFFINE]) + header[MetaKeys.SPACE] = SpaceKeys.RAS + return header + + def _convert_f_to_c_order(self, header: dict) -> dict: + """ + All header fields of a NRRD are specified in `F` (Fortran) order, even if the image was read as C-ordered array. + 1D arrays of header['space origin'] and header['sizes'] become inverted, e.g, [1,2,3] -> [3,2,1] + The 2D Array for header['space directions'] is transposed: [[1,0,0],[0,2,0],[0,0,3]] -> [[3,0,0],[0,2,0],[0,0,1]] + For more details refer to: https://pynrrd.readthedocs.io/en/latest/user-guide.html#index-ordering + + Args: + header: The image metadata as dict + + """ + + header["space directions"] = np.rot90(np.flip(header["space directions"], 0)) + header["space origin"] = header["space origin"][::-1] + header["sizes"] = header["sizes"][::-1] + return header \ No newline at end of file From e9f7565c4adc62b239e01d53c93224d1c04d2085 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 14 May 2024 09:37:14 +0000 Subject: [PATCH 03/49] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/data/image_reader.py | 25 +++++-------------------- 1 file changed, 5 insertions(+), 20 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 257bebc831..e7240e6b96 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -147,8 +147,8 @@ def _stack_images(image_list: list, meta_dict: dict): # stack at a new first dim as the channel dim, if `'original_channel_dim'` is unspecified meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0 return np.stack(image_list, axis=0) - - + + def update_json(input_file=None, output_file=None): record_path = "img-label.json" @@ -1433,28 +1433,13 @@ def _convert_f_to_c_order(self, header: dict) -> dict: from __future__ import annotations -import glob -import os -import re -import warnings from abc import ABC, abstractmethod -from collections.abc import Callable, Iterable, Iterator, Sequence from dataclasses import dataclass -from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import numpy as np -from torch.utils.data._utils.collate import np_str_obj_array_pattern -from monai.config import KeysCollection, PathLike -from monai.data.utils import ( - affine_to_spacing, - correct_nifti_header_if_necessary, - is_no_channel, - is_supported_format, - orientation_ras_lps, -) -from monai.utils import MetaKeys, SpaceKeys, TraceKeys, ensure_tuple, optional_import, require_pkg +from monai.utils import optional_import, require_pkg if TYPE_CHECKING: import itk @@ -2816,4 +2801,4 @@ def _convert_f_to_c_order(self, header: dict) -> dict: header["space directions"] = np.rot90(np.flip(header["space directions"], 0)) header["space origin"] = header["space origin"][::-1] header["sizes"] = header["sizes"][::-1] - return header \ No newline at end of file + return header From 7969d21613087a3432de6782e7590182e9a06614 Mon Sep 17 00:00:00 2001 From: staydelight Date: Tue, 14 May 2024 17:40:15 +0800 Subject: [PATCH 04/49] Fixes #7557 Remove changes unrelated to this issue. Signed-off-by: staydelight --- monai/data/image_writer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py index 06209c664a..4b7d95e71a 100644 --- a/monai/data/image_writer.py +++ b/monai/data/image_writer.py @@ -297,7 +297,7 @@ def resample_if_needed( # convert back at the end if isinstance(output_array, MetaTensor): output_array.applied_operations = [] - data_array, *_ = convert_data_type(output_array, output_type=orig_type) # type: ignore + data_array, *_ = convert_data_type(output_array, output_type=orig_type) affine, *_ = convert_data_type(output_array.affine, output_type=orig_type) # type: ignore return data_array[0], affine @@ -483,7 +483,6 @@ def write(self, filename: PathLike, verbose: bool = False, **kwargs): - https://github.com/InsightSoftwareConsortium/ITK/blob/v5.2.1/Wrapping/Generators/Python/itk/support/extras.py#L809 """ - logger.info(f"ITKWriter is processing the file: {filename}") super().write(filename, verbose=verbose) super().update_json(output_file=filename) self.data_obj = self.create_backend_obj( @@ -648,7 +647,6 @@ def write(self, filename: PathLike, verbose: bool = False, **obj_kwargs): - https://nipy.org/nibabel/reference/nibabel.nifti1.html#nibabel.nifti1.save """ - logger.info(f"NibabelWriter is processing the file: {filename}") super().write(filename, verbose=verbose) super().update_json(output_file=filename) self.data_obj = self.create_backend_obj( From 3ce5f30f5034bd9f954676a2a62da56ffc17664d Mon Sep 17 00:00:00 2001 From: staydelight Date: Tue, 14 May 2024 17:58:19 +0800 Subject: [PATCH 05/49] Fixes #7557 Remove changes unrelated to this issue. Signed-off-by: staydelight --- monai/data/image_reader.py | 1383 +----------------------------------- 1 file changed, 1 insertion(+), 1382 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index e7240e6b96..d11140c110 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -1420,1385 +1420,4 @@ def _convert_f_to_c_order(self, header: dict) -> dict: header["space directions"] = np.rot90(np.flip(header["space directions"], 0)) header["space origin"] = header["space origin"][::-1] header["sizes"] = header["sizes"][::-1] - return header# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import TYPE_CHECKING - -import numpy as np - -from monai.utils import optional_import, require_pkg - -if TYPE_CHECKING: - import itk - import nibabel as nib - import nrrd - import pydicom - from nibabel.nifti1 import Nifti1Image - from PIL import Image as PILImage - - has_nrrd = has_itk = has_nib = has_pil = has_pydicom = True -else: - itk, has_itk = optional_import("itk", allow_namespace_pkg=True) - nib, has_nib = optional_import("nibabel") - Nifti1Image, _ = optional_import("nibabel.nifti1", name="Nifti1Image") - PILImage, has_pil = optional_import("PIL.Image") - pydicom, has_pydicom = optional_import("pydicom") - nrrd, has_nrrd = optional_import("nrrd", allow_namespace_pkg=True) - -__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"] - - -class ImageReader(ABC): - """ - An abstract class defines APIs to load image files. - - Typical usage of an implementation of this class is: - - .. code-block:: python - - image_reader = MyImageReader() - img_obj = image_reader.read(path_to_image) - img_data, meta_data = image_reader.get_data(img_obj) - - - The `read` call converts image filenames into image objects, - - The `get_data` call fetches the image data, as well as metadata. - - A reader should implement `verify_suffix` with the logic of checking the input filename - by the filename extensions. - - """ - - @abstractmethod - def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: - """ - Verify whether the specified `filename` is supported by the current reader. - This method should return True if the reader is able to read the format suggested by the - `filename`. - - Args: - filename: file name or a list of file names to read. - if a list of files, verify all the suffixes. - - """ - raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") - - @abstractmethod - def read(self, data: Sequence[PathLike] | PathLike, **kwargs) -> Sequence[Any] | Any: - """ - Read image data from specified file or files. - Note that it returns a data object or a sequence of data objects. - - Args: - data: file name or a list of file names to read. - kwargs: additional args for actual `read` API of 3rd party libs. - - """ - raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") - - @abstractmethod - def get_data(self, img) -> tuple[np.ndarray, dict]: - """ - Extract data array and metadata from loaded image and return them. - This function must return two objects, the first is a numpy array of image data, - the second is a dictionary of metadata. - - Args: - img: an image object loaded from an image file or a list of image objects. - - """ - raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") - - -def _copy_compatible_dict(from_dict: dict, to_dict: dict): - if not isinstance(to_dict, dict): - raise ValueError(f"to_dict must be a Dict, got {type(to_dict)}.") - if not to_dict: - for key in from_dict: - datum = from_dict[key] - if isinstance(datum, np.ndarray) and np_str_obj_array_pattern.search(datum.dtype.str) is not None: - continue - to_dict[key] = str(TraceKeys.NONE) if datum is None else datum # NoneType to string for default_collate - else: - affine_key, shape_key = MetaKeys.AFFINE, MetaKeys.SPATIAL_SHAPE - if affine_key in from_dict and not np.allclose(from_dict[affine_key], to_dict[affine_key]): - raise RuntimeError( - "affine matrix of all images should be the same for channel-wise concatenation. " - f"Got {from_dict[affine_key]} and {to_dict[affine_key]}." - ) - if shape_key in from_dict and not np.allclose(from_dict[shape_key], to_dict[shape_key]): - raise RuntimeError( - "spatial_shape of all images should be the same for channel-wise concatenation. " - f"Got {from_dict[shape_key]} and {to_dict[shape_key]}." - ) - - -def _stack_images(image_list: list, meta_dict: dict): - if len(image_list) <= 1: - return image_list[0] - if not is_no_channel(meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None)): - channel_dim = int(meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM]) - return np.concatenate(image_list, axis=channel_dim) - # stack at a new first dim as the channel dim, if `'original_channel_dim'` is unspecified - meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0 - return np.stack(image_list, axis=0) - - -@require_pkg(pkg_name="itk") -class ITKReader(ImageReader): - """ - Load medical images based on ITK library. - All the supported image formats can be found at: - https://github.com/InsightSoftwareConsortium/ITK/tree/master/Modules/IO - The loaded data array will be in C order, for example, a 3D image NumPy - array index order will be `CDWH`. - - Args: - channel_dim: the channel dimension of the input image, default is None. - This is used to set original_channel_dim in the metadata, EnsureChannelFirstD reads this field. - If None, `original_channel_dim` will be either `no_channel` or `-1`. - - - Nifti file is usually "channel last", so there is no need to specify this argument. - - PNG file usually has `GetNumberOfComponentsPerPixel()==3`, so there is no need to specify this argument. - - series_name: the name of the DICOM series if there are multiple ones. - used when loading DICOM series. - reverse_indexing: whether to use a reversed spatial indexing convention for the returned data array. - If ``False``, the spatial indexing convention is reversed to be compatible with ITK; - otherwise, the spatial indexing follows the numpy convention. Default is ``False``. - This option does not affect the metadata. - series_meta: whether to load the metadata of the DICOM series (using the metadata from the first slice). - This flag is checked only when loading DICOM series. Default is ``False``. - affine_lps_to_ras: whether to convert the affine matrix from "LPS" to "RAS". Defaults to ``True``. - Set to ``True`` to be consistent with ``NibabelReader``, otherwise the affine matrix remains in the ITK convention. - kwargs: additional args for `itk.imread` API. more details about available args: - https://github.com/InsightSoftwareConsortium/ITK/blob/master/Wrapping/Generators/Python/itk/support/extras.py - - """ - - def __init__( - self, - channel_dim: str | int | None = None, - series_name: str = "", - reverse_indexing: bool = False, - series_meta: bool = False, - affine_lps_to_ras: bool = True, - **kwargs, - ): - super().__init__() - self.kwargs = kwargs - self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim - self.series_name = series_name - self.reverse_indexing = reverse_indexing - self.series_meta = series_meta - self.affine_lps_to_ras = affine_lps_to_ras - - def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: - """ - Verify whether the specified file or files format is supported by ITK reader. - - Args: - filename: file name or a list of file names to read. - if a list of files, verify all the suffixes. - - """ - return has_itk - - def read(self, data: Sequence[PathLike] | PathLike, **kwargs): - """ - Read image data from specified file or files, it can read a list of images - and stack them together as multi-channel data in `get_data()`. - If passing directory path instead of file path, will treat it as DICOM images series and read. - Note that the returned object is ITK image object or list of ITK image objects. - - Args: - data: file name or a list of file names to read, - kwargs: additional args for `itk.imread` API, will override `self.kwargs` for existing keys. - More details about available args: - https://github.com/InsightSoftwareConsortium/ITK/blob/master/Wrapping/Generators/Python/itk/support/extras.py - - """ - img_ = [] - - filenames: Sequence[PathLike] = ensure_tuple(data) - kwargs_ = self.kwargs.copy() - kwargs_.update(kwargs) - for name in filenames: - name = f"{name}" - if Path(name).is_dir(): - # read DICOM series - # https://examples.itk.org/src/io/gdcm/readdicomseriesandwrite3dimage/documentation - names_generator = itk.GDCMSeriesFileNames.New() - names_generator.SetUseSeriesDetails(True) - names_generator.AddSeriesRestriction("0008|0021") # Series Date - names_generator.SetDirectory(name) - series_uid = names_generator.GetSeriesUIDs() - - if len(series_uid) < 1: - raise FileNotFoundError(f"no DICOMs in: {name}.") - if len(series_uid) > 1: - warnings.warn(f"the directory: {name} contains more than one DICOM series.") - series_identifier = series_uid[0] if not self.series_name else self.series_name - name = names_generator.GetFileNames(series_identifier) - - name = name[0] if len(name) == 1 else name # type: ignore - _obj = itk.imread(name, **kwargs_) - if self.series_meta: - _reader = itk.ImageSeriesReader.New(FileNames=name) - _reader.Update() - _meta = _reader.GetMetaDataDictionaryArray() - if len(_meta) > 0: - # TODO: using the first slice's meta. this could be improved to filter unnecessary tags. - _obj.SetMetaDataDictionary(_meta[0]) - img_.append(_obj) - else: - img_.append(itk.imread(name, **kwargs_)) - return img_ if len(filenames) > 1 else img_[0] - - def get_data(self, img) -> tuple[np.ndarray, dict]: - """ - Extract data array and metadata from loaded image and return them. - This function returns two objects, first is numpy array of image data, second is dict of metadata. - It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict. - When loading a list of files, they are stacked together at a new dimension as the first dimension, - and the metadata of the first image is used to represent the output metadata. - - Args: - img: an ITK image object loaded from an image file or a list of ITK image objects. - - """ - img_array: list[np.ndarray] = [] - compatible_meta: dict = {} - - for i in ensure_tuple(img): - data = self._get_array_data(i) - img_array.append(data) - header = self._get_meta_dict(i) - header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(i, self.affine_lps_to_ras) - header[MetaKeys.SPACE] = SpaceKeys.RAS if self.affine_lps_to_ras else SpaceKeys.LPS - header[MetaKeys.AFFINE] = header[MetaKeys.ORIGINAL_AFFINE].copy() - header[MetaKeys.SPATIAL_SHAPE] = self._get_spatial_shape(i) - if self.channel_dim is None: # default to "no_channel" or -1 - header[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( - float("nan") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1 - ) - else: - header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim - _copy_compatible_dict(header, compatible_meta) - - return _stack_images(img_array, compatible_meta), compatible_meta - - def _get_meta_dict(self, img) -> dict: - """ - Get all the metadata of the image and convert to dict type. - - Args: - img: an ITK image object loaded from an image file. - - """ - img_meta_dict = img.GetMetaDataDictionary() - meta_dict = {} - for key in img_meta_dict.GetKeys(): - if key.startswith("ITK_"): - continue - val = img_meta_dict[key] - meta_dict[key] = np.asarray(val) if type(val).__name__.startswith("itk") else val - - meta_dict["spacing"] = np.asarray(img.GetSpacing()) - return meta_dict - - def _get_affine(self, img, lps_to_ras: bool = True): - """ - Get or construct the affine matrix of the image, it can be used to correct - spacing, orientation or execute spatial transforms. - - Args: - img: an ITK image object loaded from an image file. - lps_to_ras: whether to convert the affine matrix from "LPS" to "RAS". Defaults to True. - - """ - direction = itk.array_from_matrix(img.GetDirection()) - spacing = np.asarray(img.GetSpacing()) - origin = np.asarray(img.GetOrigin()) - - direction = np.asarray(direction) - sr = min(max(direction.shape[0], 1), 3) - affine: np.ndarray = np.eye(sr + 1) - affine[:sr, :sr] = direction[:sr, :sr] @ np.diag(spacing[:sr]) - affine[:sr, -1] = origin[:sr] - if lps_to_ras: - affine = orientation_ras_lps(affine) - return affine - - def _get_spatial_shape(self, img): - """ - Get the spatial shape of `img`. - - Args: - img: an ITK image object loaded from an image file. - - """ - sr = itk.array_from_matrix(img.GetDirection()).shape[0] - sr = max(min(sr, 3), 1) - _size = list(itk.size(img)) - if isinstance(self.channel_dim, int): - _size.pop(self.channel_dim) - return np.asarray(_size[:sr]) - - def _get_array_data(self, img): - """ - Get the raw array data of the image, converted to Numpy array. - - Following PyTorch conventions, the returned array data has contiguous channels, - e.g. for an RGB image, all red channel image pixels are contiguous in memory. - The last axis of the returned array is the channel axis. - - See also: - - - https://github.com/InsightSoftwareConsortium/ITK/blob/v5.2.1/Modules/Bridge/NumPy/wrapping/PyBuffer.i.in - - Args: - img: an ITK image object loaded from an image file. - - """ - np_img = itk.array_view_from_image(img, keep_axes=False) - if img.GetNumberOfComponentsPerPixel() == 1: # handling spatial images - return np_img if self.reverse_indexing else np_img.T - # handling multi-channel images - return np_img if self.reverse_indexing else np.moveaxis(np_img.T, 0, -1) - - -@require_pkg(pkg_name="pydicom") -class PydicomReader(ImageReader): - """ - Load medical images based on Pydicom library. - All the supported image formats can be found at: - https://dicom.nema.org/medical/dicom/current/output/chtml/part10/chapter_7.html - - PydicomReader is also able to load segmentations, if a dicom file contains tag: `SegmentSequence`, the reader - will consider it as segmentation data, and to load it successfully, `PerFrameFunctionalGroupsSequence` is required - for dicom file, and for each frame of dicom file, `SegmentIdentificationSequence` is required. - This method refers to the Highdicom library. - - This class refers to: - https://nipy.org/nibabel/dicom/dicom_orientation.html#dicom-affine-formula - https://github.com/pydicom/contrib-pydicom/blob/master/input-output/pydicom_series.py - https://highdicom.readthedocs.io/en/latest/usage.html#parsing-segmentation-seg-images - - Args: - channel_dim: the channel dimension of the input image, default is None. - This is used to set original_channel_dim in the metadata, EnsureChannelFirstD reads this field. - If None, `original_channel_dim` will be either `no_channel` or `-1`. - affine_lps_to_ras: whether to convert the affine matrix from "LPS" to "RAS". Defaults to ``True``. - Set to ``True`` to be consistent with ``NibabelReader``, - otherwise the affine matrix remains in the Dicom convention. - swap_ij: whether to swap the first two spatial axes. Default to ``True``, so that the outputs - are consistent with the other readers. - prune_metadata: whether to prune the saved information in metadata. This argument is used for - `get_data` function. If True, only items that are related to the affine matrix will be saved. - Default to ``True``. - label_dict: label of the dicom data. If provided, it will be used when loading segmentation data. - Keys of the dict are the classes, and values are the corresponding class number. For example: - for TCIA collection "C4KC-KiTS", it can be: {"Kidney": 0, "Renal Tumor": 1}. - fname_regex: a regular expression to match the file names when the input is a folder. - If provided, only the matched files will be included. For example, to include the file name - "image_0001.dcm", the regular expression could be `".*image_(\\d+).dcm"`. Default to `""`. - Set it to `None` to use `pydicom.misc.is_dicom` to match valid files. - kwargs: additional args for `pydicom.dcmread` API. more details about available args: - https://pydicom.github.io/pydicom/stable/reference/generated/pydicom.filereader.dcmread.html - If the `get_data` function will be called - (for example, when using this reader with `monai.transforms.LoadImage`), please ensure that the argument - `stop_before_pixels` is `True`, and `specific_tags` covers all necessary tags, such as `PixelSpacing`, - `ImagePositionPatient`, `ImageOrientationPatient` and all `pixel_array` related tags. - """ - - def __init__( - self, - channel_dim: str | int | None = None, - affine_lps_to_ras: bool = True, - swap_ij: bool = True, - prune_metadata: bool = True, - label_dict: dict | None = None, - fname_regex: str = "", - **kwargs, - ): - super().__init__() - self.kwargs = kwargs - self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim - self.affine_lps_to_ras = affine_lps_to_ras - self.swap_ij = swap_ij - self.prune_metadata = prune_metadata - self.label_dict = label_dict - self.fname_regex = fname_regex - - def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: - """ - Verify whether the specified file or files format is supported by Pydicom reader. - - Args: - filename: file name or a list of file names to read. - if a list of files, verify all the suffixes. - - """ - return has_pydicom - - def read(self, data: Sequence[PathLike] | PathLike, **kwargs): - """ - Read image data from specified file or files, it can read a list of images - and stack them together as multi-channel data in `get_data()`. - If passing directory path instead of file path, will treat it as DICOM images series and read. - - Args: - data: file name or a list of file names to read, - kwargs: additional args for `pydicom.dcmread` API, will override `self.kwargs` for existing keys. - - Returns: - If `data` represents a filename: return a pydicom dataset object. - If `data` represents a list of filenames or a directory: return a list of pydicom dataset object. - If `data` represents a list of directories: return a list of list of pydicom dataset object. - - """ - img_ = [] - - filenames: Sequence[PathLike] = ensure_tuple(data) - kwargs_ = self.kwargs.copy() - kwargs_.update(kwargs) - - self.has_series = False - - for name in filenames: - name = f"{name}" - if Path(name).is_dir(): - # read DICOM series - if self.fname_regex is not None: - series_slcs = [slc for slc in glob.glob(os.path.join(name, "*")) if re.match(self.fname_regex, slc)] - else: - series_slcs = [slc for slc in glob.glob(os.path.join(name, "*")) if pydicom.misc.is_dicom(slc)] - slices = [] - for slc in series_slcs: - try: - slices.append(pydicom.dcmread(fp=slc, **kwargs_)) - except pydicom.errors.InvalidDicomError as e: - warnings.warn(f"Failed to read {slc} with exception: \n{e}.", stacklevel=2) - img_.append(slices if len(slices) > 1 else slices[0]) - if len(slices) > 1: - self.has_series = True - else: - ds = pydicom.dcmread(fp=name, **kwargs_) - img_.append(ds) - return img_ if len(filenames) > 1 else img_[0] - - def _combine_dicom_series(self, data: Iterable): - """ - Combine dicom series (a list of pydicom dataset objects). Their data arrays will be stacked together at a new - dimension as the last dimension. - - The stack order depends on Instance Number. The metadata will be based on the - first slice's metadata, and some new items will be added: - - "spacing": the new spacing of the stacked slices. - "lastImagePositionPatient": `ImagePositionPatient` for the last slice, it will be used to achieve the affine - matrix. - "spatial_shape": the spatial shape of the stacked slices. - - Args: - data: a list of pydicom dataset objects. - Returns: - a tuple that consisted with data array and metadata. - """ - slices: list = [] - # for a dicom series - for slc_ds in data: - if hasattr(slc_ds, "InstanceNumber"): - slices.append(slc_ds) - else: - warnings.warn(f"slice: {slc_ds.filename} does not have InstanceNumber tag, skip it.") - slices = sorted(slices, key=lambda s: s.InstanceNumber) - - if len(slices) == 0: - raise ValueError("the input does not have valid slices.") - - first_slice = slices[0] - average_distance = 0.0 - first_array = self._get_array_data(first_slice) - shape = first_array.shape - spacing = getattr(first_slice, "PixelSpacing", [1.0, 1.0, 1.0]) - prev_pos = getattr(first_slice, "ImagePositionPatient", (0.0, 0.0, 0.0))[2] - stack_array = [first_array] - for idx in range(1, len(slices)): - slc_array = self._get_array_data(slices[idx]) - slc_shape = slc_array.shape - slc_spacing = getattr(slices[idx], "PixelSpacing", (1.0, 1.0, 1.0)) - slc_pos = getattr(slices[idx], "ImagePositionPatient", (0.0, 0.0, float(idx)))[2] - if not np.allclose(slc_spacing, spacing): - warnings.warn(f"the list contains slices that have different spacings {spacing} and {slc_spacing}.") - if shape != slc_shape: - warnings.warn(f"the list contains slices that have different shapes {shape} and {slc_shape}.") - average_distance += abs(prev_pos - slc_pos) - prev_pos = slc_pos - stack_array.append(slc_array) - - if len(slices) > 1: - average_distance /= len(slices) - 1 - spacing.append(average_distance) - stack_array = np.stack(stack_array, axis=-1) - stack_metadata = self._get_meta_dict(first_slice) - stack_metadata["spacing"] = np.asarray(spacing) - if hasattr(slices[-1], "ImagePositionPatient"): - stack_metadata["lastImagePositionPatient"] = np.asarray(slices[-1].ImagePositionPatient) - stack_metadata[MetaKeys.SPATIAL_SHAPE] = shape + (len(slices),) - else: - stack_array = stack_array[0] - stack_metadata = self._get_meta_dict(first_slice) - stack_metadata["spacing"] = np.asarray(spacing) - stack_metadata[MetaKeys.SPATIAL_SHAPE] = shape - - return stack_array, stack_metadata - - def get_data(self, data) -> tuple[np.ndarray, dict]: - """ - Extract data array and metadata from loaded image and return them. - This function returns two objects, first is numpy array of image data, second is dict of metadata. - It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict. - For dicom series within the input, all slices will be stacked first, - When loading a list of files (dicom file, or stacked dicom series), they are stacked together at a new - dimension as the first dimension, and the metadata of the first image is used to represent the output metadata. - - To use this function, all pydicom dataset objects (if not segmentation data) should contain: - `pixel_array`, `PixelSpacing`, `ImagePositionPatient` and `ImageOrientationPatient`. - - For segmentation data, we assume that the input is not a dicom series, and the object should contain - `SegmentSequence` in order to identify it. - In addition, tags (5200, 9229) and (5200, 9230) are required to achieve - `PixelSpacing`, `ImageOrientationPatient` and `ImagePositionPatient`. - - Args: - data: a pydicom dataset object, or a list of pydicom dataset objects, or a list of list of - pydicom dataset objects. - - """ - - dicom_data = [] - # combine dicom series if exists - if self.has_series is True: - # a list, all objects within a list belong to one dicom series - if not isinstance(data[0], list): - dicom_data.append(self._combine_dicom_series(data)) - # a list of list, each inner list represents a dicom series - else: - for series in data: - dicom_data.append(self._combine_dicom_series(series)) - else: - # a single pydicom dataset object - if not isinstance(data, list): - data = [data] - for d in data: - if hasattr(d, "SegmentSequence"): - data_array, metadata = self._get_seg_data(d) - else: - data_array = self._get_array_data(d) - metadata = self._get_meta_dict(d) - metadata[MetaKeys.SPATIAL_SHAPE] = data_array.shape - dicom_data.append((data_array, metadata)) - - img_array: list[np.ndarray] = [] - compatible_meta: dict = {} - - for data_array, metadata in ensure_tuple(dicom_data): - img_array.append(np.ascontiguousarray(np.swapaxes(data_array, 0, 1) if self.swap_ij else data_array)) - affine = self._get_affine(metadata, self.affine_lps_to_ras) - metadata[MetaKeys.SPACE] = SpaceKeys.RAS if self.affine_lps_to_ras else SpaceKeys.LPS - if self.swap_ij: - affine = affine @ np.array([[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) - sp_size = list(metadata[MetaKeys.SPATIAL_SHAPE]) - sp_size[0], sp_size[1] = sp_size[1], sp_size[0] - metadata[MetaKeys.SPATIAL_SHAPE] = ensure_tuple(sp_size) - metadata[MetaKeys.ORIGINAL_AFFINE] = affine - metadata[MetaKeys.AFFINE] = affine.copy() - if self.channel_dim is None: # default to "no_channel" or -1 - metadata[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( - float("nan") if len(data_array.shape) == len(metadata[MetaKeys.SPATIAL_SHAPE]) else -1 - ) - else: - metadata[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim - metadata["spacing"] = affine_to_spacing( - metadata[MetaKeys.ORIGINAL_AFFINE], r=len(metadata[MetaKeys.SPATIAL_SHAPE]) - ) - - _copy_compatible_dict(metadata, compatible_meta) - - return _stack_images(img_array, compatible_meta), compatible_meta - - def _get_meta_dict(self, img) -> dict: - """ - Get all the metadata of the image and convert to dict type. - - Args: - img: a Pydicom dataset object. - - """ - - metadata = img.to_json_dict(suppress_invalid_tags=True) - - if self.prune_metadata: - prune_metadata = {} - for key in ["00200037", "00200032", "00280030", "52009229", "52009230"]: - if key in metadata.keys(): - prune_metadata[key] = metadata[key] - return prune_metadata - - # always remove Pixel Data "7FE00008" or "7FE00009" or "7FE00010" - # always remove Data Set Trailing Padding "FFFCFFFC" - for key in ["7FE00008", "7FE00009", "7FE00010", "FFFCFFFC"]: - if key in metadata.keys(): - metadata.pop(key) - - return metadata # type: ignore - - def _get_affine(self, metadata: dict, lps_to_ras: bool = True): - """ - Get or construct the affine matrix of the image, it can be used to correct - spacing, orientation or execute spatial transforms. - - Args: - metadata: metadata with dict type. - lps_to_ras: whether to convert the affine matrix from "LPS" to "RAS". Defaults to True. - - """ - affine: np.ndarray = np.eye(4) - if not ("00200037" in metadata and "00200032" in metadata): - return affine - # "00200037" is the tag of `ImageOrientationPatient` - rx, ry, rz, cx, cy, cz = metadata["00200037"]["Value"] - # "00200032" is the tag of `ImagePositionPatient` - sx, sy, sz = metadata["00200032"]["Value"] - # "00280030" is the tag of `PixelSpacing` - spacing = metadata["00280030"]["Value"] if "00280030" in metadata else (1.0, 1.0) - dr, dc = metadata.get("spacing", spacing)[:2] - affine[0, 0] = cx * dr - affine[0, 1] = rx * dc - affine[0, 3] = sx - affine[1, 0] = cy * dr - affine[1, 1] = ry * dc - affine[1, 3] = sy - affine[2, 0] = cz * dr - affine[2, 1] = rz * dc - affine[2, 2] = 1.0 - affine[2, 3] = sz - - # 3d - if "lastImagePositionPatient" in metadata: - t1n, t2n, t3n = metadata["lastImagePositionPatient"] - n = metadata[MetaKeys.SPATIAL_SHAPE][-1] - k1, k2, k3 = (t1n - sx) / (n - 1), (t2n - sy) / (n - 1), (t3n - sz) / (n - 1) - affine[0, 2] = k1 - affine[1, 2] = k2 - affine[2, 2] = k3 - - if lps_to_ras: - affine = orientation_ras_lps(affine) - return affine - - def _get_frame_data(self, img) -> Iterator: - """ - yield frames and description from the segmentation image. - This function is adapted from Highdicom: - https://github.com/herrmannlab/highdicom/blob/v0.18.2/src/highdicom/seg/utils.py - - which has the following license... - - # ========================================================================= - # https://github.com/herrmannlab/highdicom/blob/v0.18.2/LICENSE - # - # Copyright 2020 MGH Computational Pathology - # Permission is hereby granted, free of charge, to any person obtaining a - # copy of this software and associated documentation files (the - # "Software"), to deal in the Software without restriction, including - # without limitation the rights to use, copy, modify, merge, publish, - # distribute, sublicense, and/or sell copies of the Software, and to - # permit persons to whom the Software is furnished to do so, subject to - # the following conditions: - # The above copyright notice and this permission notice shall be included - # in all copies or substantial portions of the Software. - # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS - # OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - # IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - # CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - # SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - # ========================================================================= - - (https://github.com/herrmannlab/highdicom/issues/188) - - Args: - img: a Pydicom dataset object that has attribute "SegmentSequence". - - """ - - if not hasattr(img, "PerFrameFunctionalGroupsSequence"): - raise NotImplementedError( - f"To read dicom seg: {img.filename}, 'PerFrameFunctionalGroupsSequence' is required." - ) - - frame_seg_nums = [] - for f in img.PerFrameFunctionalGroupsSequence: - if not hasattr(f, "SegmentIdentificationSequence"): - raise NotImplementedError( - f"To read dicom seg: {img.filename}, 'SegmentIdentificationSequence' is required for each frame." - ) - frame_seg_nums.append(int(f.SegmentIdentificationSequence[0].ReferencedSegmentNumber)) - - frame_seg_nums_arr = np.array(frame_seg_nums) - - seg_descriptions = {int(f.SegmentNumber): f for f in img.SegmentSequence} - - for i in np.unique(frame_seg_nums_arr): - indices = np.where(frame_seg_nums_arr == i)[0] - yield (img.pixel_array[indices, ...], seg_descriptions[i]) - - def _get_seg_data(self, img): - """ - Get the array data and metadata of the segmentation image. - - Aegs: - img: a Pydicom dataset object that has attribute "SegmentSequence". - - """ - - metadata = self._get_meta_dict(img) - n_classes = len(img.SegmentSequence) - spatial_shape = list(img.pixel_array.shape) - spatial_shape[0] = spatial_shape[0] // n_classes - - if self.label_dict is not None: - metadata["labels"] = self.label_dict - all_segs = np.zeros([*spatial_shape, len(self.label_dict)]) - else: - metadata["labels"] = {} - all_segs = np.zeros([*spatial_shape, n_classes]) - - for i, (frames, description) in enumerate(self._get_frame_data(img)): - segment_label = getattr(description, "SegmentLabel", f"label_{i}") - class_name = getattr(description, "SegmentDescription", segment_label) - if class_name not in metadata["labels"].keys(): - metadata["labels"][class_name] = i - class_num = metadata["labels"][class_name] - all_segs[..., class_num] = frames - - all_segs = all_segs.transpose([1, 2, 0, 3]) - metadata[MetaKeys.SPATIAL_SHAPE] = all_segs.shape[:-1] - - if "52009229" in metadata.keys(): - shared_func_group_seq = metadata["52009229"]["Value"][0] - - # get `ImageOrientationPatient` - if "00209116" in shared_func_group_seq.keys(): - plane_orient_seq = shared_func_group_seq["00209116"]["Value"][0] - if "00200037" in plane_orient_seq.keys(): - metadata["00200037"] = plane_orient_seq["00200037"] - - # get `PixelSpacing` - if "00289110" in shared_func_group_seq.keys(): - pixel_measure_seq = shared_func_group_seq["00289110"]["Value"][0] - - if "00280030" in pixel_measure_seq.keys(): - pixel_spacing = pixel_measure_seq["00280030"]["Value"] - metadata["spacing"] = pixel_spacing - if "00180050" in pixel_measure_seq.keys(): - metadata["spacing"] += pixel_measure_seq["00180050"]["Value"] - - if self.prune_metadata: - metadata.pop("52009229") - - # get `ImagePositionPatient` - if "52009230" in metadata.keys(): - first_frame_func_group_seq = metadata["52009230"]["Value"][0] - if "00209113" in first_frame_func_group_seq.keys(): - plane_position_seq = first_frame_func_group_seq["00209113"]["Value"][0] - if "00200032" in plane_position_seq.keys(): - metadata["00200032"] = plane_position_seq["00200032"] - metadata["lastImagePositionPatient"] = metadata["52009230"]["Value"][-1]["00209113"]["Value"][0][ - "00200032" - ]["Value"] - if self.prune_metadata: - metadata.pop("52009230") - - return all_segs, metadata - - def _get_array_data(self, img): - """ - Get the array data of the image. If `RescaleSlope` and `RescaleIntercept` are available, the raw array data - will be rescaled. The output data has the dtype np.float32 if the rescaling is applied. - - Args: - img: a Pydicom dataset object. - - """ - # process Dicom series - if not hasattr(img, "pixel_array"): - raise ValueError(f"dicom data: {img.filename} does not have pixel_array.") - data = img.pixel_array - - slope, offset = 1.0, 0.0 - rescale_flag = False - if hasattr(img, "RescaleSlope"): - slope = img.RescaleSlope - rescale_flag = True - if hasattr(img, "RescaleIntercept"): - offset = img.RescaleIntercept - rescale_flag = True - if rescale_flag: - data = data.astype(np.float32) * slope + offset - - return data - - -@require_pkg(pkg_name="nibabel") -class NibabelReader(ImageReader): - """ - Load NIfTI format images based on Nibabel library. - - Args: - as_closest_canonical: if True, load the image as closest to canonical axis format. - squeeze_non_spatial_dims: if True, non-spatial singletons will be squeezed, e.g. (256,256,1,3) -> (256,256,3) - channel_dim: the channel dimension of the input image, default is None. - this is used to set original_channel_dim in the metadata, EnsureChannelFirstD reads this field. - if None, `original_channel_dim` will be either `no_channel` or `-1`. - most Nifti files are usually "channel last", no need to specify this argument for them. - kwargs: additional args for `nibabel.load` API. more details about available args: - https://github.com/nipy/nibabel/blob/master/nibabel/loadsave.py - - """ - - def __init__( - self, - channel_dim: str | int | None = None, - as_closest_canonical: bool = False, - squeeze_non_spatial_dims: bool = False, - **kwargs, - ): - super().__init__() - self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim - self.as_closest_canonical = as_closest_canonical - self.squeeze_non_spatial_dims = squeeze_non_spatial_dims - self.kwargs = kwargs - - def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: - """ - Verify whether the specified file or files format is supported by Nibabel reader. - - Args: - filename: file name or a list of file names to read. - if a list of files, verify all the suffixes. - - """ - suffixes: Sequence[str] = ["nii", "nii.gz"] - return has_nib and is_supported_format(filename, suffixes) - - def read(self, data: Sequence[PathLike] | PathLike, **kwargs): - """ - Read image data from specified file or files, it can read a list of images - and stack them together as multi-channel data in `get_data()`. - Note that the returned object is Nibabel image object or list of Nibabel image objects. - - Args: - data: file name or a list of file names to read. - kwargs: additional args for `nibabel.load` API, will override `self.kwargs` for existing keys. - More details about available args: - https://github.com/nipy/nibabel/blob/master/nibabel/loadsave.py - - """ - img_: list[Nifti1Image] = [] - - filenames: Sequence[PathLike] = ensure_tuple(data) - kwargs_ = self.kwargs.copy() - kwargs_.update(kwargs) - for name in filenames: - img = nib.load(name, **kwargs_) - img = correct_nifti_header_if_necessary(img) - img_.append(img) # type: ignore - return img_ if len(filenames) > 1 else img_[0] - - def get_data(self, img) -> tuple[np.ndarray, dict]: - """ - Extract data array and metadata from loaded image and return them. - This function returns two objects, first is numpy array of image data, second is dict of metadata. - It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict. - When loading a list of files, they are stacked together at a new dimension as the first dimension, - and the metadata of the first image is used to present the output metadata. - - Args: - img: a Nibabel image object loaded from an image file or a list of Nibabel image objects. - - """ - img_array: list[np.ndarray] = [] - compatible_meta: dict = {} - - for i in ensure_tuple(img): - header = self._get_meta_dict(i) - header[MetaKeys.AFFINE] = self._get_affine(i) - header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(i) - header["as_closest_canonical"] = self.as_closest_canonical - if self.as_closest_canonical: - i = nib.as_closest_canonical(i) - header[MetaKeys.AFFINE] = self._get_affine(i) - header[MetaKeys.SPATIAL_SHAPE] = self._get_spatial_shape(i) - header[MetaKeys.SPACE] = SpaceKeys.RAS - data = self._get_array_data(i) - if self.squeeze_non_spatial_dims: - for d in range(len(data.shape), len(header[MetaKeys.SPATIAL_SHAPE]), -1): - if data.shape[d - 1] == 1: - data = data.squeeze(axis=d - 1) - img_array.append(data) - if self.channel_dim is None: # default to "no_channel" or -1 - header[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( - float("nan") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1 - ) - else: - header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim - _copy_compatible_dict(header, compatible_meta) - - return _stack_images(img_array, compatible_meta), compatible_meta - - def _get_meta_dict(self, img) -> dict: - """ - Get the all the metadata of the image and convert to dict type. - - Args: - img: a Nibabel image object loaded from an image file. - - """ - # swap to little endian as PyTorch doesn't support big endian - try: - header = img.header.as_byteswapped("<") - except ValueError: - header = img.header - return dict(header) - - def _get_affine(self, img): - """ - Get the affine matrix of the image, it can be used to correct - spacing, orientation or execute spatial transforms. - - Args: - img: a Nibabel image object loaded from an image file. - - """ - return np.array(img.affine, copy=True) - - def _get_spatial_shape(self, img): - """ - Get the spatial shape of image data, it doesn't contain the channel dim. - - Args: - img: a Nibabel image object loaded from an image file. - - """ - # swap to little endian as PyTorch doesn't support big endian - try: - header = img.header.as_byteswapped("<") - except ValueError: - header = img.header - dim = header.get("dim", None) - if dim is None: - dim = header.get("dims") # mgh format? - dim = np.insert(dim, 0, 3) - ndim = dim[0] - size = list(dim[1:]) - if not is_no_channel(self.channel_dim): - size.pop(int(self.channel_dim)) # type: ignore - spatial_rank = max(min(ndim, 3), 1) - return np.asarray(size[:spatial_rank]) - - def _get_array_data(self, img): - """ - Get the raw array data of the image, converted to Numpy array. - - Args: - img: a Nibabel image object loaded from an image file. - - """ - return np.asanyarray(img.dataobj, order="C") - - -class NumpyReader(ImageReader): - """ - Load NPY or NPZ format data based on Numpy library, they can be arrays or pickled objects. - A typical usage is to load the `mask` data for classification task. - It can load part of the npz file with specified `npz_keys`. - - Args: - npz_keys: if loading npz file, only load the specified keys, if None, load all the items. - stack the loaded items together to construct a new first dimension. - channel_dim: if not None, explicitly specify the channel dim, otherwise, treat the array as no channel. - kwargs: additional args for `numpy.load` API except `allow_pickle`. more details about available args: - https://numpy.org/doc/stable/reference/generated/numpy.load.html - - """ - - def __init__(self, npz_keys: KeysCollection | None = None, channel_dim: str | int | None = None, **kwargs): - super().__init__() - if npz_keys is not None: - npz_keys = ensure_tuple(npz_keys) - self.npz_keys = npz_keys - self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim - self.kwargs = kwargs - - def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: - """ - Verify whether the specified file or files format is supported by Numpy reader. - - Args: - filename: file name or a list of file names to read. - if a list of files, verify all the suffixes. - """ - suffixes: Sequence[str] = ["npz", "npy"] - return is_supported_format(filename, suffixes) - - def read(self, data: Sequence[PathLike] | PathLike, **kwargs): - """ - Read image data from specified file or files, it can read a list of data files - and stack them together as multi-channel data in `get_data()`. - Note that the returned object is Numpy array or list of Numpy arrays. - - Args: - data: file name or a list of file names to read. - kwargs: additional args for `numpy.load` API except `allow_pickle`, will override `self.kwargs` for existing keys. - More details about available args: - https://numpy.org/doc/stable/reference/generated/numpy.load.html - - """ - img_: list[Nifti1Image] = [] - - filenames: Sequence[PathLike] = ensure_tuple(data) - kwargs_ = self.kwargs.copy() - kwargs_.update(kwargs) - for name in filenames: - img = np.load(name, allow_pickle=True, **kwargs_) - if Path(name).name.endswith(".npz"): - # load expected items from NPZ file - npz_keys = list(img.keys()) if self.npz_keys is None else self.npz_keys - for k in npz_keys: - img_.append(img[k]) - else: - img_.append(img) - - return img_ if len(img_) > 1 else img_[0] - - def get_data(self, img) -> tuple[np.ndarray, dict]: - """ - Extract data array and metadata from loaded image and return them. - This function returns two objects, first is numpy array of image data, second is dict of metadata. - It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict. - When loading a list of files, they are stacked together at a new dimension as the first dimension, - and the metadata of the first image is used to represent the output metadata. - - Args: - img: a Numpy array loaded from a file or a list of Numpy arrays. - - """ - img_array: list[np.ndarray] = [] - compatible_meta: dict = {} - if isinstance(img, np.ndarray): - img = (img,) - - for i in ensure_tuple(img): - header: dict[MetaKeys, Any] = {} - if isinstance(i, np.ndarray): - # if `channel_dim` is None, can not detect the channel dim, use all the dims as spatial_shape - spatial_shape = np.asarray(i.shape) - if isinstance(self.channel_dim, int): - spatial_shape = np.delete(spatial_shape, self.channel_dim) - header[MetaKeys.SPATIAL_SHAPE] = spatial_shape - header[MetaKeys.SPACE] = SpaceKeys.RAS - img_array.append(i) - header[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( - self.channel_dim if isinstance(self.channel_dim, int) else float("nan") - ) - _copy_compatible_dict(header, compatible_meta) - - return _stack_images(img_array, compatible_meta), compatible_meta - - -@require_pkg(pkg_name="PIL") -class PILReader(ImageReader): - """ - Load common 2D image format (supports PNG, JPG, BMP) file or files from provided path. - - Args: - converter: additional function to convert the image data after `read()`. - for example, use `converter=lambda image: image.convert("LA")` to convert image format. - reverse_indexing: whether to swap axis 0 and 1 after loading the array, this is enabled by default, - so that output of the reader is consistent with the other readers. Set this option to ``False`` to use - the PIL backend's original spatial axes convention. - kwargs: additional args for `Image.open` API in `read()`, mode details about available args: - https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.open - """ - - def __init__(self, converter: Callable | None = None, reverse_indexing: bool = True, **kwargs): - super().__init__() - self.converter = converter - self.reverse_indexing = reverse_indexing - self.kwargs = kwargs - - def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: - """ - Verify whether the specified file or files format is supported by PIL reader. - - Args: - filename: file name or a list of file names to read. - if a list of files, verify all the suffixes. - """ - suffixes: Sequence[str] = ["png", "jpg", "jpeg", "bmp"] - return has_pil and is_supported_format(filename, suffixes) - - def read(self, data: Sequence[PathLike] | PathLike | np.ndarray, **kwargs): - """ - Read image data from specified file or files, it can read a list of images - and stack them together as multi-channel data in `get_data()`. - Note that the returned object is PIL image or list of PIL image. - - Args: - data: file name or a list of file names to read. - kwargs: additional args for `Image.open` API in `read()`, will override `self.kwargs` for existing keys. - Mode details about available args: - https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.open - - """ - img_: list[PILImage.Image] = [] - - filenames: Sequence[PathLike] = ensure_tuple(data) - kwargs_ = self.kwargs.copy() - kwargs_.update(kwargs) - for name in filenames: - img = PILImage.open(name, **kwargs_) - if callable(self.converter): - img = self.converter(img) - img_.append(img) - - return img_ if len(filenames) > 1 else img_[0] - - def get_data(self, img) -> tuple[np.ndarray, dict]: - """ - Extract data array and metadata from loaded image and return them. - This function returns two objects, first is numpy array of image data, second is dict of metadata. - It computes `spatial_shape` and stores it in meta dict. - When loading a list of files, they are stacked together at a new dimension as the first dimension, - and the metadata of the first image is used to represent the output metadata. - Note that by default `self.reverse_indexing` is set to ``True``, which swaps axis 0 and 1 after loading - the array because the spatial axes definition in PIL is different from other common medical packages. - - Args: - img: a PIL Image object loaded from a file or a list of PIL Image objects. - - """ - img_array: list[np.ndarray] = [] - compatible_meta: dict = {} - - for i in ensure_tuple(img): - header = self._get_meta_dict(i) - header[MetaKeys.SPATIAL_SHAPE] = self._get_spatial_shape(i) - data = np.moveaxis(np.asarray(i), 0, 1) if self.reverse_indexing else np.asarray(i) - img_array.append(data) - header[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( - float("nan") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1 - ) - _copy_compatible_dict(header, compatible_meta) - - return _stack_images(img_array, compatible_meta), compatible_meta - - def _get_meta_dict(self, img) -> dict: - """ - Get the all the metadata of the image and convert to dict type. - Args: - img: a PIL Image object loaded from an image file. - - """ - return {"format": img.format, "mode": img.mode, "width": img.width, "height": img.height} - - def _get_spatial_shape(self, img): - """ - Get the spatial shape of image data, it doesn't contain the channel dim. - Args: - img: a PIL Image object loaded from an image file. - """ - return np.asarray((img.width, img.height)) - - -@dataclass -class NrrdImage: - """Class to wrap nrrd image array and metadata header""" - - array: np.ndarray - header: dict - - -@require_pkg(pkg_name="nrrd") -class NrrdReader(ImageReader): - """ - Load NRRD format images based on pynrrd library. - - Args: - channel_dim: the channel dimension of the input image, default is None. - This is used to set original_channel_dim in the metadata, EnsureChannelFirstD reads this field. - If None, `original_channel_dim` will be either `no_channel` or `0`. - NRRD files are usually "channel first". - dtype: dtype of the data array when loading image. - index_order: Specify whether the returned data array should be in C-order (‘C’) or Fortran-order (‘F’). - Numpy is usually in C-order, but default on the NRRD header is F - affine_lps_to_ras: whether to convert the affine matrix from "LPS" to "RAS". Defaults to ``True``. - Set to ``True`` to be consistent with ``NibabelReader``, otherwise the affine matrix is unmodified. - - kwargs: additional args for `nrrd.read` API. more details about available args: - https://github.com/mhe/pynrrd/blob/master/nrrd/reader.py - - """ - - def __init__( - self, - channel_dim: str | int | None = None, - dtype: np.dtype | type | str | None = np.float32, - index_order: str = "F", - affine_lps_to_ras: bool = True, - **kwargs, - ): - self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim - self.dtype = dtype - self.index_order = index_order - self.affine_lps_to_ras = affine_lps_to_ras - self.kwargs = kwargs - - def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: - """ - Verify whether the specified `filename` is supported by pynrrd reader. - - Args: - filename: file name or a list of file names to read. - if a list of files, verify all the suffixes. - - """ - suffixes: Sequence[str] = ["nrrd", "seg.nrrd"] - return has_nrrd and is_supported_format(filename, suffixes) - - def read(self, data: Sequence[PathLike] | PathLike, **kwargs) -> Sequence[Any] | Any: - """ - Read image data from specified file or files. - Note that it returns a data object or a sequence of data objects. - - Args: - data: file name or a list of file names to read. - kwargs: additional args for actual `read` API of 3rd party libs. - - """ - img_: list = [] - filenames: Sequence[PathLike] = ensure_tuple(data) - kwargs_ = self.kwargs.copy() - kwargs_.update(kwargs) - for name in filenames: - nrrd_image = NrrdImage(*nrrd.read(name, index_order=self.index_order, **kwargs_)) - img_.append(nrrd_image) - return img_ if len(filenames) > 1 else img_[0] - - def get_data(self, img: NrrdImage | list[NrrdImage]) -> tuple[np.ndarray, dict]: - """ - Extract data array and metadata from loaded image and return them. - This function must return two objects, the first is a numpy array of image data, - the second is a dictionary of metadata. - - Args: - img: a `NrrdImage` loaded from an image file or a list of image objects. - - """ - img_array: list[np.ndarray] = [] - compatible_meta: dict = {} - - for i in ensure_tuple(img): - data = i.array.astype(self.dtype) - img_array.append(data) - header = dict(i.header) - if self.index_order == "C": - header = self._convert_f_to_c_order(header) - header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(header) - - if self.affine_lps_to_ras: - header = self._switch_lps_ras(header) - if header.get(MetaKeys.SPACE, "left-posterior-superior") == "left-posterior-superior": - header[MetaKeys.SPACE] = SpaceKeys.LPS # assuming LPS if not specified - - header[MetaKeys.AFFINE] = header[MetaKeys.ORIGINAL_AFFINE].copy() - header[MetaKeys.SPATIAL_SHAPE] = header["sizes"] - [header.pop(k) for k in ("sizes", "space origin", "space directions")] # rm duplicated data in header - - if self.channel_dim is None: # default to "no_channel" or -1 - header[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( - float("nan") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else 0 - ) - else: - header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim - _copy_compatible_dict(header, compatible_meta) - - return _stack_images(img_array, compatible_meta), compatible_meta - - def _get_affine(self, header: dict) -> np.ndarray: - """ - Get the affine matrix of the image, it can be used to correct - spacing, orientation or execute spatial transforms. - - Args: - img: A `NrrdImage` loaded from image file - - """ - direction = header["space directions"] - origin = header["space origin"] - - x, y = direction.shape - affine_diam = min(x, y) + 1 - affine: np.ndarray = np.eye(affine_diam) - affine[:x, :y] = direction - affine[: (affine_diam - 1), -1] = origin # len origin is always affine_diam - 1 - return affine - - def _switch_lps_ras(self, header: dict) -> dict: - """ - For compatibility with nibabel, switch from LPS to RAS. Adapt affine matrix and - `space` argument in header accordingly. If no information of space is given in the header, - LPS is assumed and thus converted to RAS. If information about space is given, - but is not LPS, the unchanged header is returned. - - Args: - header: The image metadata as dict - - """ - if "space" not in header or header["space"] == "left-posterior-superior": - header[MetaKeys.ORIGINAL_AFFINE] = orientation_ras_lps(header[MetaKeys.ORIGINAL_AFFINE]) - header[MetaKeys.SPACE] = SpaceKeys.RAS - return header - - def _convert_f_to_c_order(self, header: dict) -> dict: - """ - All header fields of a NRRD are specified in `F` (Fortran) order, even if the image was read as C-ordered array. - 1D arrays of header['space origin'] and header['sizes'] become inverted, e.g, [1,2,3] -> [3,2,1] - The 2D Array for header['space directions'] is transposed: [[1,0,0],[0,2,0],[0,0,3]] -> [[3,0,0],[0,2,0],[0,0,1]] - For more details refer to: https://pynrrd.readthedocs.io/en/latest/user-guide.html#index-ordering - - Args: - header: The image metadata as dict - - """ - - header["space directions"] = np.rot90(np.flip(header["space directions"], 0)) - header["space origin"] = header["space origin"][::-1] - header["sizes"] = header["sizes"][::-1] - return header + return header \ No newline at end of file From 0699eebe43bc9e06b2fa35fb4cf93e8095e69e0a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 14 May 2024 10:00:44 +0000 Subject: [PATCH 06/49] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/data/image_reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index d11140c110..c33b62681c 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -1420,4 +1420,4 @@ def _convert_f_to_c_order(self, header: dict) -> dict: header["space directions"] = np.rot90(np.flip(header["space directions"], 0)) header["space origin"] = header["space origin"][::-1] header["sizes"] = header["sizes"][::-1] - return header \ No newline at end of file + return header From 274cd044a423c8d31cd28a6240547a0225c67be2 Mon Sep 17 00:00:00 2001 From: staydelight Date: Sun, 2 Jun 2024 00:05:51 +0800 Subject: [PATCH 07/49] fix-issue-7557 Signed-off-by: staydelight --- monai/data/image_reader.py | 28 +--------------------------- monai/data/image_writer.py | 26 +------------------------- monai/transforms/io/array.py | 22 ++++++++++++++++++++++ monai/transforms/io/dictionary.py | 2 ++ 4 files changed, 26 insertions(+), 52 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index c33b62681c..488d3df15e 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -12,7 +12,6 @@ from __future__ import annotations import glob -import json import os import re import warnings @@ -149,25 +148,6 @@ def _stack_images(image_list: list, meta_dict: dict): return np.stack(image_list, axis=0) -def update_json(input_file=None, output_file=None): - record_path = "img-label.json" - - if not os.path.exists(record_path) or os.stat(record_path).st_size == 0: - with open(record_path, 'w') as f: - json.dump([], f) - - with open(record_path, 'r+') as f: - records = json.load(f) - if input_file: - new_record = {"image": input_file, "label": []} - records.append(new_record) - elif output_file and records: - records[-1]["label"].append(output_file) - - f.seek(0) - json.dump(records, f, indent=4) - - @require_pkg(pkg_name="itk") class ITKReader(ImageReader): """ @@ -245,7 +225,6 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): img_ = [] filenames: Sequence[PathLike] = ensure_tuple(data) - update_json(input_file=filenames) kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: @@ -486,7 +465,6 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): img_ = [] filenames: Sequence[PathLike] = ensure_tuple(data) - update_json(input_file=filenames) kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) @@ -938,7 +916,6 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): img_: list[Nifti1Image] = [] filenames: Sequence[PathLike] = ensure_tuple(data) - update_json(input_file=filenames) kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: @@ -1099,7 +1076,6 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): img_: list[Nifti1Image] = [] filenames: Sequence[PathLike] = ensure_tuple(data) - update_json(input_file=filenames) kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: @@ -1197,7 +1173,6 @@ def read(self, data: Sequence[PathLike] | PathLike | np.ndarray, **kwargs): img_: list[PILImage.Image] = [] filenames: Sequence[PathLike] = ensure_tuple(data) - update_json(input_file=filenames) kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: @@ -1322,7 +1297,6 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs) -> Sequence[Any] | """ img_: list = [] filenames: Sequence[PathLike] = ensure_tuple(data) - update_json(input_file=filenames) kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: @@ -1420,4 +1394,4 @@ def _convert_f_to_c_order(self, header: dict) -> dict: header["space directions"] = np.rot90(np.flip(header["space directions"], 0)) header["space origin"] = header["space origin"][::-1] header["sizes"] = header["sizes"][::-1] - return header + return header \ No newline at end of file diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py index 4b7d95e71a..ba1c9dde27 100644 --- a/monai/data/image_writer.py +++ b/monai/data/image_writer.py @@ -15,8 +15,6 @@ from typing import TYPE_CHECKING, Any, cast import numpy as np -import os -import json from monai.apps.utils import get_logger from monai.config import DtypeLike, NdarrayOrTensor, PathLike @@ -198,25 +196,6 @@ def write(self, filename: PathLike, verbose: bool = True, **kwargs): if verbose: logger.info(f"writing: {filename}") - def update_json(self, input_file=None, output_file=None): - record_path = "img-label.json" - - if not os.path.exists(record_path) or os.stat(record_path).st_size == 0: - with open(record_path, 'w') as f: - json.dump([], f) - - with open(record_path, 'r+') as f: - records = json.load(f) - if input_file: - new_record = {"image": input_file, "label": []} - records.append(new_record) - elif output_file and records: - records[-1]["label"].append(output_file) - - f.seek(0) - json.dump(records, f, indent=4) - - @classmethod def create_backend_obj(cls, data_array: NdarrayOrTensor, **kwargs) -> np.ndarray: """ @@ -484,7 +463,6 @@ def write(self, filename: PathLike, verbose: bool = False, **kwargs): - https://github.com/InsightSoftwareConsortium/ITK/blob/v5.2.1/Wrapping/Generators/Python/itk/support/extras.py#L809 """ super().write(filename, verbose=verbose) - super().update_json(output_file=filename) self.data_obj = self.create_backend_obj( cast(NdarrayOrTensor, self.data_obj), channel_dim=self.channel_dim, @@ -648,7 +626,6 @@ def write(self, filename: PathLike, verbose: bool = False, **obj_kwargs): - https://nipy.org/nibabel/reference/nibabel.nifti1.html#nibabel.nifti1.save """ super().write(filename, verbose=verbose) - super().update_json(output_file=filename) self.data_obj = self.create_backend_obj( cast(NdarrayOrTensor, self.data_obj), affine=self.affine, dtype=self.output_dtype, **obj_kwargs ) @@ -794,7 +771,6 @@ def write(self, filename: PathLike, verbose: bool = False, **kwargs): - https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image.save """ super().write(filename, verbose=verbose) - super().update_json(output_file=filename) self.data_obj = self.create_backend_obj( data_array=self.data_obj, dtype=self.output_dtype, @@ -895,4 +871,4 @@ def init(): for ext in ("nii.gz", "nii"): register_writer(ext, NibabelWriter, ITKWriter) register_writer("nrrd", ITKWriter, NibabelWriter) - register_writer(EXT_WILDCARD, ITKWriter, NibabelWriter, ITKWriter) + register_writer(EXT_WILDCARD, ITKWriter, NibabelWriter, ITKWriter) \ No newline at end of file diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 7222a26fc3..04492112b3 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -393,6 +393,7 @@ def __init__( output_name_formatter: Callable[[dict, Transform], dict] | None = None, folder_layout: FolderLayoutBase | None = None, savepath_in_metadict: bool = False, + mapping_log_path: Union[Path, str, None] = None ) -> None: self.folder_layout: FolderLayoutBase if folder_layout is None: @@ -438,6 +439,11 @@ def __init__( self.write_kwargs = {"verbose": print_log} self._data_index = 0 self.savepath_in_metadict = savepath_in_metadict + if mapping_log_path: + self.mapping_log_path = Path(mapping_log_path) + self.savepath_in_metadict = True + else: + self.mapping_log_path = None def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, write_kwargs=None): """ @@ -506,6 +512,22 @@ def __call__( self._data_index += 1 if self.savepath_in_metadict and meta_data is not None: meta_data["saved_to"] = filename + if self.mapping_log_path and meta_data is not None: + log_data = [] + log_data.append({ + "input": meta_data.get("filename_or_obj", ()), + "output": meta_data.get("saved_to", ()) + }) + + try: + with open(self.mapping_log_path, 'r') as f: + existing_log_data = json.load(f) + except FileNotFoundError: + existing_log_data = [] + + with open(self.mapping_log_path, 'w') as f: + existing_log_data.extend(log_data) + json.dump(existing_log_data, f, indent=4) return img msg = "\n".join([f"{e}" for e in err]) raise RuntimeError( diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index 4da1d422ca..966bf305ef 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -281,6 +281,7 @@ def __init__( output_name_formatter: Callable[[dict, Transform], dict] | None = None, folder_layout: monai.data.FolderLayoutBase | None = None, savepath_in_metadict: bool = False, + mapping_log_path: Union[Path, str, None] = None ) -> None: super().__init__(keys, allow_missing_keys) self.meta_keys = ensure_tuple_rep(meta_keys, len(self.keys)) @@ -304,6 +305,7 @@ def __init__( output_name_formatter=output_name_formatter, folder_layout=folder_layout, savepath_in_metadict=savepath_in_metadict, + mapping_log_path= mapping_log_path, ) def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, write_kwargs=None): From d4fb0b7db72d77188d06299d6ad1b5457cc5ab52 Mon Sep 17 00:00:00 2001 From: staydelight Date: Sun, 2 Jun 2024 00:11:38 +0800 Subject: [PATCH 08/49] fix-issue-7557 Signed-off-by: staydelight --- monai/transforms/io/array.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 04492112b3..10cb2bca9f 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -14,6 +14,7 @@ from __future__ import annotations +import json import inspect import logging import sys From bfb6d58f4355682c51a0ac03c767ffd2e13daf98 Mon Sep 17 00:00:00 2001 From: staydelight Date: Sun, 2 Jun 2024 00:23:23 +0800 Subject: [PATCH 09/49] Fixes #7557 Add code for generating a mapping json file. Signed-off-by: staydelight --- monai/transforms/io/array.py | 14 +++++++------- monai/transforms/io/dictionary.py | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 10cb2bca9f..cd8bf5f638 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -394,7 +394,7 @@ def __init__( output_name_formatter: Callable[[dict, Transform], dict] | None = None, folder_layout: FolderLayoutBase | None = None, savepath_in_metadict: bool = False, - mapping_log_path: Union[Path, str, None] = None + mapping_json_path: Union[Path, str, None] = None ) -> None: self.folder_layout: FolderLayoutBase if folder_layout is None: @@ -440,11 +440,11 @@ def __init__( self.write_kwargs = {"verbose": print_log} self._data_index = 0 self.savepath_in_metadict = savepath_in_metadict - if mapping_log_path: - self.mapping_log_path = Path(mapping_log_path) + if mapping_json_path: + self.mapping_json_path = Path(mapping_json_path) self.savepath_in_metadict = True else: - self.mapping_log_path = None + self.mapping_json_path = None def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, write_kwargs=None): """ @@ -513,7 +513,7 @@ def __call__( self._data_index += 1 if self.savepath_in_metadict and meta_data is not None: meta_data["saved_to"] = filename - if self.mapping_log_path and meta_data is not None: + if self.mapping_json_path and meta_data is not None: log_data = [] log_data.append({ "input": meta_data.get("filename_or_obj", ()), @@ -521,12 +521,12 @@ def __call__( }) try: - with open(self.mapping_log_path, 'r') as f: + with open(self.mapping_json_path, 'r') as f: existing_log_data = json.load(f) except FileNotFoundError: existing_log_data = [] - with open(self.mapping_log_path, 'w') as f: + with open(self.mapping_json_path, 'w') as f: existing_log_data.extend(log_data) json.dump(existing_log_data, f, indent=4) return img diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index 966bf305ef..927e0ad718 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -281,7 +281,7 @@ def __init__( output_name_formatter: Callable[[dict, Transform], dict] | None = None, folder_layout: monai.data.FolderLayoutBase | None = None, savepath_in_metadict: bool = False, - mapping_log_path: Union[Path, str, None] = None + mapping_json_path: Union[Path, str, None] = None ) -> None: super().__init__(keys, allow_missing_keys) self.meta_keys = ensure_tuple_rep(meta_keys, len(self.keys)) @@ -305,7 +305,7 @@ def __init__( output_name_formatter=output_name_formatter, folder_layout=folder_layout, savepath_in_metadict=savepath_in_metadict, - mapping_log_path= mapping_log_path, + mapping_json_path= mapping_json_path, ) def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, write_kwargs=None): From 5ab2521a4dd310d344256548188c164a31b70b7f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 1 Jun 2024 16:32:24 +0000 Subject: [PATCH 10/49] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/data/image_reader.py | 2 +- monai/data/image_writer.py | 2 +- monai/transforms/io/array.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 488d3df15e..f5e199e2a3 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -1394,4 +1394,4 @@ def _convert_f_to_c_order(self, header: dict) -> dict: header["space directions"] = np.rot90(np.flip(header["space directions"], 0)) header["space origin"] = header["space origin"][::-1] header["sizes"] = header["sizes"][::-1] - return header \ No newline at end of file + return header diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py index ba1c9dde27..b9e8b9e68e 100644 --- a/monai/data/image_writer.py +++ b/monai/data/image_writer.py @@ -871,4 +871,4 @@ def init(): for ext in ("nii.gz", "nii"): register_writer(ext, NibabelWriter, ITKWriter) register_writer("nrrd", ITKWriter, NibabelWriter) - register_writer(EXT_WILDCARD, ITKWriter, NibabelWriter, ITKWriter) \ No newline at end of file + register_writer(EXT_WILDCARD, ITKWriter, NibabelWriter, ITKWriter) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index cd8bf5f638..961ca72111 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -443,7 +443,7 @@ def __init__( if mapping_json_path: self.mapping_json_path = Path(mapping_json_path) self.savepath_in_metadict = True - else: + else: self.mapping_json_path = None def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, write_kwargs=None): @@ -519,9 +519,9 @@ def __call__( "input": meta_data.get("filename_or_obj", ()), "output": meta_data.get("saved_to", ()) }) - + try: - with open(self.mapping_json_path, 'r') as f: + with open(self.mapping_json_path) as f: existing_log_data = json.load(f) except FileNotFoundError: existing_log_data = [] From 894854deb0fa3e02f28fb472371a723e9918bc1c Mon Sep 17 00:00:00 2001 From: staydelight Date: Sun, 2 Jun 2024 00:54:54 +0800 Subject: [PATCH 11/49] Fixes #7557 Change mapping_json_path init way. Signed-off-by: staydelight --- monai/transforms/io/array.py | 8 ++++---- monai/transforms/io/dictionary.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 961ca72111..c4f396dc64 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -394,7 +394,7 @@ def __init__( output_name_formatter: Callable[[dict, Transform], dict] | None = None, folder_layout: FolderLayoutBase | None = None, savepath_in_metadict: bool = False, - mapping_json_path: Union[Path, str, None] = None + mapping_json_path: Path | str | None = None ) -> None: self.folder_layout: FolderLayoutBase if folder_layout is None: @@ -443,7 +443,7 @@ def __init__( if mapping_json_path: self.mapping_json_path = Path(mapping_json_path) self.savepath_in_metadict = True - else: + else: self.mapping_json_path = None def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, write_kwargs=None): @@ -519,9 +519,9 @@ def __call__( "input": meta_data.get("filename_or_obj", ()), "output": meta_data.get("saved_to", ()) }) - + try: - with open(self.mapping_json_path) as f: + with open(self.mapping_json_path, 'r') as f: existing_log_data = json.load(f) except FileNotFoundError: existing_log_data = [] diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index 927e0ad718..e3214777b9 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -281,7 +281,7 @@ def __init__( output_name_formatter: Callable[[dict, Transform], dict] | None = None, folder_layout: monai.data.FolderLayoutBase | None = None, savepath_in_metadict: bool = False, - mapping_json_path: Union[Path, str, None] = None + mapping_json_path: Path | str | None = None ) -> None: super().__init__(keys, allow_missing_keys) self.meta_keys = ensure_tuple_rep(meta_keys, len(self.keys)) From c37222512ee5131de0190554e684666bd12fd687 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 1 Jun 2024 16:59:13 +0000 Subject: [PATCH 12/49] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/transforms/io/array.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index c4f396dc64..5601b2a20a 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -443,7 +443,7 @@ def __init__( if mapping_json_path: self.mapping_json_path = Path(mapping_json_path) self.savepath_in_metadict = True - else: + else: self.mapping_json_path = None def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, write_kwargs=None): @@ -519,9 +519,9 @@ def __call__( "input": meta_data.get("filename_or_obj", ()), "output": meta_data.get("saved_to", ()) }) - + try: - with open(self.mapping_json_path, 'r') as f: + with open(self.mapping_json_path) as f: existing_log_data = json.load(f) except FileNotFoundError: existing_log_data = [] From 682379b2c7a71665a65572e580c933f90f0f3ffc Mon Sep 17 00:00:00 2001 From: staydelight Date: Mon, 3 Jun 2024 16:25:42 +0800 Subject: [PATCH 13/49] Fixes #7557 Fixing unsuccessful checks. Signed-off-by: staydelight --- monai/transforms/io/array.py | 13 ++++++------- monai/transforms/io/dictionary.py | 4 ++-- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 5601b2a20a..cdcc4da80d 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -14,8 +14,8 @@ from __future__ import annotations -import json import inspect +import json import logging import sys import traceback @@ -394,7 +394,7 @@ def __init__( output_name_formatter: Callable[[dict, Transform], dict] | None = None, folder_layout: FolderLayoutBase | None = None, savepath_in_metadict: bool = False, - mapping_json_path: Path | str | None = None + mapping_json_path: Path | str | None = None, ) -> None: self.folder_layout: FolderLayoutBase if folder_layout is None: @@ -515,10 +515,9 @@ def __call__( meta_data["saved_to"] = filename if self.mapping_json_path and meta_data is not None: log_data = [] - log_data.append({ - "input": meta_data.get("filename_or_obj", ()), - "output": meta_data.get("saved_to", ()) - }) + log_data.append( + {"input": meta_data.get("filename_or_obj", ()), "output": meta_data.get("saved_to", ())} + ) try: with open(self.mapping_json_path) as f: @@ -526,7 +525,7 @@ def __call__( except FileNotFoundError: existing_log_data = [] - with open(self.mapping_json_path, 'w') as f: + with open(self.mapping_json_path, "w") as f: existing_log_data.extend(log_data) json.dump(existing_log_data, f, indent=4) return img diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index e3214777b9..3cf46272c0 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -281,7 +281,7 @@ def __init__( output_name_formatter: Callable[[dict, Transform], dict] | None = None, folder_layout: monai.data.FolderLayoutBase | None = None, savepath_in_metadict: bool = False, - mapping_json_path: Path | str | None = None + mapping_json_path: Path | str | None = None, ) -> None: super().__init__(keys, allow_missing_keys) self.meta_keys = ensure_tuple_rep(meta_keys, len(self.keys)) @@ -305,7 +305,7 @@ def __init__( output_name_formatter=output_name_formatter, folder_layout=folder_layout, savepath_in_metadict=savepath_in_metadict, - mapping_json_path= mapping_json_path, + mapping_json_path=mapping_json_path, ) def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, write_kwargs=None): From 56d8df5756586782038d9d66e5b8f036310767e0 Mon Sep 17 00:00:00 2001 From: staydelight Date: Mon, 3 Jun 2024 16:42:35 +0800 Subject: [PATCH 14/49] Fixes #7557 Fixes unseccessful ckecks. (if mapping_json_path is not None) Signed-off-by: staydelight --- monai/transforms/io/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index cdcc4da80d..b8f8cc6ee0 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -440,7 +440,7 @@ def __init__( self.write_kwargs = {"verbose": print_log} self._data_index = 0 self.savepath_in_metadict = savepath_in_metadict - if mapping_json_path: + if mapping_json_path is not None: self.mapping_json_path = Path(mapping_json_path) self.savepath_in_metadict = True else: From 8bab11b6ea954b055b71a49c800395ab633426db Mon Sep 17 00:00:00 2001 From: staydelight Date: Mon, 3 Jun 2024 16:53:34 +0800 Subject: [PATCH 15/49] Fixes #7557 Signed-off-by: staydelight --- monai/transforms/io/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index b8f8cc6ee0..d4af4e91a6 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -513,7 +513,7 @@ def __call__( self._data_index += 1 if self.savepath_in_metadict and meta_data is not None: meta_data["saved_to"] = filename - if self.mapping_json_path and meta_data is not None: + if self.mapping_json_path is not None and meta_data is not None: log_data = [] log_data.append( {"input": meta_data.get("filename_or_obj", ()), "output": meta_data.get("saved_to", ())} From ca48feca4e3a8b95731ae32df0d227ef1fe06c11 Mon Sep 17 00:00:00 2001 From: staydelight Date: Mon, 3 Jun 2024 17:58:29 +0800 Subject: [PATCH 16/49] fix-issue-7557 Signed-off-by: staydelight --- monai/transforms/io/array.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index d4af4e91a6..fee5de6270 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -440,11 +440,9 @@ def __init__( self.write_kwargs = {"verbose": print_log} self._data_index = 0 self.savepath_in_metadict = savepath_in_metadict + self.mapping_json_path = Path(mapping_json_path) if mapping_json_path is not None else None if mapping_json_path is not None: - self.mapping_json_path = Path(mapping_json_path) self.savepath_in_metadict = True - else: - self.mapping_json_path = None def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, write_kwargs=None): """ From 117dd7855fd2cf39854b923ed15c0873cc5ac7da Mon Sep 17 00:00:00 2001 From: staydelight Date: Thu, 13 Jun 2024 20:48:38 +0800 Subject: [PATCH 17/49] fix-issue-7557 Signed-off-by: staydelight --- monai/transforms/io/array.py | 64 +++++++++++++++++++++---------- monai/transforms/io/dictionary.py | 4 +- tests/test_mapping_json.py | 64 +++++++++++++++++++++++++++++++ 3 files changed, 109 insertions(+), 23 deletions(-) create mode 100644 tests/test_mapping_json.py diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index fee5de6270..480ab2b853 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -15,8 +15,8 @@ from __future__ import annotations import inspect -import json import logging +import json import sys import traceback import warnings @@ -394,7 +394,6 @@ def __init__( output_name_formatter: Callable[[dict, Transform], dict] | None = None, folder_layout: FolderLayoutBase | None = None, savepath_in_metadict: bool = False, - mapping_json_path: Path | str | None = None, ) -> None: self.folder_layout: FolderLayoutBase if folder_layout is None: @@ -440,9 +439,6 @@ def __init__( self.write_kwargs = {"verbose": print_log} self._data_index = 0 self.savepath_in_metadict = savepath_in_metadict - self.mapping_json_path = Path(mapping_json_path) if mapping_json_path is not None else None - if mapping_json_path is not None: - self.savepath_in_metadict = True def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, write_kwargs=None): """ @@ -511,21 +507,6 @@ def __call__( self._data_index += 1 if self.savepath_in_metadict and meta_data is not None: meta_data["saved_to"] = filename - if self.mapping_json_path is not None and meta_data is not None: - log_data = [] - log_data.append( - {"input": meta_data.get("filename_or_obj", ()), "output": meta_data.get("saved_to", ())} - ) - - try: - with open(self.mapping_json_path) as f: - existing_log_data = json.load(f) - except FileNotFoundError: - existing_log_data = [] - - with open(self.mapping_json_path, "w") as f: - existing_log_data.extend(log_data) - json.dump(existing_log_data, f, indent=4) return img msg = "\n".join([f"{e}" for e in err]) raise RuntimeError( @@ -534,3 +515,46 @@ def __call__( " https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies.\n" f" The current registered writers for {self.output_ext}: {self.writers}.\n{msg}" ) + +class MappingJson(Transform): + """ + Writes a JSON file that logs the mapping between input image paths and their corresponding output paths. + + Args: + mapping_json_path (Path or str): Path to the JSON file where the mappings will be saved. + """ + + def __init__(self, mapping_json_path: Path | str = "mapping.json"): + self.mapping_json_path = Path(mapping_json_path) + + def write_json(self, input_path: str, output_path: str): + """ + Args: + input_path (str): The path of the input image file. + output_path (str): The path of the output image file. + """ + log_data = {"input": input_path, "output": output_path} + try: + with self.mapping_json_path.open("r") as f: + existing_log_data = json.load(f) + except FileNotFoundError: + existing_log_data = [] + + existing_log_data.append(log_data) + + with self.mapping_json_path.open("w") as f: + json.dump(existing_log_data, f, indent=4) + + def __call__(self, img: MetaTensor): + """ + Args: + img (MetaTensor): The input image with metadata. + """ + if "saved_to" not in img.meta: + raise KeyError("The 'saved_to' key is missing from the image metadata. Ensure SaveImage is configured with savepath_in_metadict=True.") + + + input_path = img.meta["filename_or_obj"] + output_path = img.meta["saved_to"] + self.write_json(input_path, output_path) + return img \ No newline at end of file diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index 3cf46272c0..6cb01ccb19 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -281,7 +281,6 @@ def __init__( output_name_formatter: Callable[[dict, Transform], dict] | None = None, folder_layout: monai.data.FolderLayoutBase | None = None, savepath_in_metadict: bool = False, - mapping_json_path: Path | str | None = None, ) -> None: super().__init__(keys, allow_missing_keys) self.meta_keys = ensure_tuple_rep(meta_keys, len(self.keys)) @@ -305,7 +304,6 @@ def __init__( output_name_formatter=output_name_formatter, folder_layout=folder_layout, savepath_in_metadict=savepath_in_metadict, - mapping_json_path=mapping_json_path, ) def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, write_kwargs=None): @@ -323,4 +321,4 @@ def __call__(self, data): LoadImageD = LoadImageDict = LoadImaged -SaveImageD = SaveImageDict = SaveImaged +SaveImageD = SaveImageDict = SaveImaged \ No newline at end of file diff --git a/tests/test_mapping_json.py b/tests/test_mapping_json.py new file mode 100644 index 0000000000..07254ddc11 --- /dev/null +++ b/tests/test_mapping_json.py @@ -0,0 +1,64 @@ +import unittest +import json +import numpy as np +import tempfile +import nibabel as nib +import os + +from pathlib import Path +from monai.transforms import Compose, LoadImage, SaveImage +from monai.data.meta_tensor import MetaTensor +from parameterized import parameterized +from monai.transforms.io.array import MappingJson + +class TestMappingJson(unittest.TestCase): + def setUp(self): + self.mapping_json_path = "test_mapping.json" + if Path(self.mapping_json_path).exists(): + Path(self.mapping_json_path).unlink() + + TEST_CASE_1 = [{}, ["test_image.nii.gz"], (128, 128, 128), True] + TEST_CASE_2 = [{}, ["test_image.nii.gz"], (128, 128, 128), False] + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_mapping_json(self, load_params, filenames, expected_shape, savepath_in_metadict): + test_image = np.random.rand(128, 128, 128) + + with tempfile.TemporaryDirectory() as tempdir: + for i, name in enumerate(filenames): + file_path = os.path.join(tempdir, name) + nib.save(nib.Nifti1Image(test_image, np.eye(4)), file_path) + filenames[i] = file_path + + transforms = Compose([ + LoadImage(image_only=True, **load_params), + SaveImage(output_dir=tempdir, output_ext=".nii.gz", savepath_in_metadict=savepath_in_metadict), + MappingJson(mapping_json_path=self.mapping_json_path) + ]) + + if savepath_in_metadict: + result = transforms(filenames[0]) + + img = result + meta = img.meta + + self.assertEqual(img.shape, expected_shape) + + self.assertTrue(Path(self.mapping_json_path).exists()) + with open(self.mapping_json_path, "r") as f: + mapping_data = json.load(f) + + expected_mapping = [{"input": meta["filename_or_obj"], "output": meta["saved_to"]}] + self.assertEqual(expected_mapping, mapping_data) + else: + with self.assertRaises(RuntimeError) as cm: + transforms(filenames[0]) + the_exception = cm.exception + self.assertIsInstance(the_exception.__cause__, KeyError) + self.assertIn( + "The 'saved_to' key is missing from the image metadata. Ensure SaveImage is configured with savepath_in_metadict=True.", + str(the_exception.__cause__) + ) + +if __name__ == "__main__": + unittest.main() From 3908cddc1ae78349de4c5d6ce1f4bd5907e30617 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 13 Jun 2024 12:51:51 +0000 Subject: [PATCH 18/49] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/transforms/io/array.py | 6 +++--- monai/transforms/io/dictionary.py | 2 +- tests/test_mapping_json.py | 3 +-- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 480ab2b853..d96196226b 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -515,7 +515,7 @@ def __call__( " https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies.\n" f" The current registered writers for {self.output_ext}: {self.writers}.\n{msg}" ) - + class MappingJson(Transform): """ Writes a JSON file that logs the mapping between input image paths and their corresponding output paths. @@ -523,7 +523,7 @@ class MappingJson(Transform): Args: mapping_json_path (Path or str): Path to the JSON file where the mappings will be saved. """ - + def __init__(self, mapping_json_path: Path | str = "mapping.json"): self.mapping_json_path = Path(mapping_json_path) @@ -557,4 +557,4 @@ def __call__(self, img: MetaTensor): input_path = img.meta["filename_or_obj"] output_path = img.meta["saved_to"] self.write_json(input_path, output_path) - return img \ No newline at end of file + return img diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index 6cb01ccb19..4da1d422ca 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -321,4 +321,4 @@ def __call__(self, data): LoadImageD = LoadImageDict = LoadImaged -SaveImageD = SaveImageDict = SaveImaged \ No newline at end of file +SaveImageD = SaveImageDict = SaveImaged diff --git a/tests/test_mapping_json.py b/tests/test_mapping_json.py index 07254ddc11..0b4b752351 100644 --- a/tests/test_mapping_json.py +++ b/tests/test_mapping_json.py @@ -7,7 +7,6 @@ from pathlib import Path from monai.transforms import Compose, LoadImage, SaveImage -from monai.data.meta_tensor import MetaTensor from parameterized import parameterized from monai.transforms.io.array import MappingJson @@ -45,7 +44,7 @@ def test_mapping_json(self, load_params, filenames, expected_shape, savepath_in_ self.assertEqual(img.shape, expected_shape) self.assertTrue(Path(self.mapping_json_path).exists()) - with open(self.mapping_json_path, "r") as f: + with open(self.mapping_json_path) as f: mapping_data = json.load(f) expected_mapping = [{"input": meta["filename_or_obj"], "output": meta["saved_to"]}] From 36e5af09aa2180829b882ed10f3d8369a146caf6 Mon Sep 17 00:00:00 2001 From: staydelight Date: Thu, 13 Jun 2024 21:29:18 +0800 Subject: [PATCH 19/49] fix-issue-7557 Signed-off-by: staydelight --- monai/transforms/io/array.py | 8 ++++--- tests/test_mapping_json.py | 43 ++++++++++++++++++++++++++---------- 2 files changed, 36 insertions(+), 15 deletions(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index d96196226b..0badb23528 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -15,8 +15,8 @@ from __future__ import annotations import inspect -import logging import json +import logging import sys import traceback import warnings @@ -516,6 +516,7 @@ def __call__( f" The current registered writers for {self.output_ext}: {self.writers}.\n{msg}" ) + class MappingJson(Transform): """ Writes a JSON file that logs the mapping between input image paths and their corresponding output paths. @@ -551,8 +552,9 @@ def __call__(self, img: MetaTensor): img (MetaTensor): The input image with metadata. """ if "saved_to" not in img.meta: - raise KeyError("The 'saved_to' key is missing from the image metadata. Ensure SaveImage is configured with savepath_in_metadict=True.") - + raise KeyError( + "The 'saved_to' key is missing from the image metadata. Ensure SaveImage is configured with savepath_in_metadict=True." + ) input_path = img.meta["filename_or_obj"] output_path = img.meta["saved_to"] diff --git a/tests/test_mapping_json.py b/tests/test_mapping_json.py index 0b4b752351..99992c6d1b 100644 --- a/tests/test_mapping_json.py +++ b/tests/test_mapping_json.py @@ -1,15 +1,31 @@ -import unittest +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + import json -import numpy as np +import os import tempfile +import unittest +from pathlib import Path + import nibabel as nib -import os +import numpy as np +from parameterized import parameterized -from pathlib import Path +from monai.data.meta_tensor import MetaTensor from monai.transforms import Compose, LoadImage, SaveImage -from parameterized import parameterized from monai.transforms.io.array import MappingJson + class TestMappingJson(unittest.TestCase): def setUp(self): self.mapping_json_path = "test_mapping.json" @@ -29,11 +45,13 @@ def test_mapping_json(self, load_params, filenames, expected_shape, savepath_in_ nib.save(nib.Nifti1Image(test_image, np.eye(4)), file_path) filenames[i] = file_path - transforms = Compose([ - LoadImage(image_only=True, **load_params), - SaveImage(output_dir=tempdir, output_ext=".nii.gz", savepath_in_metadict=savepath_in_metadict), - MappingJson(mapping_json_path=self.mapping_json_path) - ]) + transforms = Compose( + [ + LoadImage(image_only=True, **load_params), + SaveImage(output_dir=tempdir, output_ext=".nii.gz", savepath_in_metadict=savepath_in_metadict), + MappingJson(mapping_json_path=self.mapping_json_path), + ] + ) if savepath_in_metadict: result = transforms(filenames[0]) @@ -44,7 +62,7 @@ def test_mapping_json(self, load_params, filenames, expected_shape, savepath_in_ self.assertEqual(img.shape, expected_shape) self.assertTrue(Path(self.mapping_json_path).exists()) - with open(self.mapping_json_path) as f: + with open(self.mapping_json_path, "r") as f: mapping_data = json.load(f) expected_mapping = [{"input": meta["filename_or_obj"], "output": meta["saved_to"]}] @@ -56,8 +74,9 @@ def test_mapping_json(self, load_params, filenames, expected_shape, savepath_in_ self.assertIsInstance(the_exception.__cause__, KeyError) self.assertIn( "The 'saved_to' key is missing from the image metadata. Ensure SaveImage is configured with savepath_in_metadict=True.", - str(the_exception.__cause__) + str(the_exception.__cause__), ) + if __name__ == "__main__": unittest.main() From 1a3da3874666c3b1cc6f18a6f5901b514cd6a975 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 13 Jun 2024 13:32:08 +0000 Subject: [PATCH 20/49] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_mapping_json.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_mapping_json.py b/tests/test_mapping_json.py index 99992c6d1b..0355b2c789 100644 --- a/tests/test_mapping_json.py +++ b/tests/test_mapping_json.py @@ -21,7 +21,6 @@ import numpy as np from parameterized import parameterized -from monai.data.meta_tensor import MetaTensor from monai.transforms import Compose, LoadImage, SaveImage from monai.transforms.io.array import MappingJson @@ -62,7 +61,7 @@ def test_mapping_json(self, load_params, filenames, expected_shape, savepath_in_ self.assertEqual(img.shape, expected_shape) self.assertTrue(Path(self.mapping_json_path).exists()) - with open(self.mapping_json_path, "r") as f: + with open(self.mapping_json_path) as f: mapping_data = json.load(f) expected_mapping = [{"input": meta["filename_or_obj"], "output": meta["saved_to"]}] From 37d19eddd2aeccd9057feb1a6fc8e2497f16ac5d Mon Sep 17 00:00:00 2001 From: staydelight Date: Thu, 13 Jun 2024 22:34:27 +0800 Subject: [PATCH 21/49] fix-issue-7557 Signed-off-by: staydelight --- monai/transforms/io/array.py | 4 +--- tests/test_mapping_json.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 0badb23528..e58ceaf825 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -552,9 +552,7 @@ def __call__(self, img: MetaTensor): img (MetaTensor): The input image with metadata. """ if "saved_to" not in img.meta: - raise KeyError( - "The 'saved_to' key is missing from the image metadata. Ensure SaveImage is configured with savepath_in_metadict=True." - ) + raise KeyError("Missing 'saved_to' key in metadata. Check SaveImage savepath_in_metadict.") input_path = img.meta["filename_or_obj"] output_path = img.meta["saved_to"] diff --git a/tests/test_mapping_json.py b/tests/test_mapping_json.py index 0355b2c789..7ab6820d4d 100644 --- a/tests/test_mapping_json.py +++ b/tests/test_mapping_json.py @@ -72,7 +72,7 @@ def test_mapping_json(self, load_params, filenames, expected_shape, savepath_in_ the_exception = cm.exception self.assertIsInstance(the_exception.__cause__, KeyError) self.assertIn( - "The 'saved_to' key is missing from the image metadata. Ensure SaveImage is configured with savepath_in_metadict=True.", + "Missing 'saved_to' key in metadata. Check SaveImage savepath_in_metadict.", str(the_exception.__cause__), ) From cff29264e8ff38157fe59b4a9b8777a1b7158cbc Mon Sep 17 00:00:00 2001 From: staydelight Date: Mon, 15 Jul 2024 16:43:15 +0800 Subject: [PATCH 22/49] fix-issue-7557 Signed-off-by: staydelight --- tests/test_mapping_json.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_mapping_json.py b/tests/test_mapping_json.py index 7ab6820d4d..92f09a6f93 100644 --- a/tests/test_mapping_json.py +++ b/tests/test_mapping_json.py @@ -21,6 +21,7 @@ import numpy as np from parameterized import parameterized +from monai.data import NibabelReader from monai.transforms import Compose, LoadImage, SaveImage from monai.transforms.io.array import MappingJson From 33c078b56df2bd3a82f49574c45c9e31c50610a2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 15 Jul 2024 08:48:34 +0000 Subject: [PATCH 23/49] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_mapping_json.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_mapping_json.py b/tests/test_mapping_json.py index 92f09a6f93..7ab6820d4d 100644 --- a/tests/test_mapping_json.py +++ b/tests/test_mapping_json.py @@ -21,7 +21,6 @@ import numpy as np from parameterized import parameterized -from monai.data import NibabelReader from monai.transforms import Compose, LoadImage, SaveImage from monai.transforms.io.array import MappingJson From 40b3e2197fa8b74d772b5e65c9914b3ae3d8105f Mon Sep 17 00:00:00 2001 From: staydelight Date: Sun, 21 Jul 2024 22:53:48 +0800 Subject: [PATCH 24/49] fix-issue-7557 Signed-off-by: staydelight --- tests/test_mapping_json.py | 107 ++++++++++++++++++++----------------- 1 file changed, 57 insertions(+), 50 deletions(-) diff --git a/tests/test_mapping_json.py b/tests/test_mapping_json.py index 7ab6820d4d..f787db6451 100644 --- a/tests/test_mapping_json.py +++ b/tests/test_mapping_json.py @@ -13,69 +13,76 @@ import json import os +import shutil import tempfile import unittest from pathlib import Path -import nibabel as nib import numpy as np +import torch from parameterized import parameterized -from monai.transforms import Compose, LoadImage, SaveImage -from monai.transforms.io.array import MappingJson +from monai.tests.utils import TEST_NDARRAYS, make_nifti_image +from monai.transforms import Compose, LoadImage, MappingJson, SaveImage + +TESTS = [] +for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + TEST_IMAGE = p(np.arange(24).reshape((2, 4, 3))) + TEST_AFFINE = q( + np.array( + [[-5.3, 0.0, 0.0, 102.01], [0.0, 0.52, 2.17, -7.50], [-0.0, 1.98, -0.26, -23.12], [0.0, 0.0, 0.0, 1.0]] + ) + ) + TESTS.append([TEST_IMAGE, TEST_AFFINE, True]) + TESTS.append([TEST_IMAGE, TEST_AFFINE, False]) class TestMappingJson(unittest.TestCase): def setUp(self): - self.mapping_json_path = "test_mapping.json" - if Path(self.mapping_json_path).exists(): - Path(self.mapping_json_path).unlink() - - TEST_CASE_1 = [{}, ["test_image.nii.gz"], (128, 128, 128), True] - TEST_CASE_2 = [{}, ["test_image.nii.gz"], (128, 128, 128), False] - - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) - def test_mapping_json(self, load_params, filenames, expected_shape, savepath_in_metadict): - test_image = np.random.rand(128, 128, 128) - - with tempfile.TemporaryDirectory() as tempdir: - for i, name in enumerate(filenames): - file_path = os.path.join(tempdir, name) - nib.save(nib.Nifti1Image(test_image, np.eye(4)), file_path) - filenames[i] = file_path - - transforms = Compose( - [ - LoadImage(image_only=True, **load_params), - SaveImage(output_dir=tempdir, output_ext=".nii.gz", savepath_in_metadict=savepath_in_metadict), - MappingJson(mapping_json_path=self.mapping_json_path), - ] + self.test_dir = tempfile.mkdtemp() + self.mapping_json_path = os.path.join(self.test_dir, "mapping.json") + + def tearDown(self): + shutil.rmtree(self.test_dir, ignore_errors=True) + + @parameterized.expand(TESTS) + def test_mapping_json(self, array, affine, savepath_in_metadict): + name = "test_image" + output_ext = ".nii.gz" + test_image_name = make_nifti_image(array, affine, fname=os.path.join(self.test_dir, name)) + + input_file = os.path.join(self.test_dir, test_image_name) + output_file = os.path.join(self.test_dir, name, name + "_trans" + output_ext) + + transforms = Compose( + [ + LoadImage(reader="NibabelReader", image_only=True), + SaveImage(output_dir=self.test_dir, output_ext=output_ext, savepath_in_metadict=savepath_in_metadict), + MappingJson(mapping_json_path=self.mapping_json_path), + ] + ) + + if savepath_in_metadict: + transforms(input_file) + self.assertTrue(Path(self.mapping_json_path).exists()) + with open(self.mapping_json_path, "r") as f: + mapping_data = json.load(f) + + self.assertEqual(len(mapping_data), 1) + self.assertEqual(mapping_data[0]["input"], input_file) + self.assertEqual(mapping_data[0]["output"], output_file) + else: + with self.assertRaises(RuntimeError) as cm: + transforms(input_file) + the_exception = cm.exception + cause_exception = the_exception.__cause__ + + self.assertIsInstance(cause_exception, KeyError) + self.assertIn( + "Missing 'saved_to' key in metadata. Check SaveImage savepath_in_metadict.", str(cause_exception) ) - if savepath_in_metadict: - result = transforms(filenames[0]) - - img = result - meta = img.meta - - self.assertEqual(img.shape, expected_shape) - - self.assertTrue(Path(self.mapping_json_path).exists()) - with open(self.mapping_json_path) as f: - mapping_data = json.load(f) - - expected_mapping = [{"input": meta["filename_or_obj"], "output": meta["saved_to"]}] - self.assertEqual(expected_mapping, mapping_data) - else: - with self.assertRaises(RuntimeError) as cm: - transforms(filenames[0]) - the_exception = cm.exception - self.assertIsInstance(the_exception.__cause__, KeyError) - self.assertIn( - "Missing 'saved_to' key in metadata. Check SaveImage savepath_in_metadict.", - str(the_exception.__cause__), - ) - if __name__ == "__main__": unittest.main() From 36047a2e55152ef18d79c6436b4bdbc199dc7f6e Mon Sep 17 00:00:00 2001 From: staydelight Date: Sun, 21 Jul 2024 23:28:14 +0800 Subject: [PATCH 25/49] fix-issue-7557 Signed-off-by: staydelight --- monai/transforms/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index ef1da2d855..bf15a74e04 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -238,7 +238,7 @@ ) from .inverse import InvertibleTransform, TraceableTransform from .inverse_batch_transform import BatchInverseTransform, Decollated, DecollateD, DecollateDict -from .io.array import SUPPORTED_READERS, LoadImage, SaveImage +from .io.array import SUPPORTED_READERS, LoadImage, MappingJson, SaveImage from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict from .lazy.array import ApplyPending from .lazy.dictionary import ApplyPendingd, ApplyPendingD, ApplyPendingDict From b3852880990d7a46edbbb2526985c2f395023d2b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 21 Jul 2024 15:34:08 +0000 Subject: [PATCH 26/49] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_mapping_json.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_mapping_json.py b/tests/test_mapping_json.py index f787db6451..356ddd067d 100644 --- a/tests/test_mapping_json.py +++ b/tests/test_mapping_json.py @@ -19,7 +19,6 @@ from pathlib import Path import numpy as np -import torch from parameterized import parameterized from monai.tests.utils import TEST_NDARRAYS, make_nifti_image @@ -66,7 +65,7 @@ def test_mapping_json(self, array, affine, savepath_in_metadict): if savepath_in_metadict: transforms(input_file) self.assertTrue(Path(self.mapping_json_path).exists()) - with open(self.mapping_json_path, "r") as f: + with open(self.mapping_json_path) as f: mapping_data = json.load(f) self.assertEqual(len(mapping_data), 1) From 393744881c990b8fafa70720da41ea5a76bcc89f Mon Sep 17 00:00:00 2001 From: staydelight Date: Mon, 22 Jul 2024 00:02:12 +0800 Subject: [PATCH 27/49] fix-issue-7557 Signed-off-by: staydelight --- tests/test_mapping_json.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_mapping_json.py b/tests/test_mapping_json.py index 356ddd067d..b0dfc97fe0 100644 --- a/tests/test_mapping_json.py +++ b/tests/test_mapping_json.py @@ -21,7 +21,7 @@ import numpy as np from parameterized import parameterized -from monai.tests.utils import TEST_NDARRAYS, make_nifti_image +from tests.utils import TEST_NDARRAYS, make_nifti_image from monai.transforms import Compose, LoadImage, MappingJson, SaveImage TESTS = [] From 44307fc84deab974302a9bebadf7dea8589e627c Mon Sep 17 00:00:00 2001 From: staydelight Date: Mon, 22 Jul 2024 21:06:41 +0800 Subject: [PATCH 28/49] fix-issue-7557 Signed-off-by: staydelight --- tests/test_mapping_json.py | 45 +++++++++++++++++--------------------- 1 file changed, 20 insertions(+), 25 deletions(-) diff --git a/tests/test_mapping_json.py b/tests/test_mapping_json.py index b0dfc97fe0..0a31a10682 100644 --- a/tests/test_mapping_json.py +++ b/tests/test_mapping_json.py @@ -16,55 +16,50 @@ import shutil import tempfile import unittest -from pathlib import Path import numpy as np +import torch from parameterized import parameterized -from tests.utils import TEST_NDARRAYS, make_nifti_image +from monai.data import NibabelWriter from monai.transforms import Compose, LoadImage, MappingJson, SaveImage -TESTS = [] -for p in TEST_NDARRAYS: - for q in TEST_NDARRAYS: - TEST_IMAGE = p(np.arange(24).reshape((2, 4, 3))) - TEST_AFFINE = q( - np.array( - [[-5.3, 0.0, 0.0, 102.01], [0.0, 0.52, 2.17, -7.50], [-0.0, 1.98, -0.26, -23.12], [0.0, 0.0, 0.0, 1.0]] - ) - ) - TESTS.append([TEST_IMAGE, TEST_AFFINE, True]) - TESTS.append([TEST_IMAGE, TEST_AFFINE, False]) - class TestMappingJson(unittest.TestCase): def setUp(self): - self.test_dir = tempfile.mkdtemp() - self.mapping_json_path = os.path.join(self.test_dir, "mapping.json") + self.temp_dir = tempfile.TemporaryDirectory() + self.mapping_json_path = os.path.join(self.temp_dir.name, "mapping.json") def tearDown(self): - shutil.rmtree(self.test_dir, ignore_errors=True) + self.temp_dir.cleanup() - @parameterized.expand(TESTS) - def test_mapping_json(self, array, affine, savepath_in_metadict): - name = "test_image" + @parameterized.expand([(True,), (False,)]) + def test_mapping_json(self, savepath_in_metadict): + image_data = np.arange(48, dtype=np.uint8).reshape(1, 2, 3, 8) output_ext = ".nii.gz" - test_image_name = make_nifti_image(array, affine, fname=os.path.join(self.test_dir, name)) + name = "test_image" + + input_file = os.path.join(self.temp_dir.name, name + output_ext) + output_file = os.path.join(self.temp_dir.name, name, name + "_trans" + output_ext) - input_file = os.path.join(self.test_dir, test_image_name) - output_file = os.path.join(self.test_dir, name, name + "_trans" + output_ext) + writer = NibabelWriter() + writer.set_data_array(image_data, channel_dim=None) + writer.set_metadata({"affine": np.eye(4), "original_affine": np.eye(4)}) + writer.write(input_file) transforms = Compose( [ LoadImage(reader="NibabelReader", image_only=True), - SaveImage(output_dir=self.test_dir, output_ext=output_ext, savepath_in_metadict=savepath_in_metadict), + SaveImage( + output_dir=self.temp_dir.name, output_ext=output_ext, savepath_in_metadict=savepath_in_metadict + ), MappingJson(mapping_json_path=self.mapping_json_path), ] ) if savepath_in_metadict: transforms(input_file) - self.assertTrue(Path(self.mapping_json_path).exists()) + self.assertTrue(os.path.exists(self.mapping_json_path)) with open(self.mapping_json_path) as f: mapping_data = json.load(f) From cdf4a1bd1819be48fcb9846e6fde5fa2eb81d0a5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 22 Jul 2024 13:10:45 +0000 Subject: [PATCH 29/49] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_mapping_json.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_mapping_json.py b/tests/test_mapping_json.py index 0a31a10682..2b0c854d8a 100644 --- a/tests/test_mapping_json.py +++ b/tests/test_mapping_json.py @@ -13,12 +13,10 @@ import json import os -import shutil import tempfile import unittest import numpy as np -import torch from parameterized import parameterized from monai.data import NibabelWriter From 5dd268e3c0496a0bacdf4cecc762557b2e3d5e30 Mon Sep 17 00:00:00 2001 From: staydelight Date: Thu, 25 Jul 2024 12:35:31 +0800 Subject: [PATCH 30/49] fix-issue-7557 Signed-off-by: staydelight --- tests/test_mapping_json.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/tests/test_mapping_json.py b/tests/test_mapping_json.py index 2b0c854d8a..6f7fd740ef 100644 --- a/tests/test_mapping_json.py +++ b/tests/test_mapping_json.py @@ -15,15 +15,19 @@ import os import tempfile import unittest +from pathlib import Path import numpy as np from parameterized import parameterized -from monai.data import NibabelWriter from monai.transforms import Compose, LoadImage, MappingJson, SaveImage +from monai.utils import optional_import +nib, has_nib = optional_import("nibabel") -class TestMappingJson(unittest.TestCase): + +@unittest.skipUnless(has_nib, "nibabel required") +class TestMappingJsonD(unittest.TestCase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() self.mapping_json_path = os.path.join(self.temp_dir.name, "mapping.json") @@ -32,22 +36,19 @@ def tearDown(self): self.temp_dir.cleanup() @parameterized.expand([(True,), (False,)]) - def test_mapping_json(self, savepath_in_metadict): - image_data = np.arange(48, dtype=np.uint8).reshape(1, 2, 3, 8) + def test_mapping_jsond(self, savepath_in_metadict): + test_image = np.random.rand(128, 128, 128) output_ext = ".nii.gz" name = "test_image" input_file = os.path.join(self.temp_dir.name, name + output_ext) output_file = os.path.join(self.temp_dir.name, name, name + "_trans" + output_ext) - writer = NibabelWriter() - writer.set_data_array(image_data, channel_dim=None) - writer.set_metadata({"affine": np.eye(4), "original_affine": np.eye(4)}) - writer.write(input_file) + nib.save(nib.Nifti1Image(test_image, np.eye(4)), input_file) transforms = Compose( [ - LoadImage(reader="NibabelReader", image_only=True), + LoadImage(image_only=True), SaveImage( output_dir=self.temp_dir.name, output_ext=output_ext, savepath_in_metadict=savepath_in_metadict ), From 401557a36ca3e830aadbd18e86086070dff77896 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 25 Jul 2024 04:39:21 +0000 Subject: [PATCH 31/49] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_mapping_json.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_mapping_json.py b/tests/test_mapping_json.py index 6f7fd740ef..898c828d32 100644 --- a/tests/test_mapping_json.py +++ b/tests/test_mapping_json.py @@ -15,7 +15,6 @@ import os import tempfile import unittest -from pathlib import Path import numpy as np from parameterized import parameterized From 4adb87d104c5c8ab263e22c451d7127cd00a4658 Mon Sep 17 00:00:00 2001 From: staydelight Date: Wed, 7 Aug 2024 00:34:57 +0800 Subject: [PATCH 32/49] fix-issue-7557 Signed-off-by: staydelight --- monai/transforms/io/array.py | 22 +++++---- tests/test_mapping_json.py | 86 +++++++++++++++++++++++++----------- 2 files changed, 72 insertions(+), 36 deletions(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index cf71bd1926..bc8dd9e109 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -27,6 +27,7 @@ import numpy as np import torch +from filelock import FileLock from monai.config import DtypeLike, NdarrayOrTensor, PathLike from monai.data import image_writer @@ -520,6 +521,7 @@ def __call__( class MappingJson(Transform): """ Writes a JSON file that logs the mapping between input image paths and their corresponding output paths. + This class uses FileLock to ensure safe writing to the JSON file in a multiprocess environment. Args: mapping_json_path (Path or str): Path to the JSON file where the mappings will be saved. @@ -527,6 +529,7 @@ class MappingJson(Transform): def __init__(self, mapping_json_path: Path | str = "mapping.json"): self.mapping_json_path = Path(mapping_json_path) + self.lock = FileLock(str(self.mapping_json_path) + ".lock") def write_json(self, input_path: str, output_path: str): """ @@ -535,16 +538,18 @@ def write_json(self, input_path: str, output_path: str): output_path (str): The path of the output image file. """ log_data = {"input": input_path, "output": output_path} - try: - with self.mapping_json_path.open("r") as f: - existing_log_data = json.load(f) - except FileNotFoundError: - existing_log_data = [] - existing_log_data.append(log_data) + with self.lock: + try: + with self.mapping_json_path.open("r") as f: + existing_log_data = json.load(f) + except (FileNotFoundError, json.JSONDecodeError): + existing_log_data = [] + + existing_log_data.append(log_data) - with self.mapping_json_path.open("w") as f: - json.dump(existing_log_data, f, indent=4) + with self.mapping_json_path.open("w") as f: + json.dump(existing_log_data, f, indent=4) def __call__(self, img: MetaTensor): """ @@ -553,7 +558,6 @@ def __call__(self, img: MetaTensor): """ if "saved_to" not in img.meta: raise KeyError("Missing 'saved_to' key in metadata. Check SaveImage savepath_in_metadict.") - input_path = img.meta["filename_or_obj"] output_path = img.meta["saved_to"] self.write_json(input_path, output_path) diff --git a/tests/test_mapping_json.py b/tests/test_mapping_json.py index 898c828d32..9deb79ce31 100644 --- a/tests/test_mapping_json.py +++ b/tests/test_mapping_json.py @@ -12,8 +12,12 @@ from __future__ import annotations import json +import multiprocessing import os +import random +import shutil import tempfile +import time import unittest import numpy as np @@ -25,56 +29,84 @@ nib, has_nib = optional_import("nibabel") +def create_input_file(temp_dir, name): + test_image = np.random.rand(128, 128, 128) + output_ext = ".nii.gz" + input_file = os.path.join(temp_dir, name + output_ext) + nib.save(nib.Nifti1Image(test_image, np.eye(4)), input_file) + return input_file + + +def create_transform(temp_dir, mapping_json_path, savepath_in_metadict=True): + return Compose( + [ + LoadImage(image_only=True), + SaveImage(output_dir=temp_dir, output_ext=".nii.gz", savepath_in_metadict=savepath_in_metadict), + MappingJson(mapping_json_path=mapping_json_path), + ] + ) + + +def process_image(args): + temp_dir, mapping_json_path, i = args + time.sleep(random.uniform(0, 0.1)) + input_file = create_input_file(temp_dir, f"test_image_{i}") + transform = create_transform(temp_dir, mapping_json_path) + transform(input_file) + time.sleep(random.uniform(0, 0.1)) + + @unittest.skipUnless(has_nib, "nibabel required") -class TestMappingJsonD(unittest.TestCase): +class TestMappingJson(unittest.TestCase): def setUp(self): - self.temp_dir = tempfile.TemporaryDirectory() - self.mapping_json_path = os.path.join(self.temp_dir.name, "mapping.json") + self.temp_dir = tempfile.mkdtemp() + self.mapping_json_path = os.path.join(self.temp_dir, "mapping.json") def tearDown(self): - self.temp_dir.cleanup() + shutil.rmtree(self.temp_dir) @parameterized.expand([(True,), (False,)]) - def test_mapping_jsond(self, savepath_in_metadict): - test_image = np.random.rand(128, 128, 128) - output_ext = ".nii.gz" + def test_mapping_json(self, savepath_in_metadict): name = "test_image" + input_file = create_input_file(self.temp_dir, name) + output_file = os.path.join(self.temp_dir, name, name + "_trans.nii.gz") - input_file = os.path.join(self.temp_dir.name, name + output_ext) - output_file = os.path.join(self.temp_dir.name, name, name + "_trans" + output_ext) - - nib.save(nib.Nifti1Image(test_image, np.eye(4)), input_file) - - transforms = Compose( - [ - LoadImage(image_only=True), - SaveImage( - output_dir=self.temp_dir.name, output_ext=output_ext, savepath_in_metadict=savepath_in_metadict - ), - MappingJson(mapping_json_path=self.mapping_json_path), - ] - ) + transform = create_transform(self.temp_dir, self.mapping_json_path, savepath_in_metadict) if savepath_in_metadict: - transforms(input_file) + transform(input_file) self.assertTrue(os.path.exists(self.mapping_json_path)) with open(self.mapping_json_path) as f: mapping_data = json.load(f) - self.assertEqual(len(mapping_data), 1) self.assertEqual(mapping_data[0]["input"], input_file) self.assertEqual(mapping_data[0]["output"], output_file) else: with self.assertRaises(RuntimeError) as cm: - transforms(input_file) - the_exception = cm.exception - cause_exception = the_exception.__cause__ - + transform(input_file) + cause_exception = cm.exception.__cause__ self.assertIsInstance(cause_exception, KeyError) self.assertIn( "Missing 'saved_to' key in metadata. Check SaveImage savepath_in_metadict.", str(cause_exception) ) + def test_multiprocess_mapping_json(self): + num_processes, num_images = 16, 1000 + + with multiprocessing.Pool(processes=num_processes) as pool: + args = [(self.temp_dir, self.mapping_json_path, i) for i in range(num_images)] + pool.map(process_image, args) + + with open(self.mapping_json_path) as f: + mapping_data = json.load(f) + + self.assertEqual(len(mapping_data), num_images, f"Expected {num_images} entries, but got {len(mapping_data)}") + unique_entries = set(tuple(sorted(entry.items())) for entry in mapping_data) + self.assertEqual(len(mapping_data), len(unique_entries), "Duplicate entries exist") + for entry in mapping_data: + self.assertIn("input", entry, "Entry missing 'input' key") + self.assertIn("output", entry, "Entry missing 'output' key") + if __name__ == "__main__": unittest.main() From 0c14f4fc2ac6b2852e60e33776a5d9bb48791dee Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 6 Aug 2024 16:36:42 +0000 Subject: [PATCH 33/49] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_mapping_json.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_mapping_json.py b/tests/test_mapping_json.py index 9deb79ce31..2e57c0b360 100644 --- a/tests/test_mapping_json.py +++ b/tests/test_mapping_json.py @@ -101,7 +101,7 @@ def test_multiprocess_mapping_json(self): mapping_data = json.load(f) self.assertEqual(len(mapping_data), num_images, f"Expected {num_images} entries, but got {len(mapping_data)}") - unique_entries = set(tuple(sorted(entry.items())) for entry in mapping_data) + unique_entries = {tuple(sorted(entry.items())) for entry in mapping_data} self.assertEqual(len(mapping_data), len(unique_entries), "Duplicate entries exist") for entry in mapping_data: self.assertIn("input", entry, "Entry missing 'input' key") From 8fafc050d59253c2dd0594360239b5e5dc9a654c Mon Sep 17 00:00:00 2001 From: staydelight Date: Wed, 7 Aug 2024 01:51:54 +0800 Subject: [PATCH 34/49] fix-issue-7557 Signed-off-by: staydelight --- monai/transforms/io/array.py | 2 +- tests/test_mapping_json.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index bc8dd9e109..cc45e791fc 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -27,7 +27,6 @@ import numpy as np import torch -from filelock import FileLock from monai.config import DtypeLike, NdarrayOrTensor, PathLike from monai.data import image_writer @@ -52,6 +51,7 @@ nib, _ = optional_import("nibabel") Image, _ = optional_import("PIL.Image") nrrd, _ = optional_import("nrrd") +FileLock, _ = optional_import("filelock") __all__ = ["LoadImage", "SaveImage", "SUPPORTED_READERS"] diff --git a/tests/test_mapping_json.py b/tests/test_mapping_json.py index 2e57c0b360..aaaccf1a0e 100644 --- a/tests/test_mapping_json.py +++ b/tests/test_mapping_json.py @@ -91,7 +91,7 @@ def test_mapping_json(self, savepath_in_metadict): ) def test_multiprocess_mapping_json(self): - num_processes, num_images = 16, 1000 + num_processes, num_images = 8, 300 with multiprocessing.Pool(processes=num_processes) as pool: args = [(self.temp_dir, self.mapping_json_path, i) for i in range(num_images)] From b238987d28443d367f8e0a524e961ba6e0efcc38 Mon Sep 17 00:00:00 2001 From: staydelight Date: Wed, 7 Aug 2024 02:35:10 +0800 Subject: [PATCH 35/49] fix-issue-7557 Signed-off-by: staydelight --- monai/transforms/io/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index cc45e791fc..e6a92c522d 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -51,7 +51,7 @@ nib, _ = optional_import("nibabel") Image, _ = optional_import("PIL.Image") nrrd, _ = optional_import("nrrd") -FileLock, _ = optional_import("filelock") +FileLock, _ = optional_import("filelock", name="FileLock") __all__ = ["LoadImage", "SaveImage", "SUPPORTED_READERS"] From 5c7599057e4bd3cc46374dc3fcee645100fb7f70 Mon Sep 17 00:00:00 2001 From: staydelight Date: Wed, 7 Aug 2024 13:30:19 +0800 Subject: [PATCH 36/49] fix-issue-7557 Signed-off-by: staydelight --- tests/test_mapping_json.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_mapping_json.py b/tests/test_mapping_json.py index aaaccf1a0e..212bd80903 100644 --- a/tests/test_mapping_json.py +++ b/tests/test_mapping_json.py @@ -91,7 +91,7 @@ def test_mapping_json(self, savepath_in_metadict): ) def test_multiprocess_mapping_json(self): - num_processes, num_images = 8, 300 + num_processes, num_images = 3, 50 with multiprocessing.Pool(processes=num_processes) as pool: args = [(self.temp_dir, self.mapping_json_path, i) for i in range(num_images)] From f7deb8635ee43c5d759845f4a4f4ef2fbaecf302 Mon Sep 17 00:00:00 2001 From: staydelight Date: Mon, 19 Aug 2024 15:39:48 +0800 Subject: [PATCH 37/49] fix-issue-7557 Signed-off-by: staydelight --- monai/transforms/__init__.py | 2 +- monai/transforms/io/array.py | 48 +++++------ monai/utils/enums.py | 1 + ...t_mapping_json.py => test_mapping_file.py} | 81 ++++++++++--------- 4 files changed, 66 insertions(+), 66 deletions(-) rename tests/{test_mapping_json.py => test_mapping_file.py} (51%) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 63f44d04c1..bcafcb753d 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -238,7 +238,7 @@ ) from .inverse import InvertibleTransform, TraceableTransform from .inverse_batch_transform import BatchInverseTransform, Decollated, DecollateD, DecollateDict -from .io.array import SUPPORTED_READERS, LoadImage, MappingJson, SaveImage +from .io.array import SUPPORTED_READERS, LoadImage, WriteFileMapping, SaveImage from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict from .lazy.array import ApplyPending from .lazy.dictionary import ApplyPendingd, ApplyPendingD, ApplyPendingDict diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index e6a92c522d..bfd6f32c8f 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -46,6 +46,7 @@ from monai.transforms.utility.array import EnsureChannelFirst from monai.utils import GridSamplePadMode from monai.utils import ImageMetaKey as Key +from monai.utils import MetaKeys from monai.utils import OptionalImportError, convert_to_dst_type, ensure_tuple, look_up_option, optional_import nib, _ = optional_import("nibabel") @@ -507,7 +508,7 @@ def __call__( else: self._data_index += 1 if self.savepath_in_metadict and meta_data is not None: - meta_data["saved_to"] = filename + meta_data[MetaKeys.SAVED_TO] = filename return img msg = "\n".join([f"{e}" for e in err]) raise RuntimeError( @@ -518,47 +519,40 @@ def __call__( ) -class MappingJson(Transform): +class WriteFileMapping(Transform): """ Writes a JSON file that logs the mapping between input image paths and their corresponding output paths. This class uses FileLock to ensure safe writing to the JSON file in a multiprocess environment. - + Args: - mapping_json_path (Path or str): Path to the JSON file where the mappings will be saved. + mapping_file_path (Path or str): Path to the JSON file where the mappings will be saved. """ + def __init__(self, mapping_file_path: Path | str = "mapping.json"): + self.mapping_file_path = Path(mapping_file_path) + self.lock = FileLock(str(self.mapping_file_path) + ".lock") - def __init__(self, mapping_json_path: Path | str = "mapping.json"): - self.mapping_json_path = Path(mapping_json_path) - self.lock = FileLock(str(self.mapping_json_path) + ".lock") - - def write_json(self, input_path: str, output_path: str): + def __call__(self, img: MetaTensor): """ Args: - input_path (str): The path of the input image file. - output_path (str): The path of the output image file. + img (MetaTensor): The input image with metadata. """ + if MetaKeys.SAVED_TO not in img.meta: + raise KeyError("Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True.") + + input_path = img.meta[Key.FILENAME_OR_OBJ] + output_path = img.meta[MetaKeys.SAVED_TO] log_data = {"input": input_path, "output": output_path} - + with self.lock: try: - with self.mapping_json_path.open("r") as f: + with self.mapping_file_path.open("r") as f: existing_log_data = json.load(f) except (FileNotFoundError, json.JSONDecodeError): existing_log_data = [] - + existing_log_data.append(log_data) - - with self.mapping_json_path.open("w") as f: + + with self.mapping_file_path.open("w") as f: json.dump(existing_log_data, f, indent=4) - - def __call__(self, img: MetaTensor): - """ - Args: - img (MetaTensor): The input image with metadata. - """ - if "saved_to" not in img.meta: - raise KeyError("Missing 'saved_to' key in metadata. Check SaveImage savepath_in_metadict.") - input_path = img.meta["filename_or_obj"] - output_path = img.meta["saved_to"] - self.write_json(input_path, output_path) + return img diff --git a/monai/utils/enums.py b/monai/utils/enums.py index b786e92151..eba1be18ed 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -543,6 +543,7 @@ class MetaKeys(StrEnum): SPATIAL_SHAPE = "spatial_shape" # optional key for the length in each spatial dimension SPACE = "space" # possible values of space type are defined in `SpaceKeys` ORIGINAL_CHANNEL_DIM = "original_channel_dim" # an integer or float("nan") + SAVED_TO = "saved_to" class ColorOrder(StrEnum): diff --git a/tests/test_mapping_json.py b/tests/test_mapping_file.py similarity index 51% rename from tests/test_mapping_json.py rename to tests/test_mapping_file.py index 212bd80903..ca71dcb72a 100644 --- a/tests/test_mapping_json.py +++ b/tests/test_mapping_file.py @@ -12,7 +12,6 @@ from __future__ import annotations import json -import multiprocessing import os import random import shutil @@ -23,7 +22,8 @@ import numpy as np from parameterized import parameterized -from monai.transforms import Compose, LoadImage, MappingJson, SaveImage +from monai.data import Dataset, DataLoader +from monai.transforms import Compose, LoadImage, WriteFileMapping, SaveImage from monai.utils import optional_import nib, has_nib = optional_import("nibabel") @@ -37,46 +37,37 @@ def create_input_file(temp_dir, name): return input_file -def create_transform(temp_dir, mapping_json_path, savepath_in_metadict=True): +def create_transform(temp_dir, mapping_file_path, savepath_in_metadict=True): return Compose( [ LoadImage(image_only=True), SaveImage(output_dir=temp_dir, output_ext=".nii.gz", savepath_in_metadict=savepath_in_metadict), - MappingJson(mapping_json_path=mapping_json_path), + WriteFileMapping(mapping_file_path=mapping_file_path), ] ) -def process_image(args): - temp_dir, mapping_json_path, i = args - time.sleep(random.uniform(0, 0.1)) - input_file = create_input_file(temp_dir, f"test_image_{i}") - transform = create_transform(temp_dir, mapping_json_path) - transform(input_file) - time.sleep(random.uniform(0, 0.1)) - - @unittest.skipUnless(has_nib, "nibabel required") -class TestMappingJson(unittest.TestCase): +class TestWriteFileMapping(unittest.TestCase): def setUp(self): self.temp_dir = tempfile.mkdtemp() - self.mapping_json_path = os.path.join(self.temp_dir, "mapping.json") def tearDown(self): shutil.rmtree(self.temp_dir) @parameterized.expand([(True,), (False,)]) - def test_mapping_json(self, savepath_in_metadict): + def test_mapping_file(self, savepath_in_metadict): + mapping_file_path = os.path.join(self.temp_dir, "mapping.json") name = "test_image" input_file = create_input_file(self.temp_dir, name) output_file = os.path.join(self.temp_dir, name, name + "_trans.nii.gz") - transform = create_transform(self.temp_dir, self.mapping_json_path, savepath_in_metadict) + transform = create_transform(self.temp_dir, mapping_file_path, savepath_in_metadict) if savepath_in_metadict: transform(input_file) - self.assertTrue(os.path.exists(self.mapping_json_path)) - with open(self.mapping_json_path) as f: + self.assertTrue(os.path.exists(mapping_file_path)) + with open(mapping_file_path) as f: mapping_data = json.load(f) self.assertEqual(len(mapping_data), 1) self.assertEqual(mapping_data[0]["input"], input_file) @@ -87,26 +78,40 @@ def test_mapping_json(self, savepath_in_metadict): cause_exception = cm.exception.__cause__ self.assertIsInstance(cause_exception, KeyError) self.assertIn( - "Missing 'saved_to' key in metadata. Check SaveImage savepath_in_metadict.", str(cause_exception) + "Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True.", str(cause_exception) ) - def test_multiprocess_mapping_json(self): - num_processes, num_images = 3, 50 - - with multiprocessing.Pool(processes=num_processes) as pool: - args = [(self.temp_dir, self.mapping_json_path, i) for i in range(num_images)] - pool.map(process_image, args) - - with open(self.mapping_json_path) as f: - mapping_data = json.load(f) - - self.assertEqual(len(mapping_data), num_images, f"Expected {num_images} entries, but got {len(mapping_data)}") - unique_entries = {tuple(sorted(entry.items())) for entry in mapping_data} - self.assertEqual(len(mapping_data), len(unique_entries), "Duplicate entries exist") - for entry in mapping_data: - self.assertIn("input", entry, "Entry missing 'input' key") - self.assertIn("output", entry, "Entry missing 'output' key") - + def test_multiprocess_mapping_file(self): + num_images = 50 + + single_mapping_file = os.path.join(self.temp_dir, "single_mapping.json") + multi_mapping_file = os.path.join(self.temp_dir, "multi_mapping.json") + + data = [create_input_file(self.temp_dir, f"test_image_{i}") for i in range(num_images)] + + # single process + single_transform = create_transform(self.temp_dir, single_mapping_file) + single_dataset = Dataset(data=data, transform=single_transform) + single_loader = DataLoader(single_dataset, batch_size=1, num_workers=0, shuffle=True) + for _ in single_loader: + pass + + # multiple processes + multi_transform = create_transform(self.temp_dir, multi_mapping_file) + multi_dataset = Dataset(data=data, transform=multi_transform) + multi_loader = DataLoader(multi_dataset, batch_size=2, num_workers=2, shuffle=True) + for _ in multi_loader: + pass + + with open(single_mapping_file) as f: + single_mapping_data = json.load(f) + with open(multi_mapping_file) as f: + multi_mapping_data = json.load(f) + + single_set = set((entry['input'], entry['output']) for entry in single_mapping_data) + multi_set = set((entry['input'], entry['output']) for entry in multi_mapping_data) + + self.assertEqual(single_set, multi_set) if __name__ == "__main__": - unittest.main() + unittest.main() \ No newline at end of file From 3607b1829174e31c633e30501339680ad73e7b5a Mon Sep 17 00:00:00 2001 From: staydelight Date: Mon, 19 Aug 2024 15:57:56 +0800 Subject: [PATCH 38/49] fix-issue-7557 Signed-off-by: staydelight --- monai/transforms/__init__.py | 2 +- monai/transforms/io/array.py | 27 ++++++++++++++++++--------- tests/test_mapping_file.py | 28 +++++++++++++++------------- 3 files changed, 34 insertions(+), 23 deletions(-) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index bcafcb753d..69d4426c57 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -238,7 +238,7 @@ ) from .inverse import InvertibleTransform, TraceableTransform from .inverse_batch_transform import BatchInverseTransform, Decollated, DecollateD, DecollateDict -from .io.array import SUPPORTED_READERS, LoadImage, WriteFileMapping, SaveImage +from .io.array import SUPPORTED_READERS, LoadImage, SaveImage, WriteFileMapping from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict from .lazy.array import ApplyPending from .lazy.dictionary import ApplyPendingd, ApplyPendingD, ApplyPendingDict diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index bfd6f32c8f..d63366f1a2 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -46,8 +46,14 @@ from monai.transforms.utility.array import EnsureChannelFirst from monai.utils import GridSamplePadMode from monai.utils import ImageMetaKey as Key -from monai.utils import MetaKeys -from monai.utils import OptionalImportError, convert_to_dst_type, ensure_tuple, look_up_option, optional_import +from monai.utils import ( + MetaKeys, + OptionalImportError, + convert_to_dst_type, + ensure_tuple, + look_up_option, + optional_import, +) nib, _ = optional_import("nibabel") Image, _ = optional_import("PIL.Image") @@ -523,10 +529,11 @@ class WriteFileMapping(Transform): """ Writes a JSON file that logs the mapping between input image paths and their corresponding output paths. This class uses FileLock to ensure safe writing to the JSON file in a multiprocess environment. - + Args: mapping_file_path (Path or str): Path to the JSON file where the mappings will be saved. """ + def __init__(self, mapping_file_path: Path | str = "mapping.json"): self.mapping_file_path = Path(mapping_file_path) self.lock = FileLock(str(self.mapping_file_path) + ".lock") @@ -537,22 +544,24 @@ def __call__(self, img: MetaTensor): img (MetaTensor): The input image with metadata. """ if MetaKeys.SAVED_TO not in img.meta: - raise KeyError("Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True.") - + raise KeyError( + "Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True." + ) + input_path = img.meta[Key.FILENAME_OR_OBJ] output_path = img.meta[MetaKeys.SAVED_TO] log_data = {"input": input_path, "output": output_path} - + with self.lock: try: with self.mapping_file_path.open("r") as f: existing_log_data = json.load(f) except (FileNotFoundError, json.JSONDecodeError): existing_log_data = [] - + existing_log_data.append(log_data) - + with self.mapping_file_path.open("w") as f: json.dump(existing_log_data, f, indent=4) - + return img diff --git a/tests/test_mapping_file.py b/tests/test_mapping_file.py index ca71dcb72a..ffec460a61 100644 --- a/tests/test_mapping_file.py +++ b/tests/test_mapping_file.py @@ -22,8 +22,8 @@ import numpy as np from parameterized import parameterized -from monai.data import Dataset, DataLoader -from monai.transforms import Compose, LoadImage, WriteFileMapping, SaveImage +from monai.data import DataLoader, Dataset +from monai.transforms import Compose, LoadImage, SaveImage, WriteFileMapping from monai.utils import optional_import nib, has_nib = optional_import("nibabel") @@ -78,40 +78,42 @@ def test_mapping_file(self, savepath_in_metadict): cause_exception = cm.exception.__cause__ self.assertIsInstance(cause_exception, KeyError) self.assertIn( - "Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True.", str(cause_exception) + "Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True.", + str(cause_exception), ) def test_multiprocess_mapping_file(self): num_images = 50 - + single_mapping_file = os.path.join(self.temp_dir, "single_mapping.json") multi_mapping_file = os.path.join(self.temp_dir, "multi_mapping.json") - + data = [create_input_file(self.temp_dir, f"test_image_{i}") for i in range(num_images)] - + # single process single_transform = create_transform(self.temp_dir, single_mapping_file) single_dataset = Dataset(data=data, transform=single_transform) single_loader = DataLoader(single_dataset, batch_size=1, num_workers=0, shuffle=True) for _ in single_loader: pass - + # multiple processes multi_transform = create_transform(self.temp_dir, multi_mapping_file) multi_dataset = Dataset(data=data, transform=multi_transform) multi_loader = DataLoader(multi_dataset, batch_size=2, num_workers=2, shuffle=True) for _ in multi_loader: pass - + with open(single_mapping_file) as f: single_mapping_data = json.load(f) with open(multi_mapping_file) as f: multi_mapping_data = json.load(f) - - single_set = set((entry['input'], entry['output']) for entry in single_mapping_data) - multi_set = set((entry['input'], entry['output']) for entry in multi_mapping_data) - + + single_set = set((entry["input"], entry["output"]) for entry in single_mapping_data) + multi_set = set((entry["input"], entry["output"]) for entry in multi_mapping_data) + self.assertEqual(single_set, multi_set) + if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() From 130eaa1153298c543bcd153c4315a38a019be8ed Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 19 Aug 2024 08:00:26 +0000 Subject: [PATCH 39/49] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_mapping_file.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/test_mapping_file.py b/tests/test_mapping_file.py index ffec460a61..8a9d03565c 100644 --- a/tests/test_mapping_file.py +++ b/tests/test_mapping_file.py @@ -13,10 +13,8 @@ import json import os -import random import shutil import tempfile -import time import unittest import numpy as np @@ -109,8 +107,8 @@ def test_multiprocess_mapping_file(self): with open(multi_mapping_file) as f: multi_mapping_data = json.load(f) - single_set = set((entry["input"], entry["output"]) for entry in single_mapping_data) - multi_set = set((entry["input"], entry["output"]) for entry in multi_mapping_data) + single_set = {(entry["input"], entry["output"]) for entry in single_mapping_data} + multi_set = {(entry["input"], entry["output"]) for entry in multi_mapping_data} self.assertEqual(single_set, multi_set) From b1475be40f954d6c2314cd60d47bc2611c0a6942 Mon Sep 17 00:00:00 2001 From: staydelight Date: Mon, 19 Aug 2024 17:43:04 +0800 Subject: [PATCH 40/49] fix-issue-7557 Signed-off-by: staydelight --- monai/transforms/io/array.py | 5 +++-- tests/test_mapping_file.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index d63366f1a2..42b90523b3 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -536,7 +536,6 @@ class WriteFileMapping(Transform): def __init__(self, mapping_file_path: Path | str = "mapping.json"): self.mapping_file_path = Path(mapping_file_path) - self.lock = FileLock(str(self.mapping_file_path) + ".lock") def __call__(self, img: MetaTensor): """ @@ -552,7 +551,9 @@ def __call__(self, img: MetaTensor): output_path = img.meta[MetaKeys.SAVED_TO] log_data = {"input": input_path, "output": output_path} - with self.lock: + lock = FileLock(str(self.mapping_file_path) + ".lock") + + with lock: try: with self.mapping_file_path.open("r") as f: existing_log_data = json.load(f) diff --git a/tests/test_mapping_file.py b/tests/test_mapping_file.py index 8a9d03565c..97fa4312ed 100644 --- a/tests/test_mapping_file.py +++ b/tests/test_mapping_file.py @@ -98,7 +98,7 @@ def test_multiprocess_mapping_file(self): # multiple processes multi_transform = create_transform(self.temp_dir, multi_mapping_file) multi_dataset = Dataset(data=data, transform=multi_transform) - multi_loader = DataLoader(multi_dataset, batch_size=2, num_workers=2, shuffle=True) + multi_loader = DataLoader(multi_dataset, batch_size=4, num_workers=3, shuffle=True) for _ in multi_loader: pass From ca1515622fde962885bf8ab15f10e865dfd70ab8 Mon Sep 17 00:00:00 2001 From: staydelight Date: Tue, 27 Aug 2024 15:44:52 +0800 Subject: [PATCH 41/49] fix-issue-7557 Signed-off-by: staydelight --- monai/transforms/__init__.py | 10 ++- monai/transforms/io/dictionary.py | 31 +++++++- tests/test_mapping_filed.py | 118 ++++++++++++++++++++++++++++++ 3 files changed, 156 insertions(+), 3 deletions(-) create mode 100644 tests/test_mapping_filed.py diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 69d4426c57..cf6f35dfe0 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -239,7 +239,15 @@ from .inverse import InvertibleTransform, TraceableTransform from .inverse_batch_transform import BatchInverseTransform, Decollated, DecollateD, DecollateDict from .io.array import SUPPORTED_READERS, LoadImage, SaveImage, WriteFileMapping -from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict +from .io.dictionary import ( + LoadImaged, + LoadImageD, + LoadImageDict, + SaveImaged, + SaveImageD, + SaveImageDict, + WriteFileMappingd, +) from .lazy.array import ApplyPending from .lazy.dictionary import ApplyPendingd, ApplyPendingD, ApplyPendingDict from .lazy.functional import apply_pending diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index 4da1d422ca..eb5178b3ba 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -17,16 +17,18 @@ from __future__ import annotations +from collections.abc import Hashable, Mapping, Sequence from pathlib import Path from typing import Callable import numpy as np +from filelock import FileLock import monai -from monai.config import DtypeLike, KeysCollection +from monai.config import DtypeLike, KeysCollection, NdarrayOrTensor, PathLike from monai.data import image_writer from monai.data.image_reader import ImageReader -from monai.transforms.io.array import LoadImage, SaveImage +from monai.transforms.io.array import LoadImage, SaveImage, WriteFileMapping from monai.transforms.transform import MapTransform, Transform from monai.utils import GridSamplePadMode, ensure_tuple, ensure_tuple_rep from monai.utils.enums import PostFix @@ -320,5 +322,30 @@ def __call__(self, data): return d +class WriteFileMappingd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.WriteFileMapping`. + + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + mapping_file_path: Path to the JSON file where the mappings will be saved. + Defaults to "mapping.json". + allow_missing_keys: don't raise exception if key is missing. + """ + + def __init__( + self, keys: KeysCollection, mapping_file_path: Path | str = "mapping.json", allow_missing_keys: bool = False + ) -> None: + super().__init__(keys, allow_missing_keys) + self.mapping = WriteFileMapping(mapping_file_path) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.mapping(d[key]) + return d + + LoadImageD = LoadImageDict = LoadImaged SaveImageD = SaveImageDict = SaveImaged diff --git a/tests/test_mapping_filed.py b/tests/test_mapping_filed.py new file mode 100644 index 0000000000..a9b4409d7c --- /dev/null +++ b/tests/test_mapping_filed.py @@ -0,0 +1,118 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +import os +import shutil +import tempfile +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.data import DataLoader, Dataset, decollate_batch +from monai.inferers import sliding_window_inference +from monai.networks.nets import UNet +from monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, SaveImaged, WriteFileMappingd +from monai.utils import optional_import + +nib, has_nib = optional_import("nibabel") + + +def create_input_file(temp_dir, name): + test_image = np.random.rand(128, 128, 128) + input_file = os.path.join(temp_dir, name + ".nii.gz") + nib.save(nib.Nifti1Image(test_image, np.eye(4)), input_file) + return input_file + + +TEST_CASE_1 = [["seg"], ["seg"]] +TEST_CASE_2 = [["seg"], ["image"]] +TEST_CASE_3 = [["image"], ["seg"]] +TEST_CASE_4 = [["image", "seg"], ["seg"]] +TEST_CASE_5 = [["seg"], ["image", "seg"]] + + +@unittest.skipUnless(has_nib, "nibabel required") +class TestWriteFileMappingd(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + self.output_dir = os.path.join(self.temp_dir, "output") + os.makedirs(self.output_dir) + self.mapping_file_path = os.path.join(self.temp_dir, "mapping.json") + + def tearDown(self): + shutil.rmtree(self.temp_dir) + if os.path.exists(self.mapping_file_path): + os.remove(self.mapping_file_path) + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + def test_mapping_filed(self, save_keys, write_keys): + + name = "test_image" + input_file = create_input_file(self.temp_dir, name) + output_file = os.path.join(self.output_dir, name, name + "_seg.nii.gz") + data = [{"image": input_file}] + + test_transforms = Compose([LoadImaged(keys=["image"]), EnsureChannelFirstd(keys=["image"])]) + + post_transforms = Compose( + [ + SaveImaged( + keys=save_keys, + meta_keys="image_meta_dict", + output_dir=self.output_dir, + output_postfix="seg", + savepath_in_metadict=True, + ), + WriteFileMappingd(keys=write_keys, mapping_file_path=self.mapping_file_path), + ] + ) + + dataset = Dataset(data=data, transform=test_transforms) + dataloader = DataLoader(dataset, batch_size=1) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = UNet(spatial_dims=3, in_channels=1, out_channels=2, channels=(16, 32), strides=(2,)).to(device) + model.eval() + + try: + with torch.no_grad(): + for batch_data in dataloader: + test_inputs = batch_data["image"].to(device) + roi_size = (64, 64, 64) + sw_batch_size = 2 + batch_data["seg"] = sliding_window_inference(test_inputs, roi_size, sw_batch_size, model) + batch_data = [post_transforms(i) for i in decollate_batch(batch_data)] + + self.assertTrue(os.path.exists(self.mapping_file_path)) + + with open(self.mapping_file_path, "r") as f: + mapping_data = json.load(f) + + self.assertEqual(len(mapping_data), len(write_keys)) + for entry in mapping_data: + self.assertEqual(entry["input"], input_file) + self.assertEqual(entry["output"], output_file) + + except RuntimeError as cm: + cause_exception = cm.__cause__ + self.assertIsInstance(cause_exception, KeyError) + self.assertIn( + "Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True.", + str(cause_exception), + ) + + +if __name__ == "__main__": + unittest.main() From 3dc9f4931dec855580998a85c1071ad482112425 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 27 Aug 2024 07:46:40 +0000 Subject: [PATCH 42/49] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/transforms/io/dictionary.py | 5 ++--- tests/test_mapping_filed.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index eb5178b3ba..807371583e 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -17,15 +17,14 @@ from __future__ import annotations -from collections.abc import Hashable, Mapping, Sequence +from collections.abc import Hashable, Mapping from pathlib import Path from typing import Callable import numpy as np -from filelock import FileLock import monai -from monai.config import DtypeLike, KeysCollection, NdarrayOrTensor, PathLike +from monai.config import DtypeLike, KeysCollection, NdarrayOrTensor from monai.data import image_writer from monai.data.image_reader import ImageReader from monai.transforms.io.array import LoadImage, SaveImage, WriteFileMapping diff --git a/tests/test_mapping_filed.py b/tests/test_mapping_filed.py index a9b4409d7c..835750cb72 100644 --- a/tests/test_mapping_filed.py +++ b/tests/test_mapping_filed.py @@ -97,7 +97,7 @@ def test_mapping_filed(self, save_keys, write_keys): self.assertTrue(os.path.exists(self.mapping_file_path)) - with open(self.mapping_file_path, "r") as f: + with open(self.mapping_file_path) as f: mapping_data = json.load(f) self.assertEqual(len(mapping_data), len(write_keys)) From 3ea0df2682dfd63b3fa23ee858812785e19abb90 Mon Sep 17 00:00:00 2001 From: staydelight Date: Tue, 27 Aug 2024 16:47:33 +0800 Subject: [PATCH 43/49] fix-issue-7557 Signed-off-by: staydelight --- monai/transforms/io/array.py | 11 +++++++---- monai/transforms/io/dictionary.py | 1 + 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 93d7c155d9..7836382646 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -537,18 +537,21 @@ class WriteFileMapping(Transform): def __init__(self, mapping_file_path: Path | str = "mapping.json"): self.mapping_file_path = Path(mapping_file_path) - def __call__(self, img: MetaTensor): + def __call__(self, img: MetaTensor | torch.Tensor | np.ndarray): """ Args: img (MetaTensor): The input image with metadata. """ - if MetaKeys.SAVED_TO not in img.meta: + if isinstance(img, MetaTensor): + meta_data = img.meta + + if MetaKeys.SAVED_TO not in meta_data: raise KeyError( "Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True." ) - input_path = img.meta[Key.FILENAME_OR_OBJ] - output_path = img.meta[MetaKeys.SAVED_TO] + input_path = meta_data[Key.FILENAME_OR_OBJ] + output_path = meta_data[MetaKeys.SAVED_TO] log_data = {"input": input_path, "output": output_path} lock = FileLock(str(self.mapping_file_path) + ".lock") diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index 807371583e..be1e78db8a 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -348,3 +348,4 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N LoadImageD = LoadImageDict = LoadImaged SaveImageD = SaveImageDict = SaveImaged +WriteFileMappingD = WriteFileMappingDict = WriteFileMappingd From 8ad9808256454829af7318d6a5f20e0cb77fdee5 Mon Sep 17 00:00:00 2001 From: staydelight Date: Tue, 27 Aug 2024 17:02:54 +0800 Subject: [PATCH 44/49] fix-issue-7557 Signed-off-by: staydelight --- docs/source/transforms.rst | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 637f0873f1..84f7cb267f 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -553,6 +553,12 @@ IO .. autoclass:: SaveImage :members: :special-members: __call__ + +`WriteFileMapping` +"""""""""""" +.. autoclass:: WriteFileMapping + :members: + :special-members: __call__ NVIDIA Tool Extension (NVTX) @@ -1641,6 +1647,12 @@ IO (Dict) .. autoclass:: SaveImaged :members: :special-members: __call__ + +`WriteFileMappingd` +"""""""""""" +.. autoclass:: WriteFileMappingd + :members: + :special-members: __call__ Post-processing (Dict) ^^^^^^^^^^^^^^^^^^^^^^ From 802e554513ce5b735b892796ce420b2f989ccfe9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 27 Aug 2024 09:04:23 +0000 Subject: [PATCH 45/49] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/source/transforms.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 84f7cb267f..1a5b2a738e 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -553,7 +553,7 @@ IO .. autoclass:: SaveImage :members: :special-members: __call__ - + `WriteFileMapping` """""""""""" .. autoclass:: WriteFileMapping @@ -1647,7 +1647,7 @@ IO (Dict) .. autoclass:: SaveImaged :members: :special-members: __call__ - + `WriteFileMappingd` """""""""""" .. autoclass:: WriteFileMappingd From 60f5b79acec0a43eb77eeca6a30f880caa581ff3 Mon Sep 17 00:00:00 2001 From: staydelight Date: Tue, 27 Aug 2024 17:16:40 +0800 Subject: [PATCH 46/49] fix-issue-7557 Signed-off-by: staydelight --- monai/transforms/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index cf6f35dfe0..f37016e63f 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -247,6 +247,8 @@ SaveImageD, SaveImageDict, WriteFileMappingd, + WriteFileMappingD, + WriteFileMappingDict, ) from .lazy.array import ApplyPending from .lazy.dictionary import ApplyPendingd, ApplyPendingD, ApplyPendingDict From b7957b6bc713d6ae484f849e949c313974e06eef Mon Sep 17 00:00:00 2001 From: staydelight Date: Tue, 27 Aug 2024 17:49:47 +0800 Subject: [PATCH 47/49] fix-issue-7557 Signed-off-by: staydelight --- docs/source/transforms.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 1a5b2a738e..3e45d899ec 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -555,7 +555,7 @@ IO :special-members: __call__ `WriteFileMapping` -"""""""""""" +"""""""""""""""""" .. autoclass:: WriteFileMapping :members: :special-members: __call__ @@ -1649,7 +1649,7 @@ IO (Dict) :special-members: __call__ `WriteFileMappingd` -"""""""""""" +""""""""""""""""""" .. autoclass:: WriteFileMappingd :members: :special-members: __call__ From 773a218d0655670ab82ef9099c742c135d67d1ca Mon Sep 17 00:00:00 2001 From: staydelight Date: Wed, 28 Aug 2024 00:57:18 +0800 Subject: [PATCH 48/49] fix-issue-7557 Signed-off-by: staydelight --- monai/transforms/io/array.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 7836382646..cde0727dc0 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -537,10 +537,10 @@ class WriteFileMapping(Transform): def __init__(self, mapping_file_path: Path | str = "mapping.json"): self.mapping_file_path = Path(mapping_file_path) - def __call__(self, img: MetaTensor | torch.Tensor | np.ndarray): + def __call__(self, img: NdarrayOrTensor): """ Args: - img (MetaTensor): The input image with metadata. + img: The input image with metadata. """ if isinstance(img, MetaTensor): meta_data = img.meta From b28b184ce759cf5b6c425c2e1ef4921bc8c56ba8 Mon Sep 17 00:00:00 2001 From: staydelight Date: Wed, 28 Aug 2024 15:40:48 +0800 Subject: [PATCH 49/49] fix-issue-7557 Signed-off-by: staydelight --- monai/transforms/io/array.py | 31 ++++++++-------- tests/test_mapping_filed.py | 72 +++++++++++++++++++----------------- 2 files changed, 54 insertions(+), 49 deletions(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index cde0727dc0..4e71870fc9 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -58,7 +58,7 @@ nib, _ = optional_import("nibabel") Image, _ = optional_import("PIL.Image") nrrd, _ = optional_import("nrrd") -FileLock, _ = optional_import("filelock", name="FileLock") +FileLock, has_filelock = optional_import("filelock", name="FileLock") __all__ = ["LoadImage", "SaveImage", "SUPPORTED_READERS"] @@ -554,18 +554,19 @@ def __call__(self, img: NdarrayOrTensor): output_path = meta_data[MetaKeys.SAVED_TO] log_data = {"input": input_path, "output": output_path} - lock = FileLock(str(self.mapping_file_path) + ".lock") - - with lock: - try: - with self.mapping_file_path.open("r") as f: - existing_log_data = json.load(f) - except (FileNotFoundError, json.JSONDecodeError): - existing_log_data = [] - - existing_log_data.append(log_data) - - with self.mapping_file_path.open("w") as f: - json.dump(existing_log_data, f, indent=4) - + if has_filelock: + with FileLock(str(self.mapping_file_path) + ".lock"): + self._write_to_file(log_data) + else: + self._write_to_file(log_data) return img + + def _write_to_file(self, log_data): + try: + with self.mapping_file_path.open("r") as f: + existing_log_data = json.load(f) + except (FileNotFoundError, json.JSONDecodeError): + existing_log_data = [] + existing_log_data.append(log_data) + with self.mapping_file_path.open("w") as f: + json.dump(existing_log_data, f, indent=4) diff --git a/tests/test_mapping_filed.py b/tests/test_mapping_filed.py index 835750cb72..d0f8bcf938 100644 --- a/tests/test_mapping_filed.py +++ b/tests/test_mapping_filed.py @@ -37,11 +37,11 @@ def create_input_file(temp_dir, name): return input_file -TEST_CASE_1 = [["seg"], ["seg"]] -TEST_CASE_2 = [["seg"], ["image"]] -TEST_CASE_3 = [["image"], ["seg"]] -TEST_CASE_4 = [["image", "seg"], ["seg"]] -TEST_CASE_5 = [["seg"], ["image", "seg"]] +# Test cases that should succeed +SUCCESS_CASES = [(["seg"], ["seg"]), (["image", "seg"], ["seg"])] + +# Test cases that should fail +FAILURE_CASES = [(["seg"], ["image"]), (["image"], ["seg"]), (["seg"], ["image", "seg"])] @unittest.skipUnless(has_nib, "nibabel required") @@ -57,9 +57,7 @@ def tearDown(self): if os.path.exists(self.mapping_file_path): os.remove(self.mapping_file_path) - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) - def test_mapping_filed(self, save_keys, write_keys): - + def run_test(self, save_keys, write_keys): name = "test_image" input_file = create_input_file(self.temp_dir, name) output_file = os.path.join(self.output_dir, name, name + "_seg.nii.gz") @@ -86,32 +84,38 @@ def test_mapping_filed(self, save_keys, write_keys): model = UNet(spatial_dims=3, in_channels=1, out_channels=2, channels=(16, 32), strides=(2,)).to(device) model.eval() - try: - with torch.no_grad(): - for batch_data in dataloader: - test_inputs = batch_data["image"].to(device) - roi_size = (64, 64, 64) - sw_batch_size = 2 - batch_data["seg"] = sliding_window_inference(test_inputs, roi_size, sw_batch_size, model) - batch_data = [post_transforms(i) for i in decollate_batch(batch_data)] - - self.assertTrue(os.path.exists(self.mapping_file_path)) - - with open(self.mapping_file_path) as f: - mapping_data = json.load(f) - - self.assertEqual(len(mapping_data), len(write_keys)) - for entry in mapping_data: - self.assertEqual(entry["input"], input_file) - self.assertEqual(entry["output"], output_file) - - except RuntimeError as cm: - cause_exception = cm.__cause__ - self.assertIsInstance(cause_exception, KeyError) - self.assertIn( - "Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True.", - str(cause_exception), - ) + with torch.no_grad(): + for batch_data in dataloader: + test_inputs = batch_data["image"].to(device) + roi_size = (64, 64, 64) + sw_batch_size = 2 + batch_data["seg"] = sliding_window_inference(test_inputs, roi_size, sw_batch_size, model) + batch_data = [post_transforms(i) for i in decollate_batch(batch_data)] + + return input_file, output_file + + @parameterized.expand(SUCCESS_CASES) + def test_successful_mapping_filed(self, save_keys, write_keys): + input_file, output_file = self.run_test(save_keys, write_keys) + self.assertTrue(os.path.exists(self.mapping_file_path)) + with open(self.mapping_file_path) as f: + mapping_data = json.load(f) + self.assertEqual(len(mapping_data), len(write_keys)) + for entry in mapping_data: + self.assertEqual(entry["input"], input_file) + self.assertEqual(entry["output"], output_file) + + @parameterized.expand(FAILURE_CASES) + def test_failure_mapping_filed(self, save_keys, write_keys): + with self.assertRaises(RuntimeError) as cm: + self.run_test(save_keys, write_keys) + + cause_exception = cm.exception.__cause__ + self.assertIsInstance(cause_exception, KeyError) + self.assertIn( + "Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True.", + str(cause_exception), + ) if __name__ == "__main__":