Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 69 additions & 35 deletions monai/data/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from collections.abc import Callable, Iterable, Iterator, Sequence
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Union

import numpy as np
from torch.utils.data._utils.collate import np_str_obj_array_pattern
Expand All @@ -38,14 +38,17 @@
from monai.utils import MetaKeys, SpaceKeys, TraceKeys, ensure_tuple, optional_import, require_pkg

if TYPE_CHECKING:
import cupy as cp
import itk
import nibabel as nib
import nrrd
import pydicom
from nibabel.nifti1 import Nifti1Image
from PIL import Image as PILImage

has_nrrd = has_itk = has_nib = has_pil = has_pydicom = True
has_nrrd = has_itk = has_nib = has_pil = has_pydicom = has_cp = True
Ndarray = Union[np.ndarray, cp.ndarray]

else:
itk, has_itk = optional_import("itk", allow_namespace_pkg=True)
nib, has_nib = optional_import("nibabel")
Expand All @@ -54,7 +57,12 @@
pydicom, has_pydicom = optional_import("pydicom")
nrrd, has_nrrd = optional_import("nrrd", allow_namespace_pkg=True)

cp, has_cp = optional_import("cupy")
cp, has_cp = optional_import("cupy")
if has_cp:
Ndarray = Union[np.ndarray, cp.ndarray]
else:
Ndarray = np.ndarray

kvikio, has_kvikio = optional_import("kvikio")

__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"]
Expand Down Expand Up @@ -107,15 +115,18 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs) -> Sequence[Any] |
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")

@abstractmethod
def get_data(self, img) -> tuple[np.ndarray, dict]:
def get_data(self, img) -> tuple[Ndarray, dict]:
"""
Extract data array and metadata from loaded image and return them.
This function must return two objects, the first is a numpy array of image data,
This function must return two objects, the first is a NumPy or CuPy array of image data,
the second is a dictionary of metadata.

Args:
img: an image object loaded from an image file or a list of image objects.

Returns:
A tuple of (image_array, metadata_dict).

"""
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")

Expand Down Expand Up @@ -143,7 +154,18 @@ def _copy_compatible_dict(from_dict: dict, to_dict: dict):
)


def _stack_images(image_list: list, meta_dict: dict, to_cupy: bool = False):
def _stack_images(image_list: list[Ndarray], meta_dict: dict, to_cupy: bool = False) -> Ndarray:
"""
Stack image arrays and update channel metadata.

Args:
image_list: List of image arrays to stack.
meta_dict: Metadata dict to update.
to_cupy: If True, stack using CuPy.

Returns:
The stacked image array.
"""
if len(image_list) <= 1:
return image_list[0]
if not is_no_channel(meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None)):
Expand Down Expand Up @@ -269,19 +291,22 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
img_.append(itk.imread(name, **kwargs_))
return img_ if len(filenames) > 1 else img_[0]

def get_data(self, img) -> tuple[np.ndarray, dict]:
def get_data(self, img) -> tuple[Ndarray, dict]:
"""
Extract data array and metadata from loaded image and return them.
This function returns two objects, first is numpy array of image data, second is dict of metadata.
This function returns two objects, first is numpy or cupy array of image data, second is dict of metadata.
It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict.
When loading a list of files, they are stacked together at a new dimension as the first dimension,
and the metadata of the first image is used to represent the output metadata.

Args:
img: an ITK image object loaded from an image file or a list of ITK image objects.

Returns:
A tuple of (image_array, metadata_dict).

"""
img_array: list[np.ndarray] = []
img_array: list[Ndarray] = []
compatible_meta: dict = {}

for i in ensure_tuple(img):
Expand Down Expand Up @@ -616,10 +641,10 @@ def _combine_dicom_series(self, data: Iterable, filenames: Sequence[PathLike]):

return stack_array, stack_metadata

def get_data(self, data) -> tuple[np.ndarray, dict]:
def get_data(self, data) -> tuple[Ndarray, dict]:
"""
Extract data array and metadata from loaded image and return them.
This function returns two objects, first is numpy array of image data, second is dict of metadata.
This function returns two objects, first is numpy or cupy array of image data, second is dict of metadata.
It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict.
For dicom series within the input, all slices will be stacked first,
When loading a list of files (dicom file, or stacked dicom series), they are stacked together at a new
Expand All @@ -637,6 +662,9 @@ def get_data(self, data) -> tuple[np.ndarray, dict]:
data: a pydicom dataset object, or a list of pydicom dataset objects, or a list of list of
pydicom dataset objects.

Returns:
A tuple of (image_array, metadata_dict).

"""

dicom_data = []
Expand All @@ -663,10 +691,7 @@ def get_data(self, data) -> tuple[np.ndarray, dict]:
metadata[MetaKeys.SPATIAL_SHAPE] = data_array.shape
dicom_data.append((data_array, metadata))

# TODO: the actual type is list[np.ndarray | cp.ndarray]
# should figure out how to define correct types without having cupy not found error
# https://github.com/Project-MONAI/MONAI/pull/8188#discussion_r1886645918
img_array: list[np.ndarray] = []
img_array: list[Ndarray] = []
compatible_meta: dict = {}

for data_array, metadata in ensure_tuple(dicom_data):
Expand Down Expand Up @@ -841,7 +866,7 @@ def _get_seg_data(self, img, filename):
if self.label_dict is not None:
metadata["labels"] = self.label_dict
if self.to_gpu:
all_segs = cp.zeros([*spatial_shape, len(self.label_dict)], dtype=array_data.dtype)
all_segs: Ndarray = cp.zeros([*spatial_shape, len(self.label_dict)], dtype=array_data.dtype)
else:
all_segs = np.zeros([*spatial_shape, len(self.label_dict)], dtype=array_data.dtype)
else:
Expand Down Expand Up @@ -899,7 +924,7 @@ def _get_seg_data(self, img, filename):

return all_segs, metadata

def _get_array_data_from_gpu(self, img, filename):
def _get_array_data_from_gpu(self, img, filename) -> Ndarray:
"""
Get the raw array data of the image. This function is used when `to_gpu` is set to True.

Expand Down Expand Up @@ -954,7 +979,7 @@ def _get_array_data_from_gpu(self, img, filename):

return data

def _get_array_data(self, img, filename):
def _get_array_data(self, img, filename) -> Ndarray:
"""
Get the array data of the image. If `RescaleSlope` and `RescaleIntercept` are available, the raw array data
will be rescaled. The output data has the dtype float32 if the rescaling is applied.
Expand Down Expand Up @@ -1092,22 +1117,22 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
img_.append(img) # type: ignore
return img_ if len(filenames) > 1 else img_[0]

def get_data(self, img) -> tuple[np.ndarray, dict]:
def get_data(self, img) -> tuple[Ndarray, dict]:
"""
Extract data array and metadata from loaded image and return them.
This function returns two objects, first is numpy array of image data, second is dict of metadata.
This function returns two objects, first is numpy or cupy array of image data, second is dict of metadata.
It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict.
When loading a list of files, they are stacked together at a new dimension as the first dimension,
and the metadata of the first image is used to present the output metadata.

Args:
img: a Nibabel image object loaded from an image file or a list of Nibabel image objects.

Returns:
A tuple of (image_array, metadata_dict).

"""
# TODO: the actual type is list[np.ndarray | cp.ndarray]
# should figure out how to define correct types without having cupy not found error
# https://github.com/Project-MONAI/MONAI/pull/8188#discussion_r1886645918
img_array: list[np.ndarray] = []
img_array: list[Ndarray] = []
compatible_meta: dict = {}

for i, filename in zip(ensure_tuple(img), self.filenames):
Expand Down Expand Up @@ -1186,7 +1211,7 @@ def _get_spatial_shape(self, img):
spatial_rank = max(min(ndim, 3), 1)
return np.asarray(size[:spatial_rank])

def _get_array_data(self, img, filename):
def _get_array_data(self, img, filename) -> Ndarray:
"""
Get the raw array data of the image, converted to Numpy array.

Expand Down Expand Up @@ -1281,19 +1306,22 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs):

return img_ if len(img_) > 1 else img_[0]

def get_data(self, img) -> tuple[np.ndarray, dict]:
def get_data(self, img) -> tuple[Ndarray, dict]:
"""
Extract data array and metadata from loaded image and return them.
This function returns two objects, first is numpy array of image data, second is dict of metadata.
This function returns two objects, first is numpy or cupy array of image data, second is dict of metadata.
It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict.
When loading a list of files, they are stacked together at a new dimension as the first dimension,
and the metadata of the first image is used to represent the output metadata.

Args:
img: a Numpy array loaded from a file or a list of Numpy arrays.

Returns:
A tuple of (image_array, metadata_dict).

"""
img_array: list[np.ndarray] = []
img_array: list[Ndarray] = []
compatible_meta: dict = {}
if isinstance(img, np.ndarray):
img = (img,)
Expand Down Expand Up @@ -1374,10 +1402,10 @@ def read(self, data: Sequence[PathLike] | PathLike | np.ndarray, **kwargs):

return img_ if len(filenames) > 1 else img_[0]

def get_data(self, img) -> tuple[np.ndarray, dict]:
def get_data(self, img) -> tuple[Ndarray, dict]:
"""
Extract data array and metadata from loaded image and return them.
This function returns two objects, first is numpy array of image data, second is dict of metadata.
This function returns two objects, first is numpy or cupy array of image data, second is dict of metadata.
It computes `spatial_shape` and stores it in meta dict.
When loading a list of files, they are stacked together at a new dimension as the first dimension,
and the metadata of the first image is used to represent the output metadata.
Expand All @@ -1387,8 +1415,11 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
Args:
img: a PIL Image object loaded from a file or a list of PIL Image objects.

Returns:
A tuple of (image_array, metadata_dict).

"""
img_array: list[np.ndarray] = []
img_array: list[Ndarray] = []
compatible_meta: dict = {}

for i in ensure_tuple(img):
Expand Down Expand Up @@ -1425,7 +1456,7 @@ def _get_spatial_shape(self, img):
class NrrdImage:
"""Class to wrap nrrd image array and metadata header"""

array: np.ndarray
array: Ndarray
header: dict


Expand Down Expand Up @@ -1495,17 +1526,20 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs) -> Sequence[Any] |
img_.append(nrrd_image)
return img_ if len(filenames) > 1 else img_[0]

def get_data(self, img: NrrdImage | list[NrrdImage]) -> tuple[np.ndarray, dict]:
def get_data(self, img: NrrdImage | list[NrrdImage]) -> tuple[Ndarray, dict]:
"""
Extract data array and metadata from loaded image and return them.
This function must return two objects, the first is a numpy array of image data,
This function must return two objects, the first is a NumPy or CuPy array of image data,
the second is a dictionary of metadata.

Args:
img: a `NrrdImage` loaded from an image file or a list of image objects.

Returns:
A tuple of (image_array, metadata_dict).

"""
img_array: list[np.ndarray] = []
img_array: list[Ndarray] = []
compatible_meta: dict = {}

for i in ensure_tuple(img):
Expand Down
Loading