diff --git a/docs/requirements.txt b/docs/requirements.txt index db6ea92bd8..9d9cdebb3e 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -2,7 +2,7 @@ torch>=1.5 pytorch-ignite==0.4.5 numpy>=1.17 -itk>=5.0, <=5.1.2 +itk>=5.2 nibabel parameterized scikit-image>=0.14.2 diff --git a/monai/config/deviceconfig.py b/monai/config/deviceconfig.py index 27bd135471..273431fc72 100644 --- a/monai/config/deviceconfig.py +++ b/monai/config/deviceconfig.py @@ -21,14 +21,6 @@ import monai from monai.utils.module import OptionalImportError, get_package_version, optional_import -try: - import itk # type: ignore - - itk_version = itk.Version.GetITKVersion() - del itk -except (ImportError, AttributeError): - itk_version = "NOT INSTALLED or UNKNOWN VERSION." - try: _, HAS_EXT = optional_import("monai._C") USE_COMPILED = HAS_EXT and os.getenv("BUILD_MONAI", "0") == "1" @@ -76,7 +68,6 @@ def get_optional_config_values(): output["Tensorboard"] = get_package_version("tensorboard") output["gdown"] = get_package_version("gdown") output["TorchVision"] = get_package_version("torchvision") - output["ITK"] = itk_version output["tqdm"] = get_package_version("tqdm") output["lmdb"] = get_package_version("lmdb") output["psutil"] = psutil_version diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index cfe8f29f04..11ed768eb7 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -29,14 +29,12 @@ import itk # type: ignore import nibabel as nib import openslide - from itk import Image # type: ignore from nibabel.nifti1 import Nifti1Image from PIL import Image as PILImage has_itk = has_nib = has_pil = has_cim = has_osl = True else: itk, has_itk = optional_import("itk", allow_namespace_pkg=True) - Image, _ = optional_import("itk", allow_namespace_pkg=True, name="Image") nib, has_nib = optional_import("nibabel") Nifti1Image, _ = optional_import("nibabel.nifti1", name="Nifti1Image") PILImage, has_pil = optional_import("PIL.Image") @@ -85,7 +83,7 @@ def get_data(self, img) -> Tuple[np.ndarray, Dict]: This function must return 2 objects, first is numpy array of image data, second is dict of meta data. Args: - img: an image object loaded from a image file or a list of image objects. + img: an image object loaded from an image file or a list of image objects. """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") @@ -132,19 +130,23 @@ class ITKReader(ImageReader): array index order will be `CDWH`. Args: + channel_dim: the channel dimension of the input image, default is None. + This is used to set `original_channel_dim` in the meta data, `EnsureChannelFirstD` reads this field. + If None, `original_channel_dim` will be either `no_channel` or `-1`. + - Nifti file is usually "channel last", so there is no need to specify this argument. + - PNG file usually has `GetNumberOfComponentsPerPixel()==3`, so there is no need to specify this argument. + series_name: the name of the DICOM series if there are multiple ones. + used when loading DICOM series. kwargs: additional args for `itk.imread` API. more details about available args: https://github.com/InsightSoftwareConsortium/ITK/blob/master/Wrapping/Generators/Python/itkExtras.py """ - def __init__(self, **kwargs): + def __init__(self, channel_dim: Optional[int] = None, series_name: str = "", **kwargs): super().__init__() self.kwargs = kwargs - if has_itk and int(itk.Version.GetITKMajorVersion()) == 5 and int(itk.Version.GetITKMinorVersion()) < 2: - # warning the ITK LazyLoading mechanism was not threadsafe until version 5.2.0, - # requesting access to the itk.imread function triggers the lazy loading of the relevant itk modules - # before the parallel use of the function. - _ = itk.imread + self.channel_dim = channel_dim + self.series_name = series_name def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: """ @@ -169,26 +171,26 @@ def read(self, data: Union[Sequence[str], str], **kwargs): https://github.com/InsightSoftwareConsortium/ITK/blob/master/Wrapping/Generators/Python/itkExtras.py """ - img_: List[Image] = [] + img_ = [] filenames: Sequence[str] = ensure_tuple(data) kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: if os.path.isdir(name): - # read DICOM series of 1 image in a folder, refer to: https://github.com/RSIP-Vision/medio + # read DICOM series + # https://itk.org/ITKExamples/src/IO/GDCM/ReadDICOMSeriesAndWrite3DImage names_generator = itk.GDCMSeriesFileNames.New() names_generator.SetUseSeriesDetails(True) names_generator.AddSeriesRestriction("0008|0021") # Series Date names_generator.SetDirectory(name) series_uid = names_generator.GetSeriesUIDs() - if len(series_uid) == 0: + if len(series_uid) < 1: raise FileNotFoundError(f"no DICOMs in: {name}.") if len(series_uid) > 1: - raise OSError(f"the directory: {name} contains more than one DICOM series.") - - series_identifier = series_uid[0] + warnings.warn(f"the directory: {name} contains more than one DICOM series.") + series_identifier = series_uid[0] if not self.series_name else self.series_name name = names_generator.GetFileNames(series_identifier) img_.append(itk.imread(name, **kwargs_)) @@ -197,26 +199,29 @@ def read(self, data: Union[Sequence[str], str], **kwargs): def get_data(self, img): """ Extract data array and meta data from loaded image and return them. - This function returns 2 objects, first is numpy array of image data, second is dict of meta data. - It constructs `affine`, `original_affine`, and `spatial_shape` and stores in meta dict. - If loading a list of files, stack them together and add a new dimension as first dimension, - and use the meta data of the first image to represent the stacked result. + This function returns two objects, first is numpy array of image data, second is dict of meta data. + It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict. + When loading a list of files, they are concatenated together at a new dimension as the first dimension, + and the meta data of the first image is used to represent the output meta data. Args: - img: a ITK image object loaded from a image file or a list of ITK image objects. + img: an ITK image object loaded from an image file or a list of ITK image objects. """ img_array: List[np.ndarray] = [] compatible_meta: Dict = {} for i in ensure_tuple(img): + data = self._get_array_data(i) + img_array.append(data) header = self._get_meta_dict(i) header["original_affine"] = self._get_affine(i) header["affine"] = header["original_affine"].copy() header["spatial_shape"] = self._get_spatial_shape(i) - data = self._get_array_data(i) - img_array.append(data) - header["original_channel_dim"] = "no_channel" if len(data.shape) == len(header["spatial_shape"]) else -1 + if self.channel_dim is None: # default to "no_channel" or -1 + header["original_channel_dim"] = "no_channel" if len(data.shape) == len(header["spatial_shape"]) else -1 + else: + header["original_channel_dim"] = self.channel_dim _copy_compatible_dict(header, compatible_meta) return _stack_images(img_array, compatible_meta), compatible_meta @@ -226,40 +231,22 @@ def _get_meta_dict(self, img) -> Dict: Get all the meta data of the image and convert to dict type. Args: - img: a ITK image object loaded from a image file. + img: an ITK image object loaded from an image file. """ img_meta_dict = img.GetMetaDataDictionary() - meta_dict = {} - for key in img_meta_dict.GetKeys(): - # ignore deprecated, legacy members that cause issues - if key.startswith("ITK_original_"): - continue - if ( - key == "NRRD_measurement frame" - and int(itk.Version.GetITKMajorVersion()) == 5 - and int(itk.Version.GetITKMinorVersion()) < 2 - ): - warnings.warn( - "Ignoring 'measurement frame' field. " - "Correct reading of NRRD05 files requires ITK >= 5.2: `pip install --upgrade --pre itk`" - ) - continue - meta_dict[key] = img_meta_dict[key] - meta_dict["origin"] = np.asarray(img.GetOrigin()) + meta_dict = {key: img_meta_dict[key] for key in img_meta_dict.GetKeys() if not key.startswith("ITK_")} + meta_dict["spacing"] = np.asarray(img.GetSpacing()) - meta_dict["direction"] = itk.array_from_matrix(img.GetDirection()) return meta_dict def _get_affine(self, img): """ Get or construct the affine matrix of the image, it can be used to correct spacing, orientation or execute spatial transforms. - Construct Affine matrix based on direction, spacing, origin information. - Refer to: https://github.com/RSIP-Vision/medio Args: - img: a ITK image object loaded from a image file. + img: an ITK image object loaded from an image file. """ direction = itk.array_from_matrix(img.GetDirection()) @@ -267,23 +254,32 @@ def _get_affine(self, img): origin = np.asarray(img.GetOrigin()) direction = np.asarray(direction) - affine: np.ndarray = np.eye(direction.shape[0] + 1) - affine[(slice(-1), slice(-1))] = direction @ np.diag(spacing) - affine[(slice(-1), -1)] = origin + sr = min(max(direction.shape[0], 1), 3) + affine: np.ndarray = np.eye(sr + 1) + affine[:sr, :sr] = direction[:sr, :sr] @ np.diag(spacing[:sr]) + affine[:sr, -1] = origin[:sr] + flip_diag = [[-1, 1], [-1, -1, 1], [-1, -1, 1, 1]][sr - 1] # itk to nibabel affine + affine = np.diag(flip_diag) @ affine return affine def _get_spatial_shape(self, img): """ - Get the spatial shape of image data, it doesn't contain the channel dim. + Get the spatial shape of `img`. Args: - img: a ITK image object loaded from a image file. + img: an ITK image object loaded from an image file. """ - # the img data should have no channel dim or the last dim is channel - shape = list(itk.size(img)) - shape.reverse() - return np.asarray(shape) + # the img data should have no channel dim + + sr = itk.array_from_matrix(img.GetDirection()).shape[0] + sr = max(min(sr, 3), 1) + _size = list(itk.size(img)) + if self.channel_dim is not None: + # channel_dim is given in the numpy convention, which is different from ITK + # size is reversed + _size.pop(-self.channel_dim) + return np.asarray(_size[:sr]) def _get_array_data(self, img): """ @@ -294,21 +290,16 @@ def _get_array_data(self, img): The first axis of the returned array is the channel axis. Args: - img: a ITK image object loaded from a image file. + img: an ITK image object loaded from an image file. """ channels = img.GetNumberOfComponentsPerPixel() + np_data = itk.array_view_from_image(img).T if channels == 1: - return itk.array_view_from_image(img, keep_axes=False) - # The memory layout of itk.Image has all pixel's channels adjacent - # in memory, i.e. R1G1B1R2G2B2R3G3B3. For PyTorch/MONAI, we need - # channels to be contiguous, i.e. R1R2R3G1G2G3B1B2B3. - arr = itk.array_view_from_image(img, keep_axes=False) - dest = list(range(img.ndim)) - source = dest.copy() - end = source.pop() - source.insert(0, end) - return np.moveaxis(arr, source, dest) + return np_data + if channels != np_data.shape[0]: + warnings.warn("itk_img.GetNumberOfComponentsPerPixel != numpy data channels") + return np.moveaxis(np_data, 0, -1) # channel last is compatible with `write_nifti` class NibabelReader(ImageReader): @@ -366,13 +357,13 @@ def read(self, data: Union[Sequence[str], str], **kwargs): def get_data(self, img): """ Extract data array and meta data from loaded image and return them. - This function returns 2 objects, first is numpy array of image data, second is dict of meta data. - It constructs `affine`, `original_affine`, and `spatial_shape` and stores in meta dict. - If loading a list of files, stack them together and add a new dimension as first dimension, - and use the meta data of the first image to represent the stacked result. + This function returns two objects, first is numpy array of image data, second is dict of meta data. + It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict. + When loading a list of files, they are concatenated together at a new dimension as the first dimension, + and the meta data of the first image is used to present the output meta data. Args: - img: a Nibabel image object loaded from a image file or a list of Nibabel image objects. + img: a Nibabel image object loaded from an image file or a list of Nibabel image objects. """ img_array: List[np.ndarray] = [] @@ -399,7 +390,7 @@ def _get_meta_dict(self, img) -> Dict: Get the all the meta data of the image and convert to dict type. Args: - img: a Nibabel image object loaded from a image file. + img: a Nibabel image object loaded from an image file. """ # swap to little endian as PyTorch doesn't support big endian @@ -412,7 +403,7 @@ def _get_affine(self, img): spacing, orientation or execute spatial transforms. Args: - img: a Nibabel image object loaded from a image file. + img: a Nibabel image object loaded from an image file. """ return np.array(img.affine, copy=True) @@ -422,7 +413,7 @@ def _get_spatial_shape(self, img): Get the spatial shape of image data, it doesn't contain the channel dim. Args: - img: a Nibabel image object loaded from a image file. + img: a Nibabel image object loaded from an image file. """ # swap to little endian as PyTorch doesn't support big endian @@ -437,7 +428,7 @@ def _get_array_data(self, img): Get the raw array data of the image, converted to Numpy array. Args: - img: a Nibabel image object loaded from a image file. + img: a Nibabel image object loaded from an image file. """ _array = np.array(img.get_fdata(dtype=self.dtype)) @@ -508,11 +499,11 @@ def read(self, data: Union[Sequence[str], str], **kwargs): def get_data(self, img): """ - Extract data array and meta data from loaded data and return them. - This function returns 2 objects, first is numpy array of image data, second is dict of meta data. - It constructs `spatial_shape=data.shape` and stores in meta dict if the data is numpy array. - If loading a list of files, stack them together and add a new dimension as first dimension, - and use the meta data of the first image to represent the stacked result. + Extract data array and meta data from loaded image and return them. + This function returns two objects, first is numpy array of image data, second is dict of meta data. + It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict. + When loading a list of files, they are concatenated together at a new dimension as the first dimension, + and the meta data of the first image is used to represent the output meta data. Args: img: a Numpy array loaded from a file or a list of Numpy arrays. @@ -588,11 +579,11 @@ def read(self, data: Union[Sequence[str], str, np.ndarray], **kwargs): def get_data(self, img): """ - Extract data array and meta data from loaded data and return them. - This function returns 2 objects, first is numpy array of image data, second is dict of meta data. - It constructs `spatial_shape` and stores in meta dict. - If loading a list of files, stack them together and add a new dimension as first dimension, - and use the meta data of the first image to represent the stacked result. + Extract data array and meta data from loaded image and return them. + This function returns two objects, first is numpy array of image data, second is dict of meta data. + It computes `spatial_shape` and stores it in meta dict. + When loading a list of files, they are concatenated together at a new dimension as the first dimension, + and the meta data of the first image is used to represent the output meta data. Args: img: a PIL Image object loaded from a file or a list of PIL Image objects. @@ -604,7 +595,7 @@ def get_data(self, img): for i in ensure_tuple(img): header = self._get_meta_dict(i) header["spatial_shape"] = self._get_spatial_shape(i) - data = np.asarray(i) + data = np.moveaxis(np.asarray(i), 0, 1) img_array.append(data) header["original_channel_dim"] = "no_channel" if len(data.shape) == len(header["spatial_shape"]) else -1 _copy_compatible_dict(header, compatible_meta) @@ -615,7 +606,7 @@ def _get_meta_dict(self, img) -> Dict: """ Get the all the meta data of the image and convert to dict type. Args: - img: a PIL Image object loaded from a image file. + img: a PIL Image object loaded from an image file. """ return { @@ -629,9 +620,8 @@ def _get_spatial_shape(self, img): """ Get the spatial shape of image data, it doesn't contain the channel dim. Args: - img: a PIL Image object loaded from a image file. + img: a PIL Image object loaded from an image file. """ - # the img data should have no channel dim or the last dim is channel return np.asarray((img.width, img.height)) diff --git a/monai/data/png_writer.py b/monai/data/png_writer.py index a3d3135eef..2baec3b872 100644 --- a/monai/data/png_writer.py +++ b/monai/data/png_writer.py @@ -80,6 +80,7 @@ def write_png( if data.dtype not in (np.uint8, np.uint16): # type: ignore data = data.astype(np.uint8) + data = np.moveaxis(data, 0, 1) img = Image.fromarray(data) img.save(file_name, "PNG") return diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index 346533f79c..5d6b4d87fd 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -97,7 +97,7 @@ def check_transforms_match(self, transform: dict) -> None: return # basic check if multiprocessing uses 'spawn' (objects get recreated so don't have same ID) if ( - torch.multiprocessing.get_start_method(allow_none=False) == "spawn" + torch.multiprocessing.get_start_method() in ("spawn", None) and transform[InverseKeys.CLASS_NAME] == self.__class__.__name__ ): return diff --git a/requirements-dev.txt b/requirements-dev.txt index 71dcaa3dca..785454ad5d 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,7 +3,7 @@ pytorch-ignite==0.4.5 gdown>=3.6.4 scipy -itk>=5.0, <=5.1.2 +itk>=5.2 nibabel pillow!=8.3.0 # https://github.com/python-pillow/Pillow/issues/5571 tensorboard diff --git a/setup.cfg b/setup.cfg index e8763b3318..6efe768a6f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -36,7 +36,7 @@ all = gdown>=3.6.4 pytorch-ignite==0.4.5 torchvision - itk>=5.0, <=5.1.2 + itk>=5.2 tqdm>=4.47.0 lmdb psutil @@ -59,7 +59,7 @@ ignite = torchvision = torchvision itk = - itk>=5.0, <=5.1.2 + itk>=5.2 tqdm = tqdm>=4.47.0 lmdb = diff --git a/tests/test_integration_classification_2d.py b/tests/test_integration_classification_2d.py index 4afba5f136..db435ee4e4 100644 --- a/tests/test_integration_classification_2d.py +++ b/tests/test_integration_classification_2d.py @@ -35,6 +35,7 @@ RandZoom, ScaleIntensity, ToTensor, + Transpose, ) from monai.utils import set_determinism from tests.testing_data.integration_answers import test_integration_value @@ -66,6 +67,7 @@ def run_training_test(root_dir, train_x, train_y, val_x, val_y, device="cuda:0", [ LoadImage(image_only=True), AddChannel(), + Transpose(indices=[0, 2, 1]), ScaleIntensity(), RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True), RandFlip(spatial_axis=0, prob=0.5), @@ -74,7 +76,9 @@ def run_training_test(root_dir, train_x, train_y, val_x, val_y, device="cuda:0", ] ) train_transforms.set_random_state(1234) - val_transforms = Compose([LoadImage(image_only=True), AddChannel(), ScaleIntensity(), ToTensor()]) + val_transforms = Compose( + [LoadImage(image_only=True), AddChannel(), Transpose(indices=[0, 2, 1]), ScaleIntensity(), ToTensor()] + ) y_pred_trans = Compose([ToTensor(), Activations(softmax=True)]) y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=True, n_classes=len(np.unique(train_y)))]) auc_metric = ROCAUCMetric() diff --git a/tests/test_integration_segmentation_3d.py b/tests/test_integration_segmentation_3d.py index 32c1ad941d..d5eb69f7af 100644 --- a/tests/test_integration_segmentation_3d.py +++ b/tests/test_integration_segmentation_3d.py @@ -28,9 +28,9 @@ from monai.networks.nets import UNet from monai.transforms import ( Activations, - AsChannelFirstd, AsDiscrete, Compose, + EnsureChannelFirstd, LoadImaged, RandCropByPosNegLabeld, RandRotate90d, @@ -47,7 +47,7 @@ TASK = "integration_segmentation_3d" -def run_training_test(root_dir, device="cuda:0", cachedataset=0): +def run_training_test(root_dir, device="cuda:0", cachedataset=0, readers=(None, None)): monai.config.print_config() images = sorted(glob(os.path.join(root_dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz"))) @@ -57,8 +57,8 @@ def run_training_test(root_dir, device="cuda:0", cachedataset=0): # define transforms for image and segmentation train_transforms = Compose( [ - LoadImaged(keys=["img", "seg"]), - AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), + LoadImaged(keys=["img", "seg"], reader=readers[0]), + EnsureChannelFirstd(keys=["img", "seg"]), # resampling with align_corners=True or dtype=float64 will generate # slight different results between PyTorch 1.5 an 1.6 Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32), @@ -73,8 +73,8 @@ def run_training_test(root_dir, device="cuda:0", cachedataset=0): train_transforms.set_random_state(1234) val_transforms = Compose( [ - LoadImaged(keys=["img", "seg"]), - AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), + LoadImaged(keys=["img", "seg"], reader=readers[1]), + EnsureChannelFirstd(keys=["img", "seg"]), # resampling with align_corners=True or dtype=float64 will generate # slight different results between PyTorch 1.5 an 1.6 Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32), @@ -184,7 +184,7 @@ def run_inference_test(root_dir, device="cuda:0"): val_transforms = Compose( [ LoadImaged(keys=["img", "seg"]), - AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), + EnsureChannelFirstd(keys=["img", "seg"]), # resampling with align_corners=True or dtype=float64 will generate # slight different results between PyTorch 1.5 an 1.6 Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32), @@ -249,21 +249,25 @@ def tearDown(self): def train_and_infer(self, idx=0): results = [] set_determinism(0) - losses, best_metric, best_metric_epoch = run_training_test(self.data_dir, device=self.device, cachedataset=idx) + _readers = (None, None) + if idx == 1: + _readers = ("itkreader", "itkreader") + elif idx == 2: + _readers = ("itkreader", "nibabelreader") + losses, best_metric, best_metric_epoch = run_training_test( + self.data_dir, device=self.device, cachedataset=idx, readers=_readers + ) infer_metric = run_inference_test(self.data_dir, device=self.device) # check training properties print("losses", losses) print("best metric", best_metric) print("infer metric", infer_metric) - self.assertTrue(test_integration_value(TASK, key="losses", data=losses, rtol=1e-3)) - self.assertTrue(test_integration_value(TASK, key="best_metric", data=best_metric, rtol=1e-2)) self.assertTrue(len(glob(os.path.join(self.data_dir, "runs"))) > 0) model_file = os.path.join(self.data_dir, "best_metric_model.pth") self.assertTrue(os.path.exists(model_file)) # check inference properties - self.assertTrue(test_integration_value(TASK, key="infer_metric", data=infer_metric, rtol=1e-2)) output_files = sorted(glob(os.path.join(self.data_dir, "output", "img*", "*.nii.gz"))) print([np.mean(nib.load(output).get_fdata()) for output in output_files]) results.extend(losses) @@ -272,6 +276,9 @@ def train_and_infer(self, idx=0): for output in output_files: ave = np.mean(nib.load(output).get_fdata()) results.append(ave) + self.assertTrue(test_integration_value(TASK, key="losses", data=results[:6], rtol=1e-3)) + self.assertTrue(test_integration_value(TASK, key="best_metric", data=results[6], rtol=1e-2)) + self.assertTrue(test_integration_value(TASK, key="infer_metric", data=results[7], rtol=1e-2)) self.assertTrue(test_integration_value(TASK, key="output_sums", data=results[8:], rtol=1e-2)) return results diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 31ef971078..fd1afbd857 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -623,7 +623,7 @@ def test_inverse(self, _, data_name, acceptable_diff, *transforms): self.check_inverse(name, data.keys(), forwards[-i - 2], fwd_bck, forwards[-1], acceptable_diff) # skip this test if multiprocessing uses 'spawn', as the check is only basic anyway - @skipUnless(torch.multiprocessing.get_start_method(allow_none=False) == "spawn", "requires spawn") + @skipUnless(torch.multiprocessing.get_start_method() == "spawn", "requires spawn") def test_fail(self): t1 = SpatialPadd("image", [10, 5]) diff --git a/tests/test_load_image.py b/tests/test_load_image.py index b7743f86ad..7b325e7565 100644 --- a/tests/test_load_image.py +++ b/tests/test_load_image.py @@ -63,13 +63,13 @@ TEST_CASE_10 = [ {"image_only": False, "reader": ITKReader(pixel_type=itk.UC)}, "tests/testing_data/CT_DICOM", - (4, 16, 16), + (16, 16, 4), ] TEST_CASE_11 = [ {"image_only": False, "reader": "ITKReader", "pixel_type": itk.UC}, "tests/testing_data/CT_DICOM", - (4, 16, 16), + (16, 16, 4), ] @@ -105,8 +105,9 @@ def test_itk_reader(self, input_param, filenames, expected_shape): result, header = result self.assertTrue("affine" in header) self.assertEqual(header["filename_or_obj"], os.path.join(tempdir, "test_image.nii.gz")) - np.testing.assert_allclose(header["affine"], np.eye(4)) - np.testing.assert_allclose(header["original_affine"], np.eye(4)) + np_diag = np.diag([-1, -1, 1, 1]) + np.testing.assert_allclose(header["affine"], np_diag) + np.testing.assert_allclose(header["original_affine"], np_diag) self.assertTupleEqual(result.shape, expected_shape) @parameterized.expand([TEST_CASE_10, TEST_CASE_11]) @@ -118,8 +119,8 @@ def test_itk_dicom_series_reader(self, input_param, filenames, expected_shape): header["affine"], np.array( [ - [0.488281, 0.0, 0.0, -125.0], - [0.0, 0.488281, 0.0, -128.100006], + [-0.488281, 0.0, 0.0, 125.0], + [0.0, -0.488281, 0.0, 128.100006], [0.0, 0.0, 68.33333333, -99.480003], [0.0, 0.0, 0.0, 1.0], ] @@ -129,28 +130,28 @@ def test_itk_dicom_series_reader(self, input_param, filenames, expected_shape): self.assertTupleEqual(tuple(header["spatial_shape"]), expected_shape) def test_itk_reader_multichannel(self): - test_image = np.random.randint(0, 256, size=(256, 256, 3)).astype("uint8") + test_image = np.random.randint(0, 256, size=(256, 224, 3)).astype("uint8") with tempfile.TemporaryDirectory() as tempdir: filename = os.path.join(tempdir, "test_image.png") itk_np_view = itk.image_view_from_array(test_image, is_vector=True) itk.imwrite(itk_np_view, filename) result, header = LoadImage(reader=ITKReader())(filename) - self.assertTupleEqual(tuple(header["spatial_shape"]), (256, 256)) - np.testing.assert_allclose(result[0, :, :], test_image[:, :, 0]) - np.testing.assert_allclose(result[1, :, :], test_image[:, :, 1]) - np.testing.assert_allclose(result[2, :, :], test_image[:, :, 2]) + self.assertTupleEqual(tuple(header["spatial_shape"]), (224, 256)) + np.testing.assert_allclose(result[:, :, 0], test_image[:, :, 0].T) + np.testing.assert_allclose(result[:, :, 1], test_image[:, :, 1].T) + np.testing.assert_allclose(result[:, :, 2], test_image[:, :, 2].T) def test_load_png(self): - spatial_size = (256, 256) + spatial_size = (256, 224) test_image = np.random.randint(0, 256, size=spatial_size) with tempfile.TemporaryDirectory() as tempdir: filename = os.path.join(tempdir, "test_image.png") Image.fromarray(test_image.astype("uint8")).save(filename) result, header = LoadImage(image_only=False)(filename) - self.assertTupleEqual(tuple(header["spatial_shape"]), spatial_size) - self.assertTupleEqual(result.shape, spatial_size) - np.testing.assert_allclose(result, test_image) + self.assertTupleEqual(tuple(header["spatial_shape"]), spatial_size[::-1]) + self.assertTupleEqual(result.shape, spatial_size[::-1]) + np.testing.assert_allclose(result.T, test_image) def test_register(self): spatial_size = (32, 64, 128) @@ -163,8 +164,8 @@ def test_register(self): loader = LoadImage(image_only=False) loader.register(ITKReader()) result, header = loader(filename) - self.assertTupleEqual(tuple(header["spatial_shape"]), spatial_size) - self.assertTupleEqual(result.shape, spatial_size) + self.assertTupleEqual(tuple(header["spatial_shape"]), spatial_size[::-1]) + self.assertTupleEqual(result.shape, spatial_size[::-1]) def test_kwargs(self): spatial_size = (32, 64, 128) diff --git a/tests/test_load_imaged.py b/tests/test_load_imaged.py index 978c3b6551..2877b1cd57 100644 --- a/tests/test_load_imaged.py +++ b/tests/test_load_imaged.py @@ -19,7 +19,7 @@ from parameterized import parameterized from monai.data import ITKReader -from monai.transforms import LoadImaged +from monai.transforms import Compose, EnsureChannelFirstD, LoadImaged, SaveImageD KEYS = ["image", "label", "extra"] @@ -53,8 +53,90 @@ def test_register(self): loader = LoadImaged(keys="img") loader.register(ITKReader()) result = loader({"img": filename}) - self.assertTupleEqual(tuple(result["img_meta_dict"]["spatial_shape"]), spatial_size) - self.assertTupleEqual(result["img"].shape, spatial_size) + self.assertTupleEqual(tuple(result["img_meta_dict"]["spatial_shape"]), spatial_size[::-1]) + self.assertTupleEqual(result["img"].shape, spatial_size[::-1]) + + def test_channel_dim(self): + spatial_size = (32, 64, 3, 128) + test_image = np.random.rand(*spatial_size) + with tempfile.TemporaryDirectory() as tempdir: + filename = os.path.join(tempdir, "test_image.nii.gz") + nib.save(nib.Nifti1Image(test_image, affine=np.eye(4)), filename) + + loader = LoadImaged(keys="img") + loader.register(ITKReader(channel_dim=2)) + result = EnsureChannelFirstD("img")(loader({"img": filename})) + self.assertTupleEqual(tuple(result["img_meta_dict"]["spatial_shape"]), (32, 64, 128)) + self.assertTupleEqual(result["img"].shape, (3, 32, 64, 128)) + + +class TestConsistency(unittest.TestCase): + def _cmp(self, filename, shape, ch_shape, reader_1, reader_2, outname, ext): + data_dict = {"img": filename} + keys = data_dict.keys() + xforms = Compose( + [ + LoadImaged(keys, reader=reader_1), + EnsureChannelFirstD(keys), + ] + ) + img_dict = xforms(data_dict) # load dicom with itk + self.assertTupleEqual(img_dict["img"].shape, ch_shape) + self.assertTupleEqual(tuple(img_dict["img_meta_dict"]["spatial_shape"]), shape) + + with tempfile.TemporaryDirectory() as tempdir: + save_xform = SaveImageD( + keys, meta_keys="img_meta_dict", output_dir=tempdir, squeeze_end_dims=False, output_ext=ext + ) + save_xform(img_dict) # save to nifti + + new_xforms = Compose( + [ + LoadImaged(keys, reader=reader_2), + EnsureChannelFirstD(keys), + ] + ) + out = new_xforms({"img": os.path.join(tempdir, outname)}) # load nifti with itk + self.assertTupleEqual(out["img"].shape, ch_shape) + self.assertTupleEqual(tuple(out["img_meta_dict"]["spatial_shape"]), shape) + if "affine" in img_dict["img_meta_dict"] and "affine" in out["img_meta_dict"]: + np.testing.assert_allclose( + img_dict["img_meta_dict"]["affine"], out["img_meta_dict"]["affine"], rtol=1e-3 + ) + np.testing.assert_allclose(out["img"], img_dict["img"], rtol=1e-3) + + def test_dicom(self): + img_dir = "tests/testing_data/CT_DICOM" + self._cmp( + img_dir, (16, 16, 4), (1, 16, 16, 4), "itkreader", "itkreader", "CT_DICOM/CT_DICOM_trans.nii.gz", ".nii.gz" + ) + output_name = "CT_DICOM/CT_DICOM_trans.nii.gz" + self._cmp(img_dir, (16, 16, 4), (1, 16, 16, 4), "nibabelreader", "itkreader", output_name, ".nii.gz") + self._cmp(img_dir, (16, 16, 4), (1, 16, 16, 4), "itkreader", "nibabelreader", output_name, ".nii.gz") + + def test_multi_dicom(self): + """multichannel dicom reading, saving to nifti, then load with itk or nibabel""" + + img_dir = ["tests/testing_data/CT_DICOM", "tests/testing_data/CT_DICOM"] + self._cmp( + img_dir, (16, 16, 4), (2, 16, 16, 4), "itkreader", "itkreader", "CT_DICOM/CT_DICOM_trans.nii.gz", ".nii.gz" + ) + output_name = "CT_DICOM/CT_DICOM_trans.nii.gz" + self._cmp(img_dir, (16, 16, 4), (2, 16, 16, 4), "nibabelreader", "itkreader", output_name, ".nii.gz") + self._cmp(img_dir, (16, 16, 4), (2, 16, 16, 4), "itkreader", "nibabelreader", output_name, ".nii.gz") + + def test_png(self): + """png reading with itk, saving to nifti, then load with itk or nibabel or PIL""" + + test_image = np.random.randint(0, 256, size=(256, 224, 3)).astype("uint8") + with tempfile.TemporaryDirectory() as tempdir: + filename = os.path.join(tempdir, "test_image.png") + itk_np_view = itk.image_view_from_array(test_image, is_vector=True) + itk.imwrite(itk_np_view, filename) + output_name = "test_image/test_image_trans.png" + self._cmp(filename, (224, 256), (3, 224, 256), "itkreader", "itkreader", output_name, ".png") + self._cmp(filename, (224, 256), (3, 224, 256), "itkreader", "PILReader", output_name, ".png") + self._cmp(filename, (224, 256), (3, 224, 256), "itkreader", "nibabelreader", output_name, ".png") if __name__ == "__main__": diff --git a/tests/test_loader_semaphore.py b/tests/test_loader_semaphore.py new file mode 100644 index 0000000000..85c6d54f35 --- /dev/null +++ b/tests/test_loader_semaphore.py @@ -0,0 +1,46 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""this test should not generate errors or +UserWarning: semaphore_tracker: There appear to be 1 leaked semaphores""" +import multiprocessing as mp +import unittest + +import monai # noqa + + +def w(): + pass + + +def _main(): + ps = mp.Process(target=w) + ps.start() + ps.join() + + +def _run_test(): + try: + tmp = mp.get_context("spawn") + except RuntimeError: + tmp = mp + p = tmp.Process(target=_main) + p.start() + p.join() + + +class TestImportLock(unittest.TestCase): + def test_start(self): + _run_test() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_pil_reader.py b/tests/test_pil_reader.py index 554b9ac737..0a076b581e 100644 --- a/tests/test_pil_reader.py +++ b/tests/test_pil_reader.py @@ -49,6 +49,7 @@ def test_shape_value(self, data_shape, filenames, expected_shape, meta_shape): self.assertTupleEqual(tuple(result[1]["spatial_shape"]), meta_shape) self.assertTupleEqual(result[0].shape, expected_shape) + test_image = np.moveaxis(test_image, 0, 1) if result[0].shape == test_image.shape: np.testing.assert_allclose(result[0], test_image) else: @@ -68,6 +69,7 @@ def test_converter(self, data_shape, filenames, expected_shape, meta_shape): self.assertTupleEqual(tuple(result[1]["spatial_shape"]), meta_shape) self.assertTupleEqual(result[0].shape, expected_shape) + test_image = np.moveaxis(test_image, 0, 1) np.testing.assert_allclose(result[0], test_image) diff --git a/tests/test_png_rw.py b/tests/test_png_rw.py index 815d0bcf2c..265b31b83b 100644 --- a/tests/test_png_rw.py +++ b/tests/test_png_rw.py @@ -27,6 +27,7 @@ def test_write_gray(self): img_save_val = (255 * img).astype(np.uint8) write_png(img, image_name, scale=255) out = np.asarray(Image.open(image_name)) + out = np.moveaxis(out, 0, 1) np.testing.assert_allclose(out, img_save_val) def test_write_gray_1height(self): @@ -36,6 +37,7 @@ def test_write_gray_1height(self): img_save_val = (65535 * img).astype(np.uint16) write_png(img, image_name, scale=65535) out = np.asarray(Image.open(image_name)) + out = np.moveaxis(out, 0, 1) np.testing.assert_allclose(out, img_save_val) def test_write_gray_1channel(self): @@ -45,6 +47,7 @@ def test_write_gray_1channel(self): img_save_val = (255 * img).astype(np.uint8).squeeze(2) write_png(img, image_name, scale=255) out = np.asarray(Image.open(image_name)) + out = np.moveaxis(out, 0, 1) np.testing.assert_allclose(out, img_save_val) def test_write_rgb(self): @@ -54,6 +57,7 @@ def test_write_rgb(self): img_save_val = (255 * img).astype(np.uint8) write_png(img, image_name, scale=255) out = np.asarray(Image.open(image_name)) + out = np.moveaxis(out, 0, 1) np.testing.assert_allclose(out, img_save_val) def test_write_2channels(self): @@ -63,6 +67,7 @@ def test_write_2channels(self): img_save_val = (255 * img).astype(np.uint8) write_png(img, image_name, scale=255) out = np.asarray(Image.open(image_name)) + out = np.moveaxis(out, 0, 1) np.testing.assert_allclose(out, img_save_val) def test_write_output_shape(self): diff --git a/tests/testing_data/integration_answers.py b/tests/testing_data/integration_answers.py index 623b5e9503..ccb4293a40 100644 --- a/tests/testing_data/integration_answers.py +++ b/tests/testing_data/integration_answers.py @@ -420,15 +420,28 @@ 0.03577899932861328, ], }, + "integration_segmentation_3d": { # for the mixed readers + "losses": [ + 0.5645154356956482, + 0.4984356611967087, + 0.472334086894989, + 0.47419720590114595, + 0.45881829261779783, + 0.43097741305828097, + ], + "best_metric": 0.9325698614120483, + "infer_metric": 0.9326590299606323, + }, }, ] def test_integration_value(test_name, key, data, rtol=1e-2): - for expected in EXPECTED_ANSWERS: + for (idx, expected) in enumerate(EXPECTED_ANSWERS): if test_name not in expected: continue value = expected[test_name][key] if np.allclose(data, value, rtol=rtol): + print(f"matched {idx} result of {test_name}, {key}, {rtol}.") return True raise ValueError(f"no matched results for {test_name}, {key}. {data}.") diff --git a/tests/utils.py b/tests/utils.py index 68ae8e4ec9..ce280a13f0 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -251,7 +251,6 @@ def __init__( self.timeout = datetime.timedelta(0, timeout) self.daemon = daemon self.method = method - self._original_method = torch.multiprocessing.get_start_method(allow_none=False) self.verbose = verbose def run_process(self, func, local_rank, args, kwargs, results): @@ -311,30 +310,19 @@ def __call__(self, obj): @functools.wraps(obj) def _wrapper(*args, **kwargs): - if self.method: - try: - torch.multiprocessing.set_start_method(self.method, force=True) - except (RuntimeError, ValueError): - pass + tmp = torch.multiprocessing.get_context(self.method) processes = [] - results = torch.multiprocessing.Queue() + results = tmp.Queue() func = _call_original_func args = [obj.__name__, obj.__module__] + list(args) for proc_rank in range(self.nproc_per_node): - p = torch.multiprocessing.Process( - target=self.run_process, args=(func, proc_rank, args, kwargs, results) + p = tmp.Process( + target=self.run_process, args=(func, proc_rank, args, kwargs, results), daemon=self.daemon ) - if self.daemon is not None: - p.daemon = self.daemon p.start() processes.append(p) for p in processes: p.join() - if self.method: - try: - torch.multiprocessing.set_start_method(self._original_method, force=True) - except (RuntimeError, ValueError): - pass assert results.get(), "Distributed call failed." return _wrapper @@ -372,7 +360,6 @@ def __init__( self.force_quit = force_quit self.skip_timing = skip_timing self.method = method - self._original_method = torch.multiprocessing.get_start_method(allow_none=False) # remember the default method @staticmethod def run_process(func, args, kwargs, results): @@ -392,18 +379,11 @@ def __call__(self, obj): @functools.wraps(obj) def _wrapper(*args, **kwargs): - - if self.method: - try: - torch.multiprocessing.set_start_method(self.method, force=True) - except (RuntimeError, ValueError): - pass + tmp = torch.multiprocessing.get_context(self.method) func = _call_original_func args = [obj.__name__, obj.__module__] + list(args) - results = torch.multiprocessing.Queue() - p = torch.multiprocessing.Process(target=TimedCall.run_process, args=(func, args, kwargs, results)) - if self.daemon is not None: - p.daemon = self.daemon + results = tmp.Queue() + p = tmp.Process(target=TimedCall.run_process, args=(func, args, kwargs, results), daemon=self.daemon) p.start() p.join(timeout=self.timeout_seconds) @@ -430,12 +410,6 @@ def _wrapper(*args, **kwargs): res = results.get(block=False) except queue.Empty: # no result returned, took too long pass - finally: - if self.method: - try: - torch.multiprocessing.set_start_method(self._original_method, force=True) - except (RuntimeError, ValueError): - pass if isinstance(res, Exception): # other errors from obj if hasattr(res, "traceback"): raise RuntimeError(res.traceback) from res @@ -585,7 +559,7 @@ def query_memory(n=2): ids = np.lexsort(free_memory)[:n] except (FileNotFoundError, TypeError, IndexError): ids = range(n) if isinstance(n, int) else [] - return ",".join([f"{int(x)}" for x in ids]) + return ",".join(f"{int(x)}" for x in ids) if __name__ == "__main__":