From 8a07c9e95e2a6621e675a9fda8ebdc5a994cd1cb Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Fri, 16 Jan 2026 14:54:59 +0800 Subject: [PATCH] Improve CuPy type hinting in ImageReader Signed-off-by: ytl0623 --- monai/data/image_reader.py | 104 ++++++++++++++++++++++++------------- 1 file changed, 69 insertions(+), 35 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 515bf38a39..c22da719f9 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -22,7 +22,7 @@ from collections.abc import Callable, Iterable, Iterator, Sequence from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Union import numpy as np from torch.utils.data._utils.collate import np_str_obj_array_pattern @@ -38,6 +38,7 @@ from monai.utils import MetaKeys, SpaceKeys, TraceKeys, ensure_tuple, optional_import, require_pkg if TYPE_CHECKING: + import cupy as cp import itk import nibabel as nib import nrrd @@ -45,7 +46,9 @@ from nibabel.nifti1 import Nifti1Image from PIL import Image as PILImage - has_nrrd = has_itk = has_nib = has_pil = has_pydicom = True + has_nrrd = has_itk = has_nib = has_pil = has_pydicom = has_cp = True + Ndarray = Union[np.ndarray, cp.ndarray] + else: itk, has_itk = optional_import("itk", allow_namespace_pkg=True) nib, has_nib = optional_import("nibabel") @@ -54,7 +57,12 @@ pydicom, has_pydicom = optional_import("pydicom") nrrd, has_nrrd = optional_import("nrrd", allow_namespace_pkg=True) -cp, has_cp = optional_import("cupy") + cp, has_cp = optional_import("cupy") + if has_cp: + Ndarray = Union[np.ndarray, cp.ndarray] + else: + Ndarray = np.ndarray + kvikio, has_kvikio = optional_import("kvikio") __all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"] @@ -107,15 +115,18 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs) -> Sequence[Any] | raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") @abstractmethod - def get_data(self, img) -> tuple[np.ndarray, dict]: + def get_data(self, img) -> tuple[Ndarray, dict]: """ Extract data array and metadata from loaded image and return them. - This function must return two objects, the first is a numpy array of image data, + This function must return two objects, the first is a NumPy or CuPy array of image data, the second is a dictionary of metadata. Args: img: an image object loaded from an image file or a list of image objects. + Returns: + A tuple of (image_array, metadata_dict). + """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") @@ -143,7 +154,18 @@ def _copy_compatible_dict(from_dict: dict, to_dict: dict): ) -def _stack_images(image_list: list, meta_dict: dict, to_cupy: bool = False): +def _stack_images(image_list: list[Ndarray], meta_dict: dict, to_cupy: bool = False) -> Ndarray: + """ + Stack image arrays and update channel metadata. + + Args: + image_list: List of image arrays to stack. + meta_dict: Metadata dict to update. + to_cupy: If True, stack using CuPy. + + Returns: + The stacked image array. + """ if len(image_list) <= 1: return image_list[0] if not is_no_channel(meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None)): @@ -269,10 +291,10 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): img_.append(itk.imread(name, **kwargs_)) return img_ if len(filenames) > 1 else img_[0] - def get_data(self, img) -> tuple[np.ndarray, dict]: + def get_data(self, img) -> tuple[Ndarray, dict]: """ Extract data array and metadata from loaded image and return them. - This function returns two objects, first is numpy array of image data, second is dict of metadata. + This function returns two objects, first is numpy or cupy array of image data, second is dict of metadata. It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict. When loading a list of files, they are stacked together at a new dimension as the first dimension, and the metadata of the first image is used to represent the output metadata. @@ -280,8 +302,11 @@ def get_data(self, img) -> tuple[np.ndarray, dict]: Args: img: an ITK image object loaded from an image file or a list of ITK image objects. + Returns: + A tuple of (image_array, metadata_dict). + """ - img_array: list[np.ndarray] = [] + img_array: list[Ndarray] = [] compatible_meta: dict = {} for i in ensure_tuple(img): @@ -616,10 +641,10 @@ def _combine_dicom_series(self, data: Iterable, filenames: Sequence[PathLike]): return stack_array, stack_metadata - def get_data(self, data) -> tuple[np.ndarray, dict]: + def get_data(self, data) -> tuple[Ndarray, dict]: """ Extract data array and metadata from loaded image and return them. - This function returns two objects, first is numpy array of image data, second is dict of metadata. + This function returns two objects, first is numpy or cupy array of image data, second is dict of metadata. It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict. For dicom series within the input, all slices will be stacked first, When loading a list of files (dicom file, or stacked dicom series), they are stacked together at a new @@ -637,6 +662,9 @@ def get_data(self, data) -> tuple[np.ndarray, dict]: data: a pydicom dataset object, or a list of pydicom dataset objects, or a list of list of pydicom dataset objects. + Returns: + A tuple of (image_array, metadata_dict). + """ dicom_data = [] @@ -663,10 +691,7 @@ def get_data(self, data) -> tuple[np.ndarray, dict]: metadata[MetaKeys.SPATIAL_SHAPE] = data_array.shape dicom_data.append((data_array, metadata)) - # TODO: the actual type is list[np.ndarray | cp.ndarray] - # should figure out how to define correct types without having cupy not found error - # https://github.com/Project-MONAI/MONAI/pull/8188#discussion_r1886645918 - img_array: list[np.ndarray] = [] + img_array: list[Ndarray] = [] compatible_meta: dict = {} for data_array, metadata in ensure_tuple(dicom_data): @@ -841,7 +866,7 @@ def _get_seg_data(self, img, filename): if self.label_dict is not None: metadata["labels"] = self.label_dict if self.to_gpu: - all_segs = cp.zeros([*spatial_shape, len(self.label_dict)], dtype=array_data.dtype) + all_segs: Ndarray = cp.zeros([*spatial_shape, len(self.label_dict)], dtype=array_data.dtype) else: all_segs = np.zeros([*spatial_shape, len(self.label_dict)], dtype=array_data.dtype) else: @@ -899,7 +924,7 @@ def _get_seg_data(self, img, filename): return all_segs, metadata - def _get_array_data_from_gpu(self, img, filename): + def _get_array_data_from_gpu(self, img, filename) -> Ndarray: """ Get the raw array data of the image. This function is used when `to_gpu` is set to True. @@ -954,7 +979,7 @@ def _get_array_data_from_gpu(self, img, filename): return data - def _get_array_data(self, img, filename): + def _get_array_data(self, img, filename) -> Ndarray: """ Get the array data of the image. If `RescaleSlope` and `RescaleIntercept` are available, the raw array data will be rescaled. The output data has the dtype float32 if the rescaling is applied. @@ -1092,10 +1117,10 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): img_.append(img) # type: ignore return img_ if len(filenames) > 1 else img_[0] - def get_data(self, img) -> tuple[np.ndarray, dict]: + def get_data(self, img) -> tuple[Ndarray, dict]: """ Extract data array and metadata from loaded image and return them. - This function returns two objects, first is numpy array of image data, second is dict of metadata. + This function returns two objects, first is numpy or cupy array of image data, second is dict of metadata. It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict. When loading a list of files, they are stacked together at a new dimension as the first dimension, and the metadata of the first image is used to present the output metadata. @@ -1103,11 +1128,11 @@ def get_data(self, img) -> tuple[np.ndarray, dict]: Args: img: a Nibabel image object loaded from an image file or a list of Nibabel image objects. + Returns: + A tuple of (image_array, metadata_dict). + """ - # TODO: the actual type is list[np.ndarray | cp.ndarray] - # should figure out how to define correct types without having cupy not found error - # https://github.com/Project-MONAI/MONAI/pull/8188#discussion_r1886645918 - img_array: list[np.ndarray] = [] + img_array: list[Ndarray] = [] compatible_meta: dict = {} for i, filename in zip(ensure_tuple(img), self.filenames): @@ -1186,7 +1211,7 @@ def _get_spatial_shape(self, img): spatial_rank = max(min(ndim, 3), 1) return np.asarray(size[:spatial_rank]) - def _get_array_data(self, img, filename): + def _get_array_data(self, img, filename) -> Ndarray: """ Get the raw array data of the image, converted to Numpy array. @@ -1281,10 +1306,10 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): return img_ if len(img_) > 1 else img_[0] - def get_data(self, img) -> tuple[np.ndarray, dict]: + def get_data(self, img) -> tuple[Ndarray, dict]: """ Extract data array and metadata from loaded image and return them. - This function returns two objects, first is numpy array of image data, second is dict of metadata. + This function returns two objects, first is numpy or cupy array of image data, second is dict of metadata. It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict. When loading a list of files, they are stacked together at a new dimension as the first dimension, and the metadata of the first image is used to represent the output metadata. @@ -1292,8 +1317,11 @@ def get_data(self, img) -> tuple[np.ndarray, dict]: Args: img: a Numpy array loaded from a file or a list of Numpy arrays. + Returns: + A tuple of (image_array, metadata_dict). + """ - img_array: list[np.ndarray] = [] + img_array: list[Ndarray] = [] compatible_meta: dict = {} if isinstance(img, np.ndarray): img = (img,) @@ -1374,10 +1402,10 @@ def read(self, data: Sequence[PathLike] | PathLike | np.ndarray, **kwargs): return img_ if len(filenames) > 1 else img_[0] - def get_data(self, img) -> tuple[np.ndarray, dict]: + def get_data(self, img) -> tuple[Ndarray, dict]: """ Extract data array and metadata from loaded image and return them. - This function returns two objects, first is numpy array of image data, second is dict of metadata. + This function returns two objects, first is numpy or cupy array of image data, second is dict of metadata. It computes `spatial_shape` and stores it in meta dict. When loading a list of files, they are stacked together at a new dimension as the first dimension, and the metadata of the first image is used to represent the output metadata. @@ -1387,8 +1415,11 @@ def get_data(self, img) -> tuple[np.ndarray, dict]: Args: img: a PIL Image object loaded from a file or a list of PIL Image objects. + Returns: + A tuple of (image_array, metadata_dict). + """ - img_array: list[np.ndarray] = [] + img_array: list[Ndarray] = [] compatible_meta: dict = {} for i in ensure_tuple(img): @@ -1425,7 +1456,7 @@ def _get_spatial_shape(self, img): class NrrdImage: """Class to wrap nrrd image array and metadata header""" - array: np.ndarray + array: Ndarray header: dict @@ -1495,17 +1526,20 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs) -> Sequence[Any] | img_.append(nrrd_image) return img_ if len(filenames) > 1 else img_[0] - def get_data(self, img: NrrdImage | list[NrrdImage]) -> tuple[np.ndarray, dict]: + def get_data(self, img: NrrdImage | list[NrrdImage]) -> tuple[Ndarray, dict]: """ Extract data array and metadata from loaded image and return them. - This function must return two objects, the first is a numpy array of image data, + This function must return two objects, the first is a NumPy or CuPy array of image data, the second is a dictionary of metadata. Args: img: a `NrrdImage` loaded from an image file or a list of image objects. + Returns: + A tuple of (image_array, metadata_dict). + """ - img_array: list[np.ndarray] = [] + img_array: list[Ndarray] = [] compatible_meta: dict = {} for i in ensure_tuple(img):