diff --git a/docs/source/highlights.md b/docs/source/highlights.md index d8fe5c2ff9..29302bda77 100644 --- a/docs/source/highlights.md +++ b/docs/source/highlights.md @@ -39,7 +39,7 @@ There is a rich set of transforms in six categories: Crop & Pad, Intensity, IO, ### 2. Medical specific transforms MONAI aims at providing a comprehensive medical image specific transformations. These currently include, for example: -- `LoadNifti`: Load Nifti format file from provided path +- `LoadImage`: Load medical specific formats file from provided path - `Spacing`: Resample input image into the specified `pixdim` - `Orientation`: Change the image's orientation into the specified `axcodes` - `RandGaussianNoise`: Perturb image intensities by adding statistical noises diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 57170a33a9..c769771f4a 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -231,24 +231,6 @@ IO :members: :special-members: __call__ -`LoadNifti` -""""""""""" -.. autoclass:: LoadNifti - :members: - :special-members: __call__ - -`LoadPNG` -""""""""" -.. autoclass:: LoadPNG - :members: - :special-members: __call__ - -`LoadNumpy` -""""""""""" -.. autoclass:: LoadNumpy - :members: - :special-members: __call__ - Post-processing ^^^^^^^^^^^^^^^ @@ -708,36 +690,12 @@ Instensity (Dict) IO (Dict) ^^^^^^^^^ -`LoadDatad` -""""""""""" -.. autoclass:: LoadDatad - :members: - :special-members: __call__ - `LoadImaged` """""""""""" .. autoclass:: LoadImaged :members: :special-members: __call__ -`LoadNiftid` -"""""""""""" -.. autoclass:: LoadNiftid - :members: - :special-members: __call__ - -`LoadPNGd` -"""""""""" -.. autoclass:: LoadPNGd - :members: - :special-members: __call__ - -`LoadNumpyd` -"""""""""""" -.. autoclass:: LoadNumpyd - :members: - :special-members: __call__ - Post-processing (Dict) ^^^^^^^^^^^^^^^^^^^^^^ diff --git a/monai/apps/datasets.py b/monai/apps/datasets.py index b4bc40ae1f..1291dac25a 100644 --- a/monai/apps/datasets.py +++ b/monai/apps/datasets.py @@ -37,9 +37,7 @@ class MedNISTDataset(Randomizable, CacheDataset): Args: root_dir: target directory to download and load MedNIST dataset. section: expected data section, can be: `training`, `validation` or `test`. - transform: transforms to execute operations on input data. the default transform is `LoadPNGd`, - which can load data into numpy array with [H, W] shape. for further usage, use `AddChanneld` - to convert the shape to [C, H, W, D]. + transform: transforms to execute operations on input data. download: whether to download and extract the MedNIST from resource link, default is False. if expected file already exists, skip downloading even set it to True. user can manually copy `MedNIST.tar.gz` file or `MedNIST` folder to root directory. @@ -158,8 +156,7 @@ class DecathlonDataset(Randomizable, CacheDataset): "Task03_Liver", "Task04_Hippocampus", "Task05_Prostate", "Task06_Lung", "Task07_Pancreas", "Task08_HepaticVessel", "Task09_Spleen", "Task10_Colon"). section: expected data section, can be: `training`, `validation` or `test`. - transform: transforms to execute operations on input data. the default transform is `LoadNiftid`, - which can load Nifti format data into numpy array with [H, W, D] or [H, W, D, C] shape. + transform: transforms to execute operations on input data. for further usage, use `AddChanneld` or `AsChannelFirstd` to convert the shape to [C, H, W, D]. download: whether to download and extract the Decathlon from resource link, default is False. if expected file already exists, skip downloading even set it to True. @@ -185,7 +182,7 @@ class DecathlonDataset(Randomizable, CacheDataset): transform = Compose( [ - LoadNiftid(keys=["image", "label"]), + LoadImaged(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), ScaleIntensityd(keys="image"), ToTensord(keys=["image", "label"]), diff --git a/monai/data/dataset.py b/monai/data/dataset.py index ed0d590bf7..047587119f 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -89,7 +89,7 @@ class PersistentDataset(Dataset): .. code-block:: python - [ LoadNiftid(keys=['image', 'label']), + [ LoadImaged(keys=['image', 'label']), Orientationd(keys=['image', 'label'], axcodes='RAS'), ScaleIntensityRanged(keys=['image'], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True), RandCropByPosNegLabeld(keys=['image', 'label'], label_key='label', spatial_size=(96, 96, 96), @@ -97,7 +97,7 @@ class PersistentDataset(Dataset): ToTensord(keys=['image', 'label'])] Upon first use a filename based dataset will be processed by the transform for the - [LoadNiftid, Orientationd, ScaleIntensityRanged] and the resulting tensor written to + [LoadImaged, Orientationd, ScaleIntensityRanged] and the resulting tensor written to the `cache_dir` before applying the remaining random dependant transforms [RandCropByPosNegLabeld, ToTensord] elements for use in the analysis. @@ -446,7 +446,7 @@ class CacheDataset(Dataset): For example, if the transform is a `Compose` of:: transforms = Compose([ - LoadNiftid(), + LoadImaged(), AddChanneld(), Spacingd(), Orientationd(), @@ -457,7 +457,7 @@ class CacheDataset(Dataset): when `transforms` is used in a multi-epoch training pipeline, before the first training epoch, this dataset will cache the results up to ``ScaleIntensityRanged``, as - all non-random transforms `LoadNiftid`, `AddChanneld`, `Spacingd`, `Orientationd`, `ScaleIntensityRanged` + all non-random transforms `LoadImaged`, `AddChanneld`, `Spacingd`, `Orientationd`, `ScaleIntensityRanged` can be cached. During training, the dataset will load the cached results and run ``RandCropByPosNegLabeld`` and ``ToTensord``, as ``RandCropByPosNegLabeld`` is a randomized transform and the outcome not cached. @@ -825,7 +825,7 @@ class ArrayDataset(Randomizable, _TorchDataset): img_transform = Compose( [ - LoadNifti(image_only=True), + LoadImage(image_only=True), AddChannel(), RandAdjustContrast() ] @@ -834,7 +834,7 @@ class ArrayDataset(Randomizable, _TorchDataset): If training based on images and the metadata, the array transforms can not be composed because several transforms receives multiple parameters or return multiple values. Then Users need - to define their own callable method to parse metadata from `LoadNifti` or set `affine` matrix + to define their own callable method to parse metadata from `LoadImage` or set `affine` matrix to `Spacing` transform:: class TestCompose(Compose): @@ -845,7 +845,7 @@ def __call__(self, input_): return self.transforms[3](img), metadata img_transform = TestCompose( [ - LoadNifti(image_only=False), + LoadImage(image_only=False), AddChannel(), Spacing(pixdim=(1.5, 1.5, 3.0)), RandAdjustContrast() diff --git a/monai/data/nifti_reader.py b/monai/data/nifti_reader.py index 9df63d9dbd..1378fb25a0 100644 --- a/monai/data/nifti_reader.py +++ b/monai/data/nifti_reader.py @@ -14,7 +14,7 @@ import numpy as np from torch.utils.data import Dataset -from monai.transforms import LoadNifti, Randomizable, apply_transform +from monai.transforms import LoadImage, Randomizable, apply_transform from monai.utils import MAX_SEED, get_seed @@ -81,8 +81,11 @@ def randomize(self, data: Optional[Any] = None) -> None: def __getitem__(self, index: int): self.randomize() meta_data = None - img_loader = LoadNifti( - as_closest_canonical=self.as_closest_canonical, image_only=self.image_only, dtype=self.dtype + img_loader = LoadImage( + reader="NibabelReader", + image_only=self.image_only, + dtype=self.dtype, + as_closest_canonical=self.as_closest_canonical, ) if self.image_only: img = img_loader(self.image_files[index]) @@ -90,7 +93,7 @@ def __getitem__(self, index: int): img, meta_data = img_loader(self.image_files[index]) seg = None if self.seg_files is not None: - seg_loader = LoadNifti(image_only=True) + seg_loader = LoadImage(image_only=True) seg = seg_loader(self.seg_files[index]) label = None if self.labels is not None: diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index e79b2ce38b..8f46abf522 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -137,22 +137,8 @@ ThresholdIntensityD, ThresholdIntensityDict, ) -from .io.array import LoadImage, LoadNifti, LoadNumpy, LoadPNG -from .io.dictionary import ( - LoadDatad, - LoadImaged, - LoadImageD, - LoadImageDict, - LoadNiftid, - LoadNiftiD, - LoadNiftiDict, - LoadNumpyd, - LoadNumpyD, - LoadNumpyDict, - LoadPNGd, - LoadPNGD, - LoadPNGDict, -) +from .io.array import LoadImage +from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict from .post.array import ( Activations, AsDiscrete, diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index afe72391e1..3e23377b36 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -55,7 +55,7 @@ def __call__(self, data: Any): - ``data`` is a Numpy ndarray, PyTorch Tensor or string - the data shape can be: - #. string data without shape, `LoadNifti` and `LoadPNG` transforms expect file paths + #. string data without shape, `LoadImage` transform expects file paths #. most of the pre-processing transforms expect: ``(num_channels, spatial_dim_1[, spatial_dim_2, ...])``, except that `AddChannel` expects (spatial_dim_1[, spatial_dim_2, ...]) and `AsChannelFirst` expects (spatial_dim_1[, spatial_dim_2, ...], num_channels) @@ -282,7 +282,7 @@ def __call__(self, data): - ``data[key]`` is a Numpy ndarray, PyTorch Tensor or string, where ``key`` is an element of ``self.keys``, the data shape can be: - #. string data without shape, `LoadNiftid` and `LoadPNGd` transforms expect file paths + #. string data without shape, `LoadImaged` transform expects file paths #. most of the pre-processing transforms expect: ``(num_channels, spatial_dim_1[, spatial_dim_2, ...])``, except that `AddChanneld` expects (spatial_dim_1[, spatial_dim_2, ...]) and `AsChannelFirstd` expects (spatial_dim_1[, spatial_dim_2, ...], num_channels) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index f1b92025a7..3b359cc460 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -13,23 +13,18 @@ https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design """ -import warnings -from pathlib import Path -from typing import Dict, List, Optional, Sequence, Union +from typing import List, Optional, Sequence, Union import numpy as np -from torch.utils.data._utils.collate import np_str_obj_array_pattern -from monai.config import KeysCollection from monai.data.image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader -from monai.data.utils import correct_nifti_header_if_necessary from monai.transforms.compose import Transform from monai.utils import ensure_tuple, optional_import nib, _ = optional_import("nibabel") Image, _ = optional_import("PIL.Image") -__all__ = ["LoadImage", "LoadNifti", "LoadPNG", "LoadNumpy"] +__all__ = ["LoadImage"] class LoadImage(Transform): @@ -132,214 +127,3 @@ def __call__( return img_array meta_data["filename_or_obj"] = ensure_tuple(filename)[0] return img_array, meta_data - - -class LoadNifti(Transform): - """ - Load Nifti format file or files from provided path. If loading a list of - files, stack them together and add a new dimension as first dimension, and - use the meta data of the first image to represent the stacked result. Note - that the affine transform of all the images should be same if ``image_only=False``. - """ - - def __init__( - self, as_closest_canonical: bool = False, image_only: bool = False, dtype: Optional[np.dtype] = np.float32 - ) -> None: - """ - Args: - as_closest_canonical: if True, load the image as closest to canonical axis format. - image_only: if True return only the image volume, otherwise return image data array and header dict. - dtype: if not None convert the loaded image to this data type. - - Note: - The transform returns image data array if `image_only` is True, - or a tuple of two elements containing the data array, and the Nifti - header in a dict format otherwise. - if a dictionary header is returned: - - - header['affine'] stores the affine of the image. - - header['original_affine'] will be additionally created to store the original affine. - """ - warnings.warn("LoadNifti will be deprecated in v0.5, please use LoadImage instead.", DeprecationWarning) - self.as_closest_canonical = as_closest_canonical - self.image_only = image_only - self.dtype = dtype - - def __call__(self, filename: Union[Sequence[Union[Path, str]], Path, str]): - """ - Args: - filename: path file or file-like object or a list of files. - """ - filename = ensure_tuple(filename) - img_array = [] - compatible_meta: Dict = {} - for name in filename: - img = nib.load(name) - img = correct_nifti_header_if_necessary(img) - header = dict(img.header) - header["filename_or_obj"] = name - header["affine"] = img.affine - header["original_affine"] = img.affine.copy() - header["as_closest_canonical"] = self.as_closest_canonical - ndim = img.header["dim"][0] - spatial_rank = min(ndim, 3) - header["spatial_shape"] = img.header["dim"][1 : spatial_rank + 1] - - if self.as_closest_canonical: - img = nib.as_closest_canonical(img) - header["affine"] = img.affine - - img_array.append(np.array(img.get_fdata(dtype=self.dtype))) - img.uncache() - - if self.image_only: - continue - - if not compatible_meta: - for meta_key in header: - meta_datum = header[meta_key] - if ( - isinstance(meta_datum, np.ndarray) - and np_str_obj_array_pattern.search(meta_datum.dtype.str) is not None - ): - continue - compatible_meta[meta_key] = meta_datum - else: - assert np.allclose( - header["affine"], compatible_meta["affine"] - ), "affine data of all images should be same." - - img_array = np.stack(img_array, axis=0) if len(img_array) > 1 else img_array[0] - if self.image_only: - return img_array - return img_array, compatible_meta - - -class LoadPNG(Transform): - """ - Load common 2D image format (PNG, JPG, etc. using PIL) file or files from provided path. - If loading a list of files, stack them together and add a new dimension as first dimension, - and use the meta data of the first image to represent the stacked result. - It's based on the Image module in PIL library: - https://pillow.readthedocs.io/en/stable/reference/Image.html - """ - - def __init__(self, image_only: bool = False, dtype: Optional[np.dtype] = np.float32) -> None: - """ - Args: - image_only: if True return only the image volume, otherwise return image data array and metadata. - dtype: if not None convert the loaded image to this data type. - """ - warnings.warn("LoadPNG will be deprecated in v0.5, please use LoadImage instead.", DeprecationWarning) - self.image_only = image_only - self.dtype = dtype - - def __call__(self, filename: Union[Sequence[Union[Path, str]], Path, str]): - """ - Args: - filename: path file or file-like object or a list of files. - """ - filename = ensure_tuple(filename) - img_array = [] - compatible_meta = None - for name in filename: - img = Image.open(name) - data = np.asarray(img) - if self.dtype: - data = data.astype(self.dtype) - img_array.append(data) - - if self.image_only: - continue - - meta = {} - meta["filename_or_obj"] = name - meta["spatial_shape"] = data.shape[:2] - meta["format"] = img.format - meta["mode"] = img.mode - meta["width"] = img.width - meta["height"] = img.height - if not compatible_meta: - compatible_meta = meta - else: - assert np.allclose( - meta["spatial_shape"], compatible_meta["spatial_shape"] - ), "all the images in the list should have same spatial shape." - - img_array = np.stack(img_array, axis=0) if len(img_array) > 1 else img_array[0] - return img_array if self.image_only else (img_array, compatible_meta) - - -class LoadNumpy(Transform): - """ - Load arrays or pickled objects from .npy, .npz or pickled files, file or files are from provided path. - A typical usage is to load the `mask` data for classification task. - If loading a list of files or loading npz file, stack results together and add a new dimension as first dimension, - and use the meta data of the first file to represent the stacked result. - It can load part of the npz file with specified `npz_keys`. - It's based on the Numpy load/read API: - https://numpy.org/doc/stable/reference/generated/numpy.load.html - - """ - - def __init__( - self, data_only: bool = False, dtype: Optional[np.dtype] = np.float32, npz_keys: Optional[KeysCollection] = None - ) -> None: - """ - Args: - data_only: if True return only the data array, otherwise return data array and metadata. - dtype: if not None convert the loaded data to this data type. - npz_keys: if loading npz file, only load the specified keys, if None, load all the items. - stack the loaded items together to construct a new first dimension. - - """ - warnings.warn("LoadNumpy will be deprecated in v0.5, please use LoadImage instead.", DeprecationWarning) - self.data_only = data_only - self.dtype = dtype - if npz_keys is not None: - npz_keys = ensure_tuple(npz_keys) - self.npz_keys = npz_keys - - def __call__(self, filename: Union[Sequence[Union[Path, str]], Path, str]): - """ - Args: - filename: path file or file-like object or a list of files. - - Raises: - ValueError: When ``filename`` is a sequence and contains a "npz" file extension. - - """ - if isinstance(filename, (tuple, list)): - for name in filename: - if name.endswith(".npz"): - raise ValueError("Cannot load a sequence of npz files.") - filename = ensure_tuple(filename) - data_array: List = [] - compatible_meta = None - - def _save_data_meta(data_array, name, data, compatible_meta): - data_array.append(data if self.dtype is None else data.astype(self.dtype)) - if not self.data_only: - meta = {} - meta["filename_or_obj"] = name - meta["spatial_shape"] = data.shape - if not compatible_meta: - compatible_meta = meta - else: - assert np.allclose( - meta["spatial_shape"], compatible_meta["spatial_shape"] - ), "all the data in the list should have same shape." - return compatible_meta - - for name in filename: - data = np.load(name, allow_pickle=True) - if name.endswith(".npz"): - # load expected items from NPZ file - npz_keys = [f"arr_{i}" for i in range(len(data))] if self.npz_keys is None else self.npz_keys - for k in npz_keys: - compatible_meta = _save_data_meta(data_array, name, data[k], compatible_meta) - else: - compatible_meta = _save_data_meta(data_array, name, data, compatible_meta) - - data_array = np.stack(data_array, axis=0) if len(data_array) > 1 else data_array[0] - return data_array if self.data_only else (data_array, compatible_meta) diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index 474fbd0a50..62ac4c8562 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -15,29 +15,19 @@ Class names are ended with 'd' to denote dictionary-based transforms. """ -from typing import Callable, Optional, Union +from typing import Optional, Union import numpy as np from monai.config import KeysCollection from monai.data.image_reader import ImageReader from monai.transforms.compose import MapTransform -from monai.transforms.io.array import LoadImage, LoadNifti, LoadNumpy, LoadPNG +from monai.transforms.io.array import LoadImage __all__ = [ "LoadImaged", - "LoadDatad", - "LoadNiftid", - "LoadPNGd", - "LoadNumpyd", "LoadImageD", "LoadImageDict", - "LoadNiftiD", - "LoadNiftiDict", - "LoadPNGD", - "LoadPNGDict", - "LoadNumpyD", - "LoadNumpyDict", ] @@ -116,161 +106,4 @@ def __call__(self, data, reader: Optional[ImageReader] = None): return d -class LoadDatad(MapTransform): - """ - Base class for dictionary-based wrapper of IO loader transforms. - It must load image and metadata together. If loading a list of files in one key, - stack them together and add a new dimension as the first dimension, and use the - meta data of the first image to represent the stacked result. Note that the affine - transform of all the stacked images should be same. The output metadata field will - be created as ``key_{meta_key_postfix}``. - """ - - def __init__( - self, - keys: KeysCollection, - loader: Callable, - meta_key_postfix: str = "meta_dict", - overwriting: bool = False, - ) -> None: - """ - Args: - keys: keys of the corresponding items to be transformed. - See also: :py:class:`monai.transforms.compose.MapTransform` - loader: callable function to load data from expected source. - typically, it's array level transform, for example: `LoadNifti`, - `LoadPNG` and `LoadNumpy`, etc. - meta_key_postfix: use `key_{postfix}` to store the metadata of the loaded data, - default is `meta_dict`. The meta data is a dictionary object. - For example, load Nifti file for `image`, store the metadata into `image_meta_dict`. - overwriting: whether allow to overwrite existing meta data of same key. - default is False, which will raise exception if encountering existing key. - - Raises: - TypeError: When ``loader`` is not ``callable``. - TypeError: When ``meta_key_postfix`` is not a ``str``. - - """ - super().__init__(keys) - if not callable(loader): - raise TypeError(f"loader must be callable but is {type(loader).__name__}.") - self.loader = loader - if not isinstance(meta_key_postfix, str): - raise TypeError(f"meta_key_postfix must be a str but is {type(meta_key_postfix).__name__}.") - self.meta_key_postfix = meta_key_postfix - self.overwriting = overwriting - - def __call__(self, data): - """ - Raises: - KeyError: When not ``self.overwriting`` and key already exists in ``data``. - - """ - d = dict(data) - for key in self.keys: - data = self.loader(d[key]) - assert isinstance(data, (tuple, list)), "loader must return a tuple or list." - d[key] = data[0] - assert isinstance(data[1], dict), "metadata must be a dict." - key_to_add = f"{key}_{self.meta_key_postfix}" - if key_to_add in d and not self.overwriting: - raise KeyError(f"Meta data with key {key_to_add} already exists and overwriting=False.") - d[key_to_add] = data[1] - return d - - -class LoadNiftid(LoadDatad): - """ - Dictionary-based wrapper of :py:class:`monai.transforms.LoadNifti`, - must load image and metadata together. If loading a list of files in one key, - stack them together and add a new dimension as the first dimension, and use the - meta data of the first image to represent the stacked result. Note that the affine - transform of all the stacked images should be same. The output metadata field will - be created as ``key_{meta_key_postfix}``. - """ - - def __init__( - self, - keys: KeysCollection, - as_closest_canonical: bool = False, - dtype: Optional[np.dtype] = np.float32, - meta_key_postfix: str = "meta_dict", - overwriting: bool = False, - ) -> None: - """ - Args: - keys: keys of the corresponding items to be transformed. - See also: :py:class:`monai.transforms.compose.MapTransform` - as_closest_canonical: if True, load the image as closest to canonical axis format. - dtype: if not None convert the loaded image data to this data type. - meta_key_postfix: use `key_{postfix}` to store the metadata of the nifti image, - default is `meta_dict`. The meta data is a dictionary object. - For example, load nifti file for `image`, store the metadata into `image_meta_dict`. - overwriting: whether allow to overwrite existing meta data of same key. - default is False, which will raise exception if encountering existing key. - """ - loader = LoadNifti(as_closest_canonical, False, dtype) - super().__init__(keys, loader, meta_key_postfix, overwriting) - - -class LoadPNGd(LoadDatad): - """ - Dictionary-based wrapper of :py:class:`monai.transforms.LoadPNG`. - """ - - def __init__( - self, - keys: KeysCollection, - dtype: Optional[np.dtype] = np.float32, - meta_key_postfix: str = "meta_dict", - overwriting: bool = False, - ) -> None: - """ - Args: - keys: keys of the corresponding items to be transformed. - See also: :py:class:`monai.transforms.compose.MapTransform` - dtype: if not None convert the loaded image data to this data type. - meta_key_postfix: use `key_{postfix}` to store the metadata of the PNG image, - default is `meta_dict`. The meta data is a dictionary object. - For example, load PNG file for `image`, store the metadata into `image_meta_dict`. - overwriting: whether allow to overwrite existing meta data of same key. - default is False, which will raise exception if encountering existing key. - """ - loader = LoadPNG(False, dtype) - super().__init__(keys, loader, meta_key_postfix, overwriting) - - -class LoadNumpyd(LoadDatad): - """ - Dictionary-based wrapper of :py:class:`monai.transforms.LoadNumpy`. - """ - - def __init__( - self, - keys: KeysCollection, - dtype: Optional[np.dtype] = np.float32, - npz_keys: Optional[KeysCollection] = None, - meta_key_postfix: str = "meta_dict", - overwriting: bool = False, - ) -> None: - """ - Args: - keys: keys of the corresponding items to be transformed. - See also: :py:class:`monai.transforms.compose.MapTransform` - dtype: if not None convert the loaded data to this data type. - npz_keys: if loading npz file, only load the specified keys, if None, load all the items. - stack the loaded items together to construct a new first dimension. - meta_key_postfix: use `key_{postfix}` to store the metadata of the Numpy data, - default is `meta_dict`. The meta data is a dictionary object. - For example, load Numpy file for `mask`, store the metadata into `mask_meta_dict`. - overwriting: whether allow to overwrite existing meta data of same key. - default is False, which will raise exception if encountering existing key. - """ - loader = LoadNumpy(data_only=False, dtype=dtype, npz_keys=npz_keys) - super().__init__(keys, loader, meta_key_postfix, overwriting) - - LoadImageD = LoadImageDict = LoadImaged -LoadNiftiD = LoadNiftiDict = LoadNiftid -LoadPNGD = LoadPNGDict = LoadPNGd -LoadNumpyD = LoadNumpyDict = LoadNumpyd diff --git a/tests/min_tests.py b/tests/min_tests.py index 4cca8b5af8..daf238a154 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -67,10 +67,6 @@ def run_testsuit(): "test_lmdbdataset", "test_load_image", "test_load_imaged", - "test_load_nifti", - "test_load_niftid", - "test_load_png", - "test_load_pngd", "test_load_spacing_orientation", "test_mednistdataset", "test_nifti_dataset", diff --git a/tests/test_load_nifti.py b/tests/test_load_nifti.py deleted file mode 100644 index 325dbd1f1b..0000000000 --- a/tests/test_load_nifti.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import tempfile -import unittest - -import nibabel as nib -import numpy as np -from parameterized import parameterized - -from monai.transforms import LoadNifti - -TEST_CASE_1 = [{"as_closest_canonical": False, "image_only": True}, ["test_image.nii.gz"], (128, 128, 128)] - -TEST_CASE_2 = [{"as_closest_canonical": False, "image_only": False}, ["test_image.nii.gz"], (128, 128, 128)] - -TEST_CASE_3 = [ - {"as_closest_canonical": False, "image_only": True}, - ["test_image1.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], - (3, 128, 128, 128), -] - -TEST_CASE_4 = [ - {"as_closest_canonical": False, "image_only": False}, - ["test_image1.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], - (3, 128, 128, 128), -] - -TEST_CASE_5 = [{"as_closest_canonical": True, "image_only": False}, ["test_image.nii.gz"], (128, 128, 128)] - - -class TestLoadNifti(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) - def test_shape(self, input_param, filenames, expected_shape): - test_image = np.random.randint(0, 2, size=[128, 128, 128]) - with tempfile.TemporaryDirectory() as tempdir: - for i, name in enumerate(filenames): - filenames[i] = os.path.join(tempdir, name) - nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i]) - result = LoadNifti(**input_param)(filenames) - - if isinstance(result, tuple): - result, header = result - self.assertTrue("affine" in header) - np.testing.assert_allclose(header["affine"], np.eye(4)) - if input_param["as_closest_canonical"]: - np.testing.assert_allclose(header["original_affine"], np.eye(4)) - self.assertTupleEqual(result.shape, expected_shape) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_load_niftid.py b/tests/test_load_niftid.py deleted file mode 100644 index b29b4f221c..0000000000 --- a/tests/test_load_niftid.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import tempfile -import unittest - -import nibabel as nib -import numpy as np -from parameterized import parameterized - -from monai.transforms import LoadImaged - -KEYS = ["image", "label", "extra"] - -TEST_CASE_1 = [{"keys": KEYS, "as_closest_canonical": False}, (128, 128, 128)] - - -class TestLoadNiftid(unittest.TestCase): - @parameterized.expand([TEST_CASE_1]) - def test_shape(self, input_param, expected_shape): - test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4)) - test_data = {} - with tempfile.TemporaryDirectory() as tempdir: - for key in KEYS: - nib.save(test_image, os.path.join(tempdir, key + ".nii.gz")) - test_data.update({key: os.path.join(tempdir, key + ".nii.gz")}) - result = LoadImaged(**input_param)(test_data) - - for key in KEYS: - self.assertTupleEqual(result[key].shape, expected_shape) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_load_numpy.py b/tests/test_load_numpy.py deleted file mode 100644 index 628ba43203..0000000000 --- a/tests/test_load_numpy.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import tempfile -import unittest - -import numpy as np - -from monai.transforms import LoadNumpy - - -class TestLoadNumpy(unittest.TestCase): - def test_npy(self): - test_data = np.random.randint(0, 256, size=[3, 4, 4]) - with tempfile.TemporaryDirectory() as tempdir: - filepath = os.path.join(tempdir, "test_data.npy") - np.save(filepath, test_data) - - result = LoadNumpy()(filepath) - self.assertTupleEqual(result[1]["spatial_shape"], test_data.shape) - self.assertTupleEqual(result[0].shape, test_data.shape) - np.testing.assert_allclose(result[0], test_data) - - def test_npz1(self): - test_data1 = np.random.randint(0, 256, size=[3, 4, 4]) - with tempfile.TemporaryDirectory() as tempdir: - filepath = os.path.join(tempdir, "test_data.npy") - np.save(filepath, test_data1) - - result = LoadNumpy()(filepath) - self.assertTupleEqual(result[1]["spatial_shape"], test_data1.shape) - self.assertTupleEqual(result[0].shape, test_data1.shape) - np.testing.assert_allclose(result[0], test_data1) - - def test_npz2(self): - test_data1 = np.random.randint(0, 256, size=[3, 4, 4]) - test_data2 = np.random.randint(0, 256, size=[3, 4, 4]) - with tempfile.TemporaryDirectory() as tempdir: - filepath = os.path.join(tempdir, "test_data.npz") - np.savez(filepath, test_data1, test_data2) - - result = LoadNumpy()(filepath) - self.assertTupleEqual(result[1]["spatial_shape"], test_data1.shape) - self.assertTupleEqual(result[0].shape, (2, 3, 4, 4)) - np.testing.assert_allclose(result[0], np.stack([test_data1, test_data2])) - - def test_npz3(self): - test_data1 = np.random.randint(0, 256, size=[3, 4, 4]) - test_data2 = np.random.randint(0, 256, size=[3, 4, 4]) - with tempfile.TemporaryDirectory() as tempdir: - filepath = os.path.join(tempdir, "test_data.npz") - np.savez(filepath, test1=test_data1, test2=test_data2) - - result = LoadNumpy(npz_keys=["test1", "test2"])(filepath) - self.assertTupleEqual(result[1]["spatial_shape"], test_data1.shape) - self.assertTupleEqual(result[0].shape, (2, 3, 4, 4)) - np.testing.assert_allclose(result[0], np.stack([test_data1, test_data2])) - - def test_npy_pickle(self): - test_data = {"test": np.random.randint(0, 256, size=[3, 4, 4])} - with tempfile.TemporaryDirectory() as tempdir: - filepath = os.path.join(tempdir, "test_data.npy") - np.save(filepath, test_data, allow_pickle=True) - - result = LoadNumpy(data_only=True, dtype=None)(filepath).item() - self.assertTupleEqual(result["test"].shape, test_data["test"].shape) - np.testing.assert_allclose(result["test"], test_data["test"]) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_load_numpyd.py b/tests/test_load_numpyd.py deleted file mode 100644 index f2179d7388..0000000000 --- a/tests/test_load_numpyd.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import tempfile -import unittest - -import numpy as np - -from monai.transforms import LoadNumpyd - - -class TestLoadNumpyd(unittest.TestCase): - def test_npy(self): - test_data = np.random.randint(0, 256, size=[3, 4, 4]) - with tempfile.TemporaryDirectory() as tempdir: - filepath = os.path.join(tempdir, "test_data.npy") - np.save(filepath, test_data) - - result = LoadNumpyd(keys="mask")({"mask": filepath}) - self.assertTupleEqual(result["mask_meta_dict"]["spatial_shape"], test_data.shape) - self.assertTupleEqual(result["mask"].shape, test_data.shape) - np.testing.assert_allclose(result["mask"], test_data) - - def test_npz1(self): - test_data1 = np.random.randint(0, 256, size=[3, 4, 4]) - with tempfile.TemporaryDirectory() as tempdir: - filepath = os.path.join(tempdir, "test_data.npy") - np.save(filepath, test_data1) - - result = LoadNumpyd(keys="mask")({"mask": filepath}) - self.assertTupleEqual(result["mask_meta_dict"]["spatial_shape"], test_data1.shape) - self.assertTupleEqual(result["mask"].shape, test_data1.shape) - np.testing.assert_allclose(result["mask"], test_data1) - - def test_npz2(self): - test_data1 = np.random.randint(0, 256, size=[3, 4, 4]) - test_data2 = np.random.randint(0, 256, size=[3, 4, 4]) - with tempfile.TemporaryDirectory() as tempdir: - filepath = os.path.join(tempdir, "test_data.npz") - np.savez(filepath, test_data1, test_data2) - - result = LoadNumpyd(keys="mask")({"mask": filepath}) - self.assertTupleEqual(result["mask_meta_dict"]["spatial_shape"], test_data1.shape) - self.assertTupleEqual(result["mask"].shape, (2, 3, 4, 4)) - np.testing.assert_allclose(result["mask"], np.stack([test_data1, test_data2])) - - def test_npz3(self): - test_data1 = np.random.randint(0, 256, size=[3, 4, 4]) - test_data2 = np.random.randint(0, 256, size=[3, 4, 4]) - with tempfile.TemporaryDirectory() as tempdir: - filepath = os.path.join(tempdir, "test_data.npz") - np.savez(filepath, test1=test_data1, test2=test_data2) - - result = LoadNumpyd(keys="mask", npz_keys=["test1", "test2"])({"mask": filepath}) - self.assertTupleEqual(result["mask_meta_dict"]["spatial_shape"], test_data1.shape) - self.assertTupleEqual(result["mask"].shape, (2, 3, 4, 4)) - np.testing.assert_allclose(result["mask"], np.stack([test_data1, test_data2])) - - def test_npy_pickle(self): - test_data = {"test": np.random.randint(0, 256, size=[3, 4, 4])} - with tempfile.TemporaryDirectory() as tempdir: - filepath = os.path.join(tempdir, "test_data.npy") - np.save(filepath, test_data, allow_pickle=True) - - result = LoadNumpyd(keys="mask", dtype=None)({"mask": filepath})["mask"].item() - self.assertTupleEqual(result["test"].shape, test_data["test"].shape) - np.testing.assert_allclose(result["test"], test_data["test"]) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_load_png.py b/tests/test_load_png.py deleted file mode 100644 index 2e3f60f4cd..0000000000 --- a/tests/test_load_png.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import tempfile -import unittest - -import numpy as np -from parameterized import parameterized -from PIL import Image - -from monai.transforms import LoadPNG - -TEST_CASE_1 = [(128, 128), ["test_image.png"], (128, 128), (128, 128)] - -TEST_CASE_2 = [(128, 128, 3), ["test_image.png"], (128, 128, 3), (128, 128)] - -TEST_CASE_3 = [(128, 128), ["test_image1.png", "test_image2.png", "test_image3.png"], (3, 128, 128), (128, 128)] - - -class TestLoadPNG(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_shape(self, data_shape, filenames, expected_shape, meta_shape): - test_image = np.random.randint(0, 256, size=data_shape) - with tempfile.TemporaryDirectory() as tempdir: - for i, name in enumerate(filenames): - filenames[i] = os.path.join(tempdir, name) - Image.fromarray(test_image.astype("uint8")).save(filenames[i]) - result = LoadPNG()(filenames) - self.assertTupleEqual(result[1]["spatial_shape"], meta_shape) - self.assertTupleEqual(result[0].shape, expected_shape) - if result[0].shape == test_image.shape: - np.testing.assert_allclose(result[0], test_image) - else: - np.testing.assert_allclose(result[0], np.tile(test_image, [result[0].shape[0], 1, 1])) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_load_pngd.py b/tests/test_load_pngd.py deleted file mode 100644 index 1b84dab983..0000000000 --- a/tests/test_load_pngd.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import tempfile -import unittest - -import numpy as np -from parameterized import parameterized -from PIL import Image - -from monai.transforms import LoadPNGd - -KEYS = ["image", "label", "extra"] - -TEST_CASE_1 = [{"keys": KEYS}, (128, 128, 3)] - - -class TestLoadPNGd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1]) - def test_shape(self, input_param, expected_shape): - test_image = np.random.randint(0, 256, size=[128, 128, 3]) - with tempfile.TemporaryDirectory() as tempdir: - test_data = {} - for key in KEYS: - Image.fromarray(test_image.astype("uint8")).save(os.path.join(tempdir, key + ".png")) - test_data.update({key: os.path.join(tempdir, key + ".png")}) - result = LoadPNGd(**input_param)(test_data) - for key in KEYS: - self.assertTupleEqual(result[key].shape, expected_shape) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_nifti_rw.py b/tests/test_nifti_rw.py index cf27f81f5a..7bfa10c6c5 100644 --- a/tests/test_nifti_rw.py +++ b/tests/test_nifti_rw.py @@ -18,7 +18,7 @@ from parameterized import parameterized from monai.data import write_nifti -from monai.transforms import LoadNifti, Orientation, Spacing +from monai.transforms import LoadImage, Orientation, Spacing from tests.utils import make_nifti_image TEST_IMAGE = np.arange(24).reshape((2, 4, 3)) @@ -27,11 +27,16 @@ ) TEST_CASES = [ - [TEST_IMAGE, TEST_AFFINE, dict(as_closest_canonical=True, image_only=False), np.arange(24).reshape((2, 4, 3))], [ TEST_IMAGE, TEST_AFFINE, - dict(as_closest_canonical=True, image_only=True), + dict(reader="NibabelReader", image_only=False, as_closest_canonical=True), + np.arange(24).reshape((2, 4, 3)), + ], + [ + TEST_IMAGE, + TEST_AFFINE, + dict(reader="NibabelReader", image_only=True, as_closest_canonical=True), np.array( [ [[12.0, 15.0, 18.0, 21.0], [13.0, 16.0, 19.0, 22.0], [14.0, 17.0, 20.0, 23.0]], @@ -39,9 +44,24 @@ ] ), ], - [TEST_IMAGE, TEST_AFFINE, dict(as_closest_canonical=False, image_only=True), np.arange(24).reshape((2, 4, 3))], - [TEST_IMAGE, TEST_AFFINE, dict(as_closest_canonical=False, image_only=False), np.arange(24).reshape((2, 4, 3))], - [TEST_IMAGE, None, dict(as_closest_canonical=False, image_only=False), np.arange(24).reshape((2, 4, 3))], + [ + TEST_IMAGE, + TEST_AFFINE, + dict(reader="NibabelReader", image_only=True, as_closest_canonical=False), + np.arange(24).reshape((2, 4, 3)), + ], + [ + TEST_IMAGE, + TEST_AFFINE, + dict(reader="NibabelReader", image_only=False, as_closest_canonical=False), + np.arange(24).reshape((2, 4, 3)), + ], + [ + TEST_IMAGE, + None, + dict(reader="NibabelReader", image_only=False, as_closest_canonical=False), + np.arange(24).reshape((2, 4, 3)), + ], ] @@ -51,7 +71,7 @@ def test_orientation(self, array, affine, reader_param, expected): test_image = make_nifti_image(array, affine) # read test cases - loader = LoadNifti(**reader_param) + loader = LoadImage(**reader_param) load_result = loader(test_image) if isinstance(load_result, tuple): data_array, header = load_result @@ -79,7 +99,7 @@ def test_orientation(self, array, affine, reader_param, expected): def test_consistency(self): np.set_printoptions(suppress=True, precision=3) test_image = make_nifti_image(np.arange(64).reshape(1, 8, 8), np.diag([1.5, 1.5, 1.5, 1])) - data, header = LoadNifti(as_closest_canonical=False)(test_image) + data, header = LoadImage(reader="NibabelReader", as_closest_canonical=False)(test_image) data, original_affine, new_affine = Spacing([0.8, 0.8, 0.8])(data[None], header["affine"], mode="nearest") data, _, new_affine = Orientation("ILP")(data, new_affine) if os.path.exists(test_image):