diff --git a/monai/apps/deepgrow/dataset.py b/monai/apps/deepgrow/dataset.py index 721781196b..c79434038b 100644 --- a/monai/apps/deepgrow/dataset.py +++ b/monai/apps/deepgrow/dataset.py @@ -15,8 +15,9 @@ import numpy as np -from monai.transforms import AsChannelFirstd, Compose, LoadImaged, Orientationd, Spacingd +from monai.transforms import AsChannelFirstd, Compose, FromMetaTensord, LoadImaged, Orientationd, Spacingd, ToNumpyd from monai.utils import GridSampleMode +from monai.utils.enums import PostFix def create_dataset( @@ -125,6 +126,8 @@ def _default_transforms(image_key, label_key, pixdim): return Compose( [ LoadImaged(keys=keys), + FromMetaTensord(keys=keys), + ToNumpyd(keys=keys + [PostFix.meta(k) for k in keys]), AsChannelFirstd(keys=keys), Orientationd(keys=keys, axcodes="RAS"), Spacingd(keys=keys, pixdim=pixdim, mode=mode), diff --git a/monai/data/dataset_summary.py b/monai/data/dataset_summary.py index b447585d3e..fe9618737b 100644 --- a/monai/data/dataset_summary.py +++ b/monai/data/dataset_summary.py @@ -18,9 +18,9 @@ from monai.config import KeysCollection from monai.data.dataloader import DataLoader from monai.data.dataset import Dataset +from monai.data.utils import affine_to_spacing from monai.transforms import concatenate -from monai.utils import convert_data_type -from monai.utils.enums import PostFix +from monai.utils import PostFix, convert_data_type DEFAULT_POST_FIX = PostFix.meta() @@ -84,7 +84,7 @@ def collect_meta_data(self): raise ValueError(f"To collect meta data for the dataset, key `{self.meta_key}` must exist in `data`.") self.all_meta_data.append(data[self.meta_key]) - def get_target_spacing(self, spacing_key: str = "pixdim", anisotropic_threshold: int = 3, percentile: float = 10.0): + def get_target_spacing(self, spacing_key: str = "affine", anisotropic_threshold: int = 3, percentile: float = 10.0): """ Calculate the target spacing according to all spacings. If the target spacing is very anisotropic, @@ -93,7 +93,7 @@ def get_target_spacing(self, spacing_key: str = "pixdim", anisotropic_threshold: After loading with `monai.DataLoader`, "pixdim" is in the form of `torch.Tensor` with size `(batch_size, 8)`. Args: - spacing_key: key of spacing in meta data (default: ``pixdim``). + spacing_key: key of the affine used to compute spacing in meta data (default: ``affine``). anisotropic_threshold: threshold to decide if the target spacing is anisotropic (default: ``3``). percentile: for anisotropic target spacing, use the percentile of all spacings of the anisotropic axis to replace that axis. @@ -103,7 +103,8 @@ def get_target_spacing(self, spacing_key: str = "pixdim", anisotropic_threshold: self.collect_meta_data() if spacing_key not in self.all_meta_data[0]: raise ValueError("The provided spacing_key is not in self.all_meta_data.") - all_spacings = concatenate(to_cat=[data[spacing_key][:, 1:4] for data in self.all_meta_data], axis=0) + spacings = [affine_to_spacing(data[spacing_key][0], 3)[None] for data in self.all_meta_data] + all_spacings = concatenate(to_cat=spacings, axis=0) all_spacings, *_ = convert_data_type(data=all_spacings, output_type=np.ndarray, wrap_sequence=True) target_spacing = np.median(all_spacings, axis=0) diff --git a/monai/data/image_dataset.py b/monai/data/image_dataset.py index 51f4e04959..61722c5490 100644 --- a/monai/data/image_dataset.py +++ b/monai/data/image_dataset.py @@ -86,7 +86,7 @@ def __init__( raise ValueError("transform_with_metadata=True requires image_only=False.") self.image_only = image_only self.transform_with_metadata = transform_with_metadata - self.loader = LoadImage(reader, image_only, dtype, *args, **kwargs) + self.loader = LoadImage(reader, dtype, *args, **kwargs) self.set_random_state(seed=get_seed()) self._seed = 0 # transform synchronization seed @@ -101,14 +101,13 @@ def __getitem__(self, index: int): meta_data, seg_meta_data, seg, label = None, None, None, None # load data and optionally meta - if self.image_only: - img = self.loader(self.image_files[index]) - if self.seg_files is not None: - seg = self.loader(self.seg_files[index]) - else: - img, meta_data = self.loader(self.image_files[index]) - if self.seg_files is not None: - seg, seg_meta_data = self.loader(self.seg_files[index]) + img = self.loader(self.image_files[index]) + if not self.image_only: + meta_data = img.meta + if self.seg_files is not None: + seg = self.loader(self.seg_files[index]) + if not self.image_only: + seg_meta_data = seg.meta # apply the transforms if self.transform is not None: diff --git a/monai/data/meta_obj.py b/monai/data/meta_obj.py index e38e009e96..7196ce31f1 100644 --- a/monai/data/meta_obj.py +++ b/monai/data/meta_obj.py @@ -153,7 +153,7 @@ def _copy_attr(self, attribute: str, input_objs: list[MetaObj], default_fn: Call Returns: Returns `None`, but `self` should be updated to have the copied attribute. """ - attributes = [getattr(i, attribute) for i in input_objs] + attributes = [getattr(i, attribute) for i in input_objs if hasattr(i, attribute)] if len(attributes) > 0: val = attributes[0] if deep_copy: @@ -189,9 +189,7 @@ def get_default_meta(self) -> dict: def __repr__(self) -> str: """String representation of class.""" - out: str = super().__repr__() - - out += "\nMetaData\n" + out: str = "\nMetaData\n" if self.meta is not None: out += "".join(f"\t{k}: {v}\n" for k, v in self.meta.items()) else: diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 9196f0186c..aae012fec0 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -17,8 +17,10 @@ import torch +from monai.config.type_definitions import NdarrayTensor from monai.data.meta_obj import MetaObj, get_track_meta, get_track_transforms from monai.data.utils import decollate_batch, list_data_collate +from monai.transforms.utils import remove_extra_metadata from monai.utils.enums import PostFix __all__ = ["MetaTensor"] @@ -186,8 +188,8 @@ def __torch_function__(cls, func, types, args=(), kwargs=None) -> Any: kwargs = {} ret = super().__torch_function__(func, types, args, kwargs) # if `out` has been used as argument, metadata is not copied, nothing to do. - if "out" in kwargs: - return ret + # if "out" in kwargs: + # return ret # we might have 1 or multiple outputs. Might be MetaTensor, might be something # else (e.g., `__repr__` returns a string). # Convert to list (if necessary), process, and at end remove list if one was added. @@ -232,3 +234,33 @@ def affine(self) -> torch.Tensor: def affine(self, d: torch.Tensor) -> None: """Set the affine.""" self.meta["affine"] = d + + @staticmethod + def ensure_torch_and_prune_meta(im: NdarrayTensor, meta: dict): + """ + Convert the image to `torch.Tensor`. If `affine` is in the `meta` dictionary, + convert that to `torch.Tensor`, too. Remove any superfluous metadata. + + Args: + im: Input image (`np.ndarray` or `torch.Tensor`) + meta: Metadata dictionary. + + Returns: + By default, a `MetaTensor` is returned. + However, if `get_track_meta()` is `False`, a `torch.Tensor` is returned. + """ + img = torch.as_tensor(im) + + # if not tracking metadata, return `torch.Tensor` + if not get_track_meta() or meta is None: + return img + + # ensure affine is of type `torch.Tensor` + if "affine" in meta: + meta["affine"] = torch.as_tensor(meta["affine"]) + + # remove any superfluous metadata. + remove_extra_metadata(meta) + + # return the `MetaTensor` + return MetaTensor(img, meta=meta) diff --git a/monai/data/utils.py b/monai/data/utils.py index 2bd7b49731..df40ca3af3 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -581,6 +581,8 @@ def affine_to_spacing(affine: NdarrayTensor, r: int = 3, dtype=float, suppress_z Returns: an `r` dimensional vector of spacing. """ + if len(affine.shape) != 2 or affine.shape[0] != affine.shape[1]: + raise ValueError(f"affine must be a square matrix, got {affine.shape}.") _affine, *_ = convert_to_dst_type(affine[:r, :r], dst=affine, dtype=dtype) if isinstance(_affine, torch.Tensor): spacing = torch.sqrt(torch.sum(_affine * _affine, dim=0)) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index c2385499b3..75f95f4d5b 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -569,6 +569,7 @@ generate_label_classes_crop_centers, generate_pos_neg_label_crop_centers, generate_spatial_bounding_box, + get_extra_metadata_keys, get_extreme_points, get_largest_connected_component_mask, get_number_image_type_conversions, @@ -582,6 +583,8 @@ map_spatial_axes, print_transform_backends, rand_choice, + remove_extra_metadata, + remove_keys, rescale_array, rescale_array_int_max, rescale_instance_array, diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 5bafd84eaf..3c4e8d59dd 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -29,11 +29,19 @@ from monai.data import image_writer from monai.data.folder_layout import FolderLayout from monai.data.image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader +from monai.data.meta_tensor import MetaTensor from monai.transforms.transform import Transform from monai.transforms.utility.array import EnsureChannelFirst from monai.utils import GridSampleMode, GridSamplePadMode from monai.utils import ImageMetaKey as Key -from monai.utils import InterpolateMode, OptionalImportError, ensure_tuple, look_up_option, optional_import +from monai.utils import ( + InterpolateMode, + OptionalImportError, + deprecated_arg, + ensure_tuple, + look_up_option, + optional_import, +) nib, _ = optional_import("nibabel") Image, _ = optional_import("PIL.Image") @@ -93,14 +101,11 @@ class LoadImage(Transform): """ + @deprecated_arg( + name="image_only", since="0.8", msg_suffix="If necessary, please extract meta data with `MetaTensor.meta`" + ) def __init__( - self, - reader=None, - image_only: bool = False, - dtype: DtypeLike = np.float32, - ensure_channel_first: bool = False, - *args, - **kwargs, + self, reader=None, dtype: DtypeLike = np.float32, ensure_channel_first: bool = False, *args, **kwargs ) -> None: """ Args: @@ -111,7 +116,6 @@ def __init__( ``"ITKReader"``, ``"NibabelReader"``, ``"NumpyReader"``. a reader instance will be constructed with the `*args` and `**kwargs` parameters. - if `reader` is a reader class/instance, it will be registered to this loader accordingly. - image_only: if True return only the image volume, otherwise return image data array and header dict. dtype: if not None convert the loaded image to this data type. ensure_channel_first: if `True` and loaded both image array and meta data, automatically convert the image array shape to `channel first`. default to `False`. @@ -120,8 +124,8 @@ def __init__( Note: - - The transform returns an image data array if `image_only` is True, - or a tuple of two elements containing the data array, and the meta data in a dictionary format otherwise. + - The transform returns a MetaTensor, unless `set_track_meta(False)` has been used, in which case, a + `torch.Tensor` will be returned. - If `reader` is specified, the loader will attempt to use the specified readers and the default supported readers. This might introduce overheads when handling the exceptions of trying the incompatible loaders. In this case, it is therefore recommended setting the most appropriate reader as @@ -130,7 +134,6 @@ def __init__( """ self.auto_select = reader is None - self.image_only = image_only self.dtype = dtype self.ensure_channel_first = ensure_channel_first @@ -241,14 +244,12 @@ def __call__(self, filename: Union[Sequence[PathLike], PathLike], reader: Option raise ValueError("`meta_data` must be a dict.") # make sure all elements in metadata are little endian meta_data = switch_endianness(meta_data, "<") - if self.ensure_channel_first: - img_array = EnsureChannelFirst()(img_array, meta_data) - if self.image_only: - return img_array meta_data[Key.FILENAME_OR_OBJ] = f"{ensure_tuple(filename)[0]}" # Path obj should be strings for data loader - - return img_array, meta_data + img = MetaTensor.ensure_torch_and_prune_meta(img_array, meta_data) + if self.ensure_channel_first: + img = EnsureChannelFirst()(img) + return img class SaveImage(Transform): diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index 30dedc7810..1aa6b934af 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -25,7 +25,7 @@ from monai.data.image_reader import ImageReader from monai.transforms.io.array import LoadImage, SaveImage from monai.transforms.transform import MapTransform -from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode, ensure_tuple, ensure_tuple_rep +from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode, deprecated_arg, ensure_tuple_rep from monai.utils.enums import PostFix __all__ = ["LoadImaged", "LoadImageD", "LoadImageDict", "SaveImaged", "SaveImageD", "SaveImageDict"] @@ -64,6 +64,10 @@ class LoadImaged(MapTransform): """ + @deprecated_arg(name="image_only", since="0.8") + @deprecated_arg(name="meta_keys", since="0.8") + @deprecated_arg(name="meta_key_postfix", since="0.8") + @deprecated_arg(name="overwriting", since="0.8") def __init__( self, keys: KeysCollection, @@ -90,17 +94,6 @@ def __init__( a reader instance will be constructed with the `*args` and `**kwargs` parameters. - if `reader` is a reader class/instance, it will be registered to this loader accordingly. dtype: if not None, convert the loaded image data to this data type. - meta_keys: explicitly indicate the key to store the corresponding meta data dictionary. - the meta data is a dictionary object which contains: filename, original_shape, etc. - it can be a sequence of string, map to the `keys`. - if None, will try to construct meta_keys by `key_{meta_key_postfix}`. - meta_key_postfix: if meta_keys is None, use `key_{postfix}` to store the metadata of the nifti image, - default is `meta_dict`. The meta data is a dictionary object. - For example, load nifti file for `image`, store the metadata into `image_meta_dict`. - overwriting: whether allow overwriting existing meta data of same key. - default is False, which will raise exception if encountering existing key. - image_only: if True return dictionary containing just only the image volumes, otherwise return - dictionary containing image data array and header dict per input key. ensure_channel_first: if `True` and loaded both image array and meta data, automatically convert the image array shape to `channel first`. default to `False`. allow_missing_keys: don't raise exception if key is missing. @@ -108,14 +101,7 @@ def __init__( kwargs: additional parameters for reader if providing a reader name. """ super().__init__(keys, allow_missing_keys) - self._loader = LoadImage(reader, image_only, dtype, ensure_channel_first, *args, **kwargs) - if not isinstance(meta_key_postfix, str): - raise TypeError(f"meta_key_postfix must be a str but is {type(meta_key_postfix).__name__}.") - self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys) - if len(self.keys) != len(self.meta_keys): - raise ValueError("meta_keys should have the same length as keys.") - self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) - self.overwriting = overwriting + self._loader = LoadImage(reader, dtype, ensure_channel_first, *args, **kwargs) def register(self, reader: ImageReader): self._loader.register(reader) @@ -127,22 +113,8 @@ def __call__(self, data, reader: Optional[ImageReader] = None): """ d = dict(data) - for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix): - data = self._loader(d[key], reader) - if self._loader.image_only: - if not isinstance(data, np.ndarray): - raise ValueError("loader must return a numpy array (because image_only=True was used).") - d[key] = data - else: - if not isinstance(data, (tuple, list)): - raise ValueError("loader must return a tuple or list (because image_only=False was used).") - d[key] = data[0] - if not isinstance(data[1], dict): - raise ValueError("metadata must be a dict.") - meta_key = meta_key or f"{key}_{meta_key_postfix}" - if meta_key in d and not self.overwriting: - raise KeyError(f"Meta data with key {meta_key} already exists and overwriting=False.") - d[meta_key] = data[1] + for key in self.key_iterator(d): + d[key] = self._loader(d[key], reader) return d diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index bc0c09e949..f512c94dc4 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -24,6 +24,7 @@ from monai.config import DtypeLike from monai.config.type_definitions import NdarrayOrTensor +from monai.data.meta_tensor import MetaTensor from monai.transforms.transform import Randomizable, RandomizableTransform, Transform from monai.transforms.utils import ( extreme_points_to_image, @@ -210,6 +211,8 @@ def __call__(self, img: NdarrayOrTensor, meta_dict: Optional[Mapping] = None) -> """ Apply the transform to `img`. """ + if isinstance(img, MetaTensor): + meta_dict = img.meta if not isinstance(meta_dict, Mapping): msg = "meta_dict not available, EnsureChannelFirst is not in use." if self.strict_check: diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 564b2993e7..1b10b6ee85 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -425,7 +425,13 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N d[split_meta_key] = deepcopy(orig_meta) dim = self.splitter.dim if dim > 0: # don't update affine if channel dim - shift = np.eye(len(d[split_meta_key]["affine"])) # type: ignore + affine = d[split_meta_key]["affine"] # type: ignore + ndim = len(affine) + shift: NdarrayOrTensor + if isinstance(affine, torch.Tensor): + shift = torch.eye(ndim, device=affine.device, dtype=affine.dtype) + else: + shift = np.eye(ndim) shift[dim - 1, -1] = i # type: ignore d[split_meta_key]["affine"] = d[split_meta_key]["affine"] @ shift # type: ignore diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 847614adfe..3be45b570d 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -105,6 +105,9 @@ "convert_pad_mode", "convert_to_contiguous", "get_unique_labels", + "remove_keys", + "remove_extra_metadata", + "get_extra_metadata_keys", ] @@ -1573,5 +1576,68 @@ def convert_to_contiguous(data, **kwargs): return data +def remove_keys(data: dict, keys: List[str]) -> None: + """ + Remove keys from a dictionary. Operates in-place so nothing is returned. + + Args: + data: dictionary to be modified. + keys: keys to be deleted from dictionary. + + Returns: + `None` + """ + for k in keys: + _ = data.pop(k, None) + + +def remove_extra_metadata(meta: dict) -> None: + """ + Remove extra metadata from the dictionary. Operates in-place so nothing is returned. + + Args: + meta: dictionary containing metadata to be modified. + + Returns: + `None` + """ + keys = get_extra_metadata_keys() + remove_keys(data=meta, keys=keys) + + +def get_extra_metadata_keys() -> List[str]: + """ + Get a list of unnecessary keys for metadata that can be removed. + + Returns: + List of keys to be removed. + """ + keys = [ + "srow_x", + "srow_y", + "srow_z", + "quatern_b", + "quatern_c", + "quatern_d", + "qoffset_x", + "qoffset_y", + "qoffset_z", + "dim", + "pixdim", + *[f"dim[{i}]" for i in range(8)], + *[f"pixdim[{i}]" for i in range(8)], + ] + + # TODO: it would be good to remove these, but they are currently being used in the + # codebase. + # keys += [ + # "original_affine", + # "spatial_shape", + # "spacing", + # ] + + return keys + + if __name__ == "__main__": print_transform_backends() diff --git a/tests/test_arraydataset.py b/tests/test_arraydataset.py index ee1a92cf97..689fc0cb3d 100644 --- a/tests/test_arraydataset.py +++ b/tests/test_arraydataset.py @@ -23,15 +23,15 @@ from monai.transforms import AddChannel, Compose, LoadImage, RandAdjustContrast, RandGaussianNoise, Spacing TEST_CASE_1 = [ - Compose([LoadImage(image_only=True), AddChannel(), RandGaussianNoise(prob=1.0)]), - Compose([LoadImage(image_only=True), AddChannel(), RandGaussianNoise(prob=1.0)]), + Compose([LoadImage(), AddChannel(), RandGaussianNoise(prob=1.0)]), + Compose([LoadImage(), AddChannel(), RandGaussianNoise(prob=1.0)]), (0, 1), (1, 128, 128, 128), ] TEST_CASE_2 = [ - Compose([LoadImage(image_only=True), AddChannel(), RandAdjustContrast(prob=1.0)]), - Compose([LoadImage(image_only=True), AddChannel(), RandAdjustContrast(prob=1.0)]), + Compose([LoadImage(), AddChannel(), RandAdjustContrast(prob=1.0)]), + Compose([LoadImage(), AddChannel(), RandAdjustContrast(prob=1.0)]), (0, 1), (1, 128, 128, 128), ] @@ -39,20 +39,21 @@ class TestCompose(Compose): def __call__(self, input_): - img, metadata = self.transforms[0](input_) + img = self.transforms[0](input_) + metadata = img.meta img = self.transforms[1](img) img, _, _ = self.transforms[2](img, metadata["affine"]) return self.transforms[3](img), metadata TEST_CASE_3 = [ - TestCompose([LoadImage(image_only=False), AddChannel(), Spacing(pixdim=(2, 2, 4)), RandAdjustContrast(prob=1.0)]), - TestCompose([LoadImage(image_only=False), AddChannel(), Spacing(pixdim=(2, 2, 4)), RandAdjustContrast(prob=1.0)]), + TestCompose([LoadImage(), AddChannel(), Spacing(pixdim=(2, 2, 4)), RandAdjustContrast(prob=1.0)]), + TestCompose([LoadImage(), AddChannel(), Spacing(pixdim=(2, 2, 4)), RandAdjustContrast(prob=1.0)]), (0, 2), (1, 64, 64, 33), ] -TEST_CASE_4 = [Compose([LoadImage(image_only=True), AddChannel(), RandGaussianNoise(prob=1.0)]), (1, 128, 128, 128)] +TEST_CASE_4 = [Compose([LoadImage(), AddChannel(), RandGaussianNoise(prob=1.0)]), (1, 128, 128, 128)] class TestArrayDataset(unittest.TestCase): diff --git a/tests/test_dataset_summary.py b/tests/test_dataset_summary.py index 51840f77ea..d0531b28a0 100644 --- a/tests/test_dataset_summary.py +++ b/tests/test_dataset_summary.py @@ -19,6 +19,9 @@ from monai.data import Dataset, DatasetSummary, create_test_image_3d from monai.transforms import LoadImaged +from monai.transforms.compose import Compose +from monai.transforms.meta_utility.dictionary import FromMetaTensord +from monai.transforms.utility.dictionary import ToNumpyd from monai.utils import set_determinism from monai.utils.enums import PostFix @@ -50,12 +53,17 @@ def test_spacing_intensity(self): {"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels) ] - dataset = Dataset( - data=data_dicts, transform=LoadImaged(keys=["image", "label"], meta_keys=["test1", "test2"]) + t = Compose( + [ + LoadImaged(keys=["image", "label"]), + FromMetaTensord(keys=["image", "label"]), + ToNumpyd(keys=["image", "label", "image_meta_dict", "label_meta_dict"]), + ] ) + dataset = Dataset(data=data_dicts, transform=t) # test **kwargs of `DatasetSummary` for `DataLoader` - calculator = DatasetSummary(dataset, num_workers=4, meta_key="test1", collate_fn=test_collate) + calculator = DatasetSummary(dataset, num_workers=4, meta_key="image_meta_dict", collate_fn=test_collate) target_spacing = calculator.get_target_spacing() self.assertEqual(target_spacing, (1.0, 1.0, 1.0)) @@ -85,7 +93,8 @@ def test_anisotropic_spacing(self): {"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels) ] - dataset = Dataset(data=data_dicts, transform=LoadImaged(keys=["image", "label"])) + t = Compose([LoadImaged(keys=["image", "label"]), FromMetaTensord(keys=["image", "label"])]) + dataset = Dataset(data=data_dicts, transform=t) calculator = DatasetSummary(dataset, num_workers=4, meta_key_postfix=PostFix.meta()) diff --git a/tests/test_decollate.py b/tests/test_decollate.py index adeaa73337..60e3eb37e3 100644 --- a/tests/test_decollate.py +++ b/tests/test_decollate.py @@ -79,14 +79,6 @@ ] -class _ListCompose(Compose): - def __call__(self, input_): - img, metadata = self.transforms[0](input_) - for t in self.transforms[1:]: - img = t(img) - return img, metadata - - class TestDeCollate(unittest.TestCase): def setUp(self) -> None: set_determinism(seed=0) @@ -148,7 +140,7 @@ def test_decollation_tensor(self, *transforms): t_compose = Compose([AddChannel(), Compose(transforms), ToTensor()]) # If nibabel present, read from disk if has_nib: - t_compose = Compose([LoadImage(image_only=True), t_compose]) + t_compose = Compose([LoadImage(), t_compose]) dataset = Dataset(self.data_list, t_compose) self.check_decollate(dataset=dataset) @@ -158,7 +150,7 @@ def test_decollation_list(self, *transforms): t_compose = Compose([AddChannel(), Compose(transforms), ToTensor()]) # If nibabel present, read from disk if has_nib: - t_compose = _ListCompose([LoadImage(image_only=False), t_compose]) + t_compose = Compose([LoadImage(), t_compose]) dataset = Dataset(self.data_list, t_compose) self.check_decollate(dataset=dataset) diff --git a/tests/test_ensure_channel_first.py b/tests/test_ensure_channel_first.py index dd6168ec75..4f22a16f11 100644 --- a/tests/test_ensure_channel_first.py +++ b/tests/test_ensure_channel_first.py @@ -23,23 +23,19 @@ from monai.transforms import EnsureChannelFirst, LoadImage from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [{"image_only": False}, ["test_image.nii.gz"], None] +TEST_CASE_1 = [{}, ["test_image.nii.gz"], None] -TEST_CASE_2 = [{"image_only": False}, ["test_image.nii.gz"], -1] +TEST_CASE_2 = [{}, ["test_image.nii.gz"], -1] -TEST_CASE_3 = [{"image_only": False}, ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], None] +TEST_CASE_3 = [{}, ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], None] -TEST_CASE_4 = [{"reader": ITKReader(), "image_only": False}, ["test_image.nii.gz"], None] +TEST_CASE_4 = [{"reader": ITKReader()}, ["test_image.nii.gz"], None] -TEST_CASE_5 = [{"reader": ITKReader(), "image_only": False}, ["test_image.nii.gz"], -1] +TEST_CASE_5 = [{"reader": ITKReader()}, ["test_image.nii.gz"], -1] -TEST_CASE_6 = [ - {"reader": ITKReader(), "image_only": False}, - ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], - None, -] +TEST_CASE_6 = [{"reader": ITKReader()}, ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], None] -TEST_CASE_7 = [{"image_only": False, "reader": ITKReader(pixel_type=itk.UC)}, "tests/testing_data/CT_DICOM", None] +TEST_CASE_7 = [{"reader": ITKReader(pixel_type=itk.UC)}, "tests/testing_data/CT_DICOM", None] class TestEnsureChannelFirst(unittest.TestCase): @@ -55,14 +51,15 @@ def test_load_nifti(self, input_param, filenames, original_channel_dim): filenames[i] = os.path.join(tempdir, name) nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i]) for p in TEST_NDARRAYS: - result, header = LoadImage(**input_param)(filenames) + result = LoadImage(**input_param)(filenames) + header = result.meta result = EnsureChannelFirst()(p(result), header) self.assertEqual(result.shape[0], len(filenames)) @parameterized.expand([TEST_CASE_7]) def test_itk_dicom_series_reader(self, input_param, filenames, original_channel_dim): - result, header = LoadImage(**input_param)(filenames) - result = EnsureChannelFirst()(result, header) + result = LoadImage(**input_param)(filenames) + result = EnsureChannelFirst()(result) self.assertEqual(result.shape[0], 1) def test_load_png(self): @@ -71,8 +68,8 @@ def test_load_png(self): with tempfile.TemporaryDirectory() as tempdir: filename = os.path.join(tempdir, "test_image.png") Image.fromarray(test_image.astype("uint8")).save(filename) - result, header = LoadImage(image_only=False)(filename) - result = EnsureChannelFirst()(result, header) + result = LoadImage()(filename) + result = EnsureChannelFirst()(result) self.assertEqual(result.shape[0], 3) def test_check(self): diff --git a/tests/test_ensure_channel_firstd.py b/tests/test_ensure_channel_firstd.py index 7f1a57a207..cb1694d4e9 100644 --- a/tests/test_ensure_channel_firstd.py +++ b/tests/test_ensure_channel_firstd.py @@ -18,7 +18,7 @@ from parameterized import parameterized from PIL import Image -from monai.transforms import EnsureChannelFirstd, LoadImaged +from monai.transforms import EnsureChannelFirstd, FromMetaTensord, LoadImaged from monai.utils.enums import PostFix from tests.utils import TEST_NDARRAYS @@ -43,6 +43,7 @@ def test_load_nifti(self, input_param, filenames, original_channel_dim): nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i]) for p in TEST_NDARRAYS: result = LoadImaged(**input_param)({"img": filenames}) + result = FromMetaTensord("img")(result) result["img"] = p(result["img"]) result = EnsureChannelFirstd(**input_param)(result) self.assertEqual(result["img"].shape[0], len(filenames)) @@ -54,6 +55,7 @@ def test_load_png(self): filename = os.path.join(tempdir, "test_image.png") Image.fromarray(test_image.astype("uint8")).save(filename) result = LoadImaged(keys="img")({"img": filename}) + result = FromMetaTensord(keys="img")(result) result = EnsureChannelFirstd(keys="img")(result) self.assertEqual(result["img"].shape[0], 3) diff --git a/tests/test_image_dataset.py b/tests/test_image_dataset.py index 41eda803dc..fae6cedff9 100644 --- a/tests/test_image_dataset.py +++ b/tests/test_image_dataset.py @@ -15,6 +15,7 @@ import nibabel as nib import numpy as np +import torch from monai.data import ImageDataset from monai.transforms import ( @@ -93,7 +94,7 @@ def test_dataset(self): # loading no meta, int dataset = ImageDataset(full_names, dtype=np.float16) for d, _ in zip(dataset, ref_data): - self.assertEqual(d.dtype, np.float16) + self.assertEqual(d.dtype, torch.float16) # loading with meta, no transform dataset = ImageDataset(full_names, image_only=False) diff --git a/tests/test_image_rw.py b/tests/test_image_rw.py index 62b1147aa5..32f31a6af1 100644 --- a/tests/test_image_rw.py +++ b/tests/test_image_rw.py @@ -16,6 +16,7 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.data.image_reader import ITKReader, NibabelReader, PILReader @@ -52,14 +53,15 @@ def nifti_rw(self, test_data, reader, writer, dtype, resample=True): saved_path = os.path.join(self.test_dir, filepath + "_trans" + output_ext) self.assertTrue(os.path.exists(saved_path)) loader = LoadImage(reader=reader, squeeze_non_spatial_dims=True) - data, meta = loader(saved_path) + data = loader(saved_path) + meta = data.meta if meta["original_channel_dim"] == -1: _test_data = moveaxis(test_data, 0, -1) else: _test_data = test_data[0] if resample: _test_data = moveaxis(_test_data, 0, 1) - assert_allclose(data, _test_data) + assert_allclose(data, torch.as_tensor(_test_data)) @parameterized.expand(itertools.product([NibabelReader, ITKReader], [NibabelWriter, "ITKWriter"])) def test_2d(self, reader, writer): @@ -99,12 +101,13 @@ def png_rw(self, test_data, reader, writer, dtype, resample=True): saved_path = os.path.join(self.test_dir, filepath + "_trans" + output_ext) self.assertTrue(os.path.exists(saved_path)) loader = LoadImage(reader=reader) - data, meta = loader(saved_path) + data = loader(saved_path) + meta = data.meta if meta["original_channel_dim"] == -1: _test_data = moveaxis(test_data, 0, -1) else: _test_data = test_data[0] - assert_allclose(data, _test_data) + assert_allclose(data, torch.as_tensor(_test_data)) @parameterized.expand(itertools.product([PILReader, ITKReader], [PILWriter, ITKWriter])) def test_2d(self, reader, writer): diff --git a/tests/test_invertd.py b/tests/test_invertd.py index 64c26c4012..183689113a 100644 --- a/tests/test_invertd.py +++ b/tests/test_invertd.py @@ -22,6 +22,7 @@ Compose, CopyItemsd, EnsureTyped, + FromMetaTensord, Invertd, LoadImaged, Orientationd, @@ -50,6 +51,7 @@ def test_invert(self): transform = Compose( [ LoadImaged(KEYS), + FromMetaTensord(KEYS), AddChanneld(KEYS), Orientationd(KEYS, "RPS"), Spacingd(KEYS, pixdim=(1.2, 1.01, 0.9), mode=["bilinear", "nearest"], dtype=np.float32), @@ -156,7 +158,7 @@ def test_invert(self): reverted = item["label_inverted"].detach().cpu().numpy().astype(np.int32) original = LoadImaged(KEYS)(data[-1])["label"] n_good = np.sum(np.isclose(reverted, original, atol=1e-3)) - reverted_name = item[PostFix.meta("label_inverted")]["filename_or_obj"] + reverted_name = item["label_inverted"].meta["filename_or_obj"] original_name = data[-1]["label"] self.assertEqual(reverted_name, original_name) print("invert diff", reverted.size - n_good) diff --git a/tests/test_load_image.py b/tests/test_load_image.py index 201fe2fd5b..9509d26283 100644 --- a/tests/test_load_image.py +++ b/tests/test_load_image.py @@ -10,6 +10,7 @@ # limitations under the License. import os +import shutil import tempfile import unittest from pathlib import Path @@ -17,11 +18,15 @@ import itk import nibabel as nib import numpy as np +import torch from parameterized import parameterized from PIL import Image from monai.data import ITKReader, NibabelReader +from monai.data.meta_obj import set_track_meta +from monai.data.meta_tensor import MetaTensor from monai.transforms import LoadImage +from tests.utils import assert_allclose class _MiniReader: @@ -40,75 +45,57 @@ def get_data(self, _obj): return np.zeros((1, 1, 1)), {"name": "my test"} -TEST_CASE_1 = [{"image_only": True}, ["test_image.nii.gz"], (128, 128, 128)] +TEST_CASE_1 = [{}, ["test_image.nii.gz"], (128, 128, 128)] -TEST_CASE_2 = [{"image_only": False}, ["test_image.nii.gz"], (128, 128, 128)] +TEST_CASE_2 = [{}, ["test_image.nii.gz"], (128, 128, 128)] -TEST_CASE_3 = [ - {"image_only": True}, - ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], - (3, 128, 128, 128), -] +TEST_CASE_3 = [{}, ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], (3, 128, 128, 128)] TEST_CASE_3_1 = [ # .mgz format - {"image_only": True, "reader": "nibabelreader"}, + {"reader": "nibabelreader"}, ["test_image.mgz", "test_image2.mgz", "test_image3.mgz"], (3, 128, 128, 128), ] -TEST_CASE_4 = [ - {"image_only": False}, - ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], - (3, 128, 128, 128), -] +TEST_CASE_4 = [{}, ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], (3, 128, 128, 128)] TEST_CASE_4_1 = [ # additional parameter - {"image_only": False, "mmap": False}, + {"mmap": False}, ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], (3, 128, 128, 128), ] -TEST_CASE_5 = [{"reader": NibabelReader(mmap=False), "image_only": False}, ["test_image.nii.gz"], (128, 128, 128)] +TEST_CASE_5 = [{"reader": NibabelReader(mmap=False)}, ["test_image.nii.gz"], (128, 128, 128)] -TEST_CASE_6 = [{"reader": ITKReader(), "image_only": True}, ["test_image.nii.gz"], (128, 128, 128)] +TEST_CASE_6 = [{"reader": ITKReader()}, ["test_image.nii.gz"], (128, 128, 128)] -TEST_CASE_7 = [{"reader": ITKReader(), "image_only": False}, ["test_image.nii.gz"], (128, 128, 128)] +TEST_CASE_7 = [{"reader": ITKReader()}, ["test_image.nii.gz"], (128, 128, 128)] TEST_CASE_8 = [ - {"reader": ITKReader(), "image_only": True}, + {"reader": ITKReader()}, ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], (3, 128, 128, 128), ] TEST_CASE_8_1 = [ - {"reader": ITKReader(channel_dim=0), "image_only": True}, + {"reader": ITKReader(channel_dim=0)}, ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], (384, 128, 128), ] TEST_CASE_9 = [ - {"reader": ITKReader(), "image_only": False}, + {"reader": ITKReader()}, ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], (3, 128, 128, 128), ] -TEST_CASE_10 = [ - {"image_only": False, "reader": ITKReader(pixel_type=itk.UC)}, - "tests/testing_data/CT_DICOM", - (16, 16, 4), - (16, 16, 4), -] +TEST_CASE_10 = [{"reader": ITKReader(pixel_type=itk.UC)}, "tests/testing_data/CT_DICOM", (16, 16, 4), (16, 16, 4)] -TEST_CASE_11 = [ - {"image_only": False, "reader": "ITKReader", "pixel_type": itk.UC}, - "tests/testing_data/CT_DICOM", - (16, 16, 4), - (16, 16, 4), -] +TEST_CASE_11 = [{"reader": "ITKReader", "pixel_type": itk.UC}, "tests/testing_data/CT_DICOM", (16, 16, 4), (16, 16, 4)] TEST_CASE_12 = [ - {"image_only": False, "reader": "ITKReader", "pixel_type": itk.UC, "reverse_indexing": True}, + {"reader": "ITKReader", "pixel_type": itk.UC, "reverse_indexing": True}, "tests/testing_data/CT_DICOM", (16, 16, 4), (4, 16, 16), @@ -135,6 +122,12 @@ def get_data(self, _obj): ] +TESTS_META = [] +for track_meta in (False, True): + TESTS_META.append([{}, (128, 128, 128), track_meta]) + TESTS_META.append([{"reader": "ITKReader", "fallback_only": False}, (128, 128, 128), track_meta]) + + class TestLoadImage(unittest.TestCase): @parameterized.expand( [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_3_1, TEST_CASE_4, TEST_CASE_4_1, TEST_CASE_5] @@ -146,13 +139,9 @@ def test_nibabel_reader(self, input_param, filenames, expected_shape): filenames[i] = os.path.join(tempdir, name) nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i]) result = LoadImage(**input_param)(filenames) - - if isinstance(result, tuple): - result, header = result - self.assertTrue("affine" in header) - self.assertEqual(header["filename_or_obj"], os.path.join(tempdir, "test_image.nii.gz")) - np.testing.assert_allclose(header["affine"], np.eye(4)) - np.testing.assert_allclose(header["original_affine"], np.eye(4)) + ext = "".join(Path(name).suffixes) + self.assertEqual(result.meta["filename_or_obj"], os.path.join(tempdir, "test_image" + ext)) + assert_allclose(result.affine, torch.eye(4)) self.assertTupleEqual(result.shape, expected_shape) @parameterized.expand([TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_8_1, TEST_CASE_9]) @@ -164,24 +153,18 @@ def test_itk_reader(self, input_param, filenames, expected_shape): itk_np_view = itk.image_view_from_array(test_image) itk.imwrite(itk_np_view, filenames[i]) result = LoadImage(**input_param)(filenames) - - if isinstance(result, tuple): - result, header = result - self.assertTrue("affine" in header) - self.assertEqual(header["filename_or_obj"], os.path.join(tempdir, "test_image.nii.gz")) - np_diag = np.diag([-1, -1, 1, 1]) - np.testing.assert_allclose(header["affine"], np_diag) - np.testing.assert_allclose(header["original_affine"], np_diag) + self.assertEqual(result.meta["filename_or_obj"], os.path.join(tempdir, "test_image.nii.gz")) + diag = torch.as_tensor(np.diag([-1, -1, 1, 1])) + np.testing.assert_allclose(result.affine, diag) self.assertTupleEqual(result.shape, expected_shape) @parameterized.expand([TEST_CASE_10, TEST_CASE_11, TEST_CASE_12]) def test_itk_dicom_series_reader(self, input_param, filenames, expected_shape, expected_np_shape): - result, header = LoadImage(**input_param)(filenames) - self.assertTrue("affine" in header) - self.assertEqual(header["filename_or_obj"], f"{Path(filenames)}") - np.testing.assert_allclose( - header["affine"], - np.array( + result = LoadImage(**input_param)(filenames) + self.assertEqual(result.meta["filename_or_obj"], f"{Path(filenames)}") + assert_allclose( + result.affine, + torch.tensor( [ [-0.488281, 0.0, 0.0, 125.0], [0.0, -0.488281, 0.0, 128.100006], @@ -190,7 +173,6 @@ def test_itk_dicom_series_reader(self, input_param, filenames, expected_shape, e ] ), ) - self.assertTupleEqual(tuple(header["spatial_shape"]), expected_shape) self.assertTupleEqual(result.shape, expected_np_shape) def test_itk_reader_multichannel(self): @@ -200,9 +182,7 @@ def test_itk_reader_multichannel(self): itk_np_view = itk.image_view_from_array(test_image, is_vector=True) itk.imwrite(itk_np_view, filename) for flag in (False, True): - result, header = LoadImage(reader=ITKReader(reverse_indexing=flag))(Path(filename)) - - self.assertTupleEqual(tuple(header["spatial_shape"]), (224, 256)) + result = LoadImage(reader=ITKReader(reverse_indexing=flag))(Path(filename)) test_image = test_image.transpose(1, 0, 2) np.testing.assert_allclose(result[:, :, 0], test_image[:, :, 0]) np.testing.assert_allclose(result[:, :, 1], test_image[:, :, 1]) @@ -215,12 +195,10 @@ def test_load_nifti_multichannel(self): itk_np_view = itk.image_view_from_array(test_image, is_vector=True) itk.imwrite(itk_np_view, filename) - itk_img, itk_header = LoadImage(reader=ITKReader())(Path(filename)) - self.assertTupleEqual(tuple(itk_header["spatial_shape"]), (16, 64, 31)) + itk_img = LoadImage(reader=ITKReader())(Path(filename)) self.assertTupleEqual(tuple(itk_img.shape), (16, 64, 31, 2)) - nib_image, nib_header = LoadImage(reader=NibabelReader(squeeze_non_spatial_dims=True))(Path(filename)) - self.assertTupleEqual(tuple(nib_header["spatial_shape"]), (16, 64, 31)) + nib_image = LoadImage(reader=NibabelReader(squeeze_non_spatial_dims=True))(Path(filename)) self.assertTupleEqual(tuple(nib_image.shape), (16, 64, 31, 2)) np.testing.assert_allclose(itk_img, nib_image, atol=1e-3, rtol=1e-3) @@ -231,8 +209,7 @@ def test_load_png(self): with tempfile.TemporaryDirectory() as tempdir: filename = os.path.join(tempdir, "test_image.png") Image.fromarray(test_image.astype("uint8")).save(filename) - result, header = LoadImage(image_only=False)(filename) - self.assertTupleEqual(tuple(header["spatial_shape"]), spatial_size[::-1]) + result = LoadImage()(filename) self.assertTupleEqual(result.shape, spatial_size[::-1]) np.testing.assert_allclose(result.T, test_image) @@ -244,10 +221,9 @@ def test_register(self): itk_np_view = itk.image_view_from_array(test_image) itk.imwrite(itk_np_view, filename) - loader = LoadImage(image_only=False) + loader = LoadImage() loader.register(ITKReader()) - result, header = loader(filename) - self.assertTupleEqual(tuple(header["spatial_shape"]), spatial_size[::-1]) + result = loader(filename) self.assertTupleEqual(result.shape, spatial_size[::-1]) def test_kwargs(self): @@ -258,35 +234,35 @@ def test_kwargs(self): itk_np_view = itk.image_view_from_array(test_image) itk.imwrite(itk_np_view, filename) - loader = LoadImage(image_only=False) + loader = LoadImage() reader = ITKReader(fallback_only=False) loader.register(reader) - result, header = loader(filename) + result = loader(filename) reader = ITKReader() img = reader.read(filename, fallback_only=False) - result_raw, header_raw = reader.get_data(img) - np.testing.assert_allclose(header["spatial_shape"], header_raw["spatial_shape"]) + result_raw = reader.get_data(img) + result_raw = MetaTensor.ensure_torch_and_prune_meta(*result_raw) self.assertTupleEqual(result.shape, result_raw.shape) def test_my_reader(self): """test customised readers""" out = LoadImage(reader=_MiniReader, is_compatible=True)("test") - self.assertEqual(out[1]["name"], "my test") + self.assertEqual(out.meta["name"], "my test") out = LoadImage(reader=_MiniReader, is_compatible=False)("test") - self.assertEqual(out[1]["name"], "my test") + self.assertEqual(out.meta["name"], "my test") for item in (_MiniReader, _MiniReader(is_compatible=False)): out = LoadImage(reader=item)("test") - self.assertEqual(out[1]["name"], "my test") + self.assertEqual(out.meta["name"], "my test") out = LoadImage()("test", reader=_MiniReader(is_compatible=False)) - self.assertEqual(out[1]["name"], "my test") + self.assertEqual(out.meta["name"], "my test") def test_itk_meta(self): """test metadata from a directory""" - out, meta = LoadImage(reader="ITKReader", pixel_type=itk.UC, series_meta=True)("tests/testing_data/CT_DICOM") + out = LoadImage(reader="ITKReader", pixel_type=itk.UC, series_meta=True)("tests/testing_data/CT_DICOM") idx = "0008|103e" label = itk.GDCMImageIO.GetLabelFromTag(idx, "")[1] - val = meta[idx] + val = out.meta[idx] expected = "Series Description=Routine Brain " self.assertEqual(f"{label}={val}", expected) @@ -299,10 +275,38 @@ def test_channel_dim(self, input_param, filename, expected_shape): result = LoadImage(**input_param)(filename) self.assertTupleEqual( - result[0].shape, (3, 128, 128, 128) if input_param.get("ensure_channel_first", False) else expected_shape + result.shape, (3, 128, 128, 128) if input_param.get("ensure_channel_first", False) else expected_shape ) - self.assertTupleEqual(tuple(result[1]["spatial_shape"]), (128, 128, 128)) - self.assertEqual(result[1]["original_channel_dim"], input_param["channel_dim"]) + self.assertEqual(result.meta["original_channel_dim"], input_param["channel_dim"]) + + +class TestLoadImageMeta(unittest.TestCase): + @classmethod + def setUpClass(cls): + super(__class__, cls).setUpClass() + cls.tmpdir = tempfile.mkdtemp() + test_image = nib.Nifti1Image(np.random.rand(128, 128, 128), np.eye(4)) + nib.save(test_image, os.path.join(cls.tmpdir, "im.nii.gz")) + cls.test_data = os.path.join(cls.tmpdir, "im.nii.gz") + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tmpdir) + super(__class__, cls).tearDownClass() + + @parameterized.expand(TESTS_META) + def test_correct(self, input_param, expected_shape, track_meta): + set_track_meta(track_meta) + r = LoadImage(**input_param)(self.test_data) + self.assertTupleEqual(r.shape, expected_shape) + if track_meta: + self.assertIsInstance(r, MetaTensor) + self.assertTrue(hasattr(r, "affine")) + self.assertIsInstance(r.affine, torch.Tensor) + else: + self.assertIsInstance(r, torch.Tensor) + self.assertNotIsInstance(r, MetaTensor) + self.assertFalse(hasattr(r, "affine")) if __name__ == "__main__": diff --git a/tests/test_load_imaged.py b/tests/test_load_imaged.py index bc001cf2fd..3b2fc4f58b 100644 --- a/tests/test_load_imaged.py +++ b/tests/test_load_imaged.py @@ -10,6 +10,7 @@ # limitations under the License. import os +import shutil import tempfile import unittest from pathlib import Path @@ -17,11 +18,16 @@ import itk import nibabel as nib import numpy as np +import torch from parameterized import parameterized from monai.data import ITKReader -from monai.transforms import Compose, EnsureChannelFirstD, LoadImaged, SaveImageD +from monai.data.meta_obj import set_track_meta +from monai.data.meta_tensor import MetaTensor +from monai.transforms import Compose, EnsureChannelFirstD, FromMetaTensord, LoadImaged, SaveImageD +from monai.transforms.meta_utility.dictionary import ToMetaTensord from monai.utils.enums import PostFix +from tests.utils import assert_allclose KEYS = ["image", "label", "extra"] @@ -29,6 +35,11 @@ TEST_CASE_2 = [{"keys": KEYS, "reader": "ITKReader", "fallback_only": False}, (128, 128, 128)] +TESTS_META = [] +for track_meta in (False, True): + TESTS_META.append([{"keys": KEYS}, (128, 128, 128), track_meta]) + TESTS_META.append([{"keys": KEYS, "reader": "ITKReader", "fallback_only": False}, (128, 128, 128), track_meta]) + class TestLoadImaged(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) @@ -55,7 +66,6 @@ def test_register(self): loader = LoadImaged(keys="img") loader.register(ITKReader()) result = loader({"img": Path(filename)}) - self.assertTupleEqual(tuple(result[PostFix.meta("img")]["spatial_shape"]), spatial_size[::-1]) self.assertTupleEqual(result["img"].shape, spatial_size[::-1]) def test_channel_dim(self): @@ -67,8 +77,8 @@ def test_channel_dim(self): loader = LoadImaged(keys="img") loader.register(ITKReader(channel_dim=2)) - result = EnsureChannelFirstD("img")(loader({"img": filename})) - self.assertTupleEqual(tuple(result[PostFix.meta("img")]["spatial_shape"]), (32, 64, 128)) + t = Compose([FromMetaTensord("img"), EnsureChannelFirstD("img")]) + result = t(loader({"img": filename})) self.assertTupleEqual(result["img"].shape, (3, 32, 64, 128)) def test_no_file(self): @@ -79,49 +89,57 @@ def test_no_file(self): class TestConsistency(unittest.TestCase): - def _cmp(self, filename, shape, ch_shape, reader_1, reader_2, outname, ext): + def _cmp(self, filename, ch_shape, reader_1, reader_2, outname, ext): data_dict = {"img": filename} keys = data_dict.keys() xforms = Compose([LoadImaged(keys, reader=reader_1, ensure_channel_first=True)]) img_dict = xforms(data_dict) # load dicom with itk self.assertTupleEqual(img_dict["img"].shape, ch_shape) - self.assertTupleEqual(tuple(img_dict[PostFix.meta("img")]["spatial_shape"]), shape) with tempfile.TemporaryDirectory() as tempdir: - save_xform = SaveImageD( - keys, meta_keys=PostFix.meta("img"), output_dir=tempdir, squeeze_end_dims=False, output_ext=ext + save_xform = Compose( + [ + FromMetaTensord(keys), + SaveImageD( + keys, meta_keys=PostFix.meta("img"), output_dir=tempdir, squeeze_end_dims=False, output_ext=ext + ), + ] ) save_xform(img_dict) # save to nifti - new_xforms = Compose([LoadImaged(keys, reader=reader_2), EnsureChannelFirstD(keys)]) + new_xforms = Compose( + [ + LoadImaged(keys, reader=reader_2), + FromMetaTensord(keys), + EnsureChannelFirstD(keys), + ToMetaTensord(keys), + ] + ) out = new_xforms({"img": os.path.join(tempdir, outname)}) # load nifti with itk self.assertTupleEqual(out["img"].shape, ch_shape) - self.assertTupleEqual(tuple(out[PostFix.meta("img")]["spatial_shape"]), shape) - if "affine" in img_dict[PostFix.meta("img")] and "affine" in out[PostFix.meta("img")]: - np.testing.assert_allclose( - img_dict[PostFix.meta("img")]["affine"], out[PostFix.meta("img")]["affine"], rtol=1e-3 - ) - np.testing.assert_allclose(out["img"], img_dict["img"], rtol=1e-3) + + def is_identity(x): + return (x == torch.eye(x.shape[0])).all() + + if not is_identity(img_dict["img"].affine) and not is_identity(out["img"].affine): + assert_allclose(img_dict["img"].affine, out["img"].affine, rtol=1e-3) + assert_allclose(out["img"], img_dict["img"], rtol=1e-3) def test_dicom(self): img_dir = "tests/testing_data/CT_DICOM" - self._cmp( - img_dir, (16, 16, 4), (1, 16, 16, 4), "itkreader", "itkreader", "CT_DICOM/CT_DICOM_trans.nii.gz", ".nii.gz" - ) + self._cmp(img_dir, (1, 16, 16, 4), "itkreader", "itkreader", "CT_DICOM/CT_DICOM_trans.nii.gz", ".nii.gz") output_name = "CT_DICOM/CT_DICOM_trans.nii.gz" - self._cmp(img_dir, (16, 16, 4), (1, 16, 16, 4), "nibabelreader", "itkreader", output_name, ".nii.gz") - self._cmp(img_dir, (16, 16, 4), (1, 16, 16, 4), "itkreader", "nibabelreader", output_name, ".nii.gz") + self._cmp(img_dir, (1, 16, 16, 4), "nibabelreader", "itkreader", output_name, ".nii.gz") + self._cmp(img_dir, (1, 16, 16, 4), "itkreader", "nibabelreader", output_name, ".nii.gz") def test_multi_dicom(self): """multichannel dicom reading, saving to nifti, then load with itk or nibabel""" img_dir = ["tests/testing_data/CT_DICOM", "tests/testing_data/CT_DICOM"] - self._cmp( - img_dir, (16, 16, 4), (2, 16, 16, 4), "itkreader", "itkreader", "CT_DICOM/CT_DICOM_trans.nii.gz", ".nii.gz" - ) + self._cmp(img_dir, (2, 16, 16, 4), "itkreader", "itkreader", "CT_DICOM/CT_DICOM_trans.nii.gz", ".nii.gz") output_name = "CT_DICOM/CT_DICOM_trans.nii.gz" - self._cmp(img_dir, (16, 16, 4), (2, 16, 16, 4), "nibabelreader", "itkreader", output_name, ".nii.gz") - self._cmp(img_dir, (16, 16, 4), (2, 16, 16, 4), "itkreader", "nibabelreader", output_name, ".nii.gz") + self._cmp(img_dir, (2, 16, 16, 4), "nibabelreader", "itkreader", output_name, ".nii.gz") + self._cmp(img_dir, (2, 16, 16, 4), "itkreader", "nibabelreader", output_name, ".nii.gz") def test_png(self): """png reading with itk, saving to nifti, then load with itk or nibabel or PIL""" @@ -132,9 +150,45 @@ def test_png(self): itk_np_view = itk.image_view_from_array(test_image, is_vector=True) itk.imwrite(itk_np_view, filename) output_name = "test_image/test_image_trans.png" - self._cmp(filename, (224, 256), (3, 224, 256), "itkreader", "itkreader", output_name, ".png") - self._cmp(filename, (224, 256), (3, 224, 256), "itkreader", "PILReader", output_name, ".png") - self._cmp(filename, (224, 256), (3, 224, 256), "itkreader", "nibabelreader", output_name, ".png") + self._cmp(filename, (3, 224, 256), "itkreader", "itkreader", output_name, ".png") + self._cmp(filename, (3, 224, 256), "itkreader", "PILReader", output_name, ".png") + self._cmp(filename, (3, 224, 256), "itkreader", "nibabelreader", output_name, ".png") + + +class TestLoadImagedMeta(unittest.TestCase): + @classmethod + def setUpClass(cls): + super(__class__, cls).setUpClass() + cls.tmpdir = tempfile.mkdtemp() + test_image = nib.Nifti1Image(np.random.rand(128, 128, 128), np.eye(4)) + cls.test_data = {} + for key in KEYS: + nib.save(test_image, os.path.join(cls.tmpdir, key + ".nii.gz")) + cls.test_data.update({key: os.path.join(cls.tmpdir, key + ".nii.gz")}) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tmpdir) + super(__class__, cls).tearDownClass() + + @parameterized.expand(TESTS_META) + def test_correct(self, input_param, expected_shape, track_meta): + set_track_meta(track_meta) + result = LoadImaged(**input_param)(self.test_data) + + # shouldn't have any extra meta data keys + self.assertEqual(len(result), len(KEYS)) + for key in KEYS: + r = result[key] + self.assertTupleEqual(r.shape, expected_shape) + if track_meta: + self.assertIsInstance(r, MetaTensor) + self.assertTrue(hasattr(r, "affine")) + self.assertIsInstance(r.affine, torch.Tensor) + else: + self.assertIsInstance(r, torch.Tensor) + self.assertNotIsInstance(r, MetaTensor) + self.assertFalse(hasattr(r, "affine")) if __name__ == "__main__": diff --git a/tests/test_load_spacing_orientation.py b/tests/test_load_spacing_orientation.py index 2792822c3d..b98a8c8627 100644 --- a/tests/test_load_spacing_orientation.py +++ b/tests/test_load_spacing_orientation.py @@ -18,7 +18,7 @@ from nibabel.processing import resample_to_output from parameterized import parameterized -from monai.transforms import AddChanneld, LoadImaged, Orientationd, Spacingd +from monai.transforms import AddChanneld, Compose, FromMetaTensord, LoadImaged, Orientationd, Spacingd, ToNumpyd from monai.utils.enums import PostFix FILES = tuple( @@ -31,8 +31,15 @@ class TestLoadSpacingOrientation(unittest.TestCase): @parameterized.expand(FILES) def test_load_spacingd(self, filename): data = {"image": filename} - data_dict = LoadImaged(keys="image")(data) - data_dict = AddChanneld(keys="image")(data_dict) + t = Compose( + [ + LoadImaged(keys="image"), + FromMetaTensord(keys="image"), + AddChanneld(keys="image"), + ToNumpyd(keys=["image", "image_meta_dict"]), + ] + ) + data_dict = t(data) t = time.time() res_dict = Spacingd(keys="image", pixdim=(1, 0.2, 1), diagonal=True, padding_mode="zeros")(data_dict) t1 = time.time() @@ -49,8 +56,15 @@ def test_load_spacingd(self, filename): @parameterized.expand(FILES) def test_load_spacingd_rotate(self, filename): data = {"image": filename} - data_dict = LoadImaged(keys="image")(data) - data_dict = AddChanneld(keys="image")(data_dict) + t = Compose( + [ + LoadImaged(keys="image"), + FromMetaTensord(keys="image"), + AddChanneld(keys="image"), + ToNumpyd(keys=["image", "image_meta_dict"]), + ] + ) + data_dict = t(data) affine = data_dict[PostFix.meta("image")]["affine"] data_dict[PostFix.meta("image")]["original_affine"] = data_dict[PostFix.meta("image")]["affine"] = ( np.array([[0, 0, 1, 0], [0, 1, 0, 0], [-1, 0, 0, 0], [0, 0, 0, 1]]) @ affine @@ -75,8 +89,15 @@ def test_load_spacingd_rotate(self, filename): def test_load_spacingd_non_diag(self): data = {"image": FILES[1]} - data_dict = LoadImaged(keys="image")(data) - data_dict = AddChanneld(keys="image")(data_dict) + t = Compose( + [ + LoadImaged(keys="image"), + FromMetaTensord(keys="image"), + AddChanneld(keys="image"), + ToNumpyd(keys=["image", "image_meta_dict"]), + ] + ) + data_dict = t(data) affine = data_dict[PostFix.meta("image")]["affine"] data_dict[PostFix.meta("image")]["original_affine"] = data_dict[PostFix.meta("image")]["affine"] = ( np.array([[0, 0, 1, 0], [0, 1, 0, 0], [-1, 0, 0, 0], [0, 0, 0, 1]]) @ affine @@ -96,8 +117,15 @@ def test_load_spacingd_non_diag(self): def test_load_spacingd_rotate_non_diag(self): data = {"image": FILES[0]} - data_dict = LoadImaged(keys="image")(data) - data_dict = AddChanneld(keys="image")(data_dict) + t = Compose( + [ + LoadImaged(keys="image"), + FromMetaTensord(keys="image"), + AddChanneld(keys="image"), + ToNumpyd(keys=["image", "image_meta_dict"]), + ] + ) + data_dict = t(data) res_dict = Spacingd(keys="image", pixdim=(1, 2, 3), diagonal=False, padding_mode="border")(data_dict) np.testing.assert_allclose( res_dict[PostFix.meta("image")]["affine"], @@ -106,8 +134,15 @@ def test_load_spacingd_rotate_non_diag(self): def test_load_spacingd_rotate_non_diag_ornt(self): data = {"image": FILES[0]} - data_dict = LoadImaged(keys="image")(data) - data_dict = AddChanneld(keys="image")(data_dict) + t = Compose( + [ + LoadImaged(keys="image"), + FromMetaTensord(keys="image"), + AddChanneld(keys="image"), + ToNumpyd(keys=["image", "image_meta_dict"]), + ] + ) + data_dict = t(data) res_dict = Spacingd(keys="image", pixdim=(1, 2, 3), diagonal=False, padding_mode="border")(data_dict) res_dict = Orientationd(keys="image", axcodes="LPI")(res_dict) np.testing.assert_allclose( @@ -117,8 +152,15 @@ def test_load_spacingd_rotate_non_diag_ornt(self): def test_load_spacingd_non_diag_ornt(self): data = {"image": FILES[1]} - data_dict = LoadImaged(keys="image")(data) - data_dict = AddChanneld(keys="image")(data_dict) + t = Compose( + [ + LoadImaged(keys="image"), + FromMetaTensord(keys="image"), + AddChanneld(keys="image"), + ToNumpyd(keys=["image", "image_meta_dict"]), + ] + ) + data_dict = t(data) affine = data_dict[PostFix.meta("image")]["affine"] data_dict[PostFix.meta("image")]["original_affine"] = data_dict[PostFix.meta("image")]["affine"] = ( np.array([[0, 0, 1, 0], [0, 1, 0, 0], [-1, 0, 0, 0], [0, 0, 0, 1]]) @ affine diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index 05356fcc84..17fbb3cb35 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -198,6 +198,7 @@ def test_conv(self, device): conv = torch.nn.Conv3d(im.shape[1], 5, 3) conv.to(device) out = conv(im) + self.assertTrue(str(out).startswith("\nMetaData")) self.check(out, im, shape=False, vals=False, ids=False) @parameterized.expand(TESTS) @@ -272,14 +273,13 @@ def test_amp(self): def test_out(self): """Test when `out` is given as an argument.""" m1, _ = self.get_im() - m1_orig = deepcopy(m1) m2, _ = self.get_im() m3, _ = self.get_im() torch.add(m2, m3, out=m1) m1_add = m2 + m3 assert_allclose(m1, m1_add) - self.check_meta(m1, m1_orig) + # self.check_meta(m1, m2) # meta is from first input tensor @parameterized.expand(TESTS) def test_collate(self, device, dtype): diff --git a/tests/test_nifti_rw.py b/tests/test_nifti_rw.py index 2c0a8dc9a3..f9a987052b 100644 --- a/tests/test_nifti_rw.py +++ b/tests/test_nifti_rw.py @@ -85,11 +85,12 @@ def test_orientation(self, array, affine, reader_param, expected): # read test cases loader = LoadImage(**reader_param) load_result = loader(test_image) - if isinstance(load_result, tuple): - data_array, header = load_result - else: - data_array = load_result + data_array = load_result.numpy() + if reader_param.get("image_only", False): header = None + else: + header = load_result.meta + header["affine"] = header["affine"].numpy() if os.path.exists(test_image): os.remove(test_image) @@ -114,7 +115,8 @@ def test_orientation(self, array, affine, reader_param, expected): def test_consistency(self): np.set_printoptions(suppress=True, precision=3) test_image = make_nifti_image(np.arange(64).reshape(1, 8, 8), np.diag([1.5, 1.5, 1.5, 1])) - data, header = LoadImage(reader="NibabelReader", as_closest_canonical=False)(test_image) + data = LoadImage(reader="NibabelReader", as_closest_canonical=False)(test_image) + header = data.meta data, original_affine, new_affine = Spacing([0.8, 0.8, 0.8])(data[None], header["affine"], mode="nearest") data, _, new_affine = Orientation("ILP")(data, new_affine) if os.path.exists(test_image): diff --git a/tests/test_nifti_saver.py b/tests/test_nifti_saver.py index 6855a59041..bd1bf86207 100644 --- a/tests/test_nifti_saver.py +++ b/tests/test_nifti_saver.py @@ -80,7 +80,7 @@ def test_saved_3d_no_resize_content(self): saver.save_batch(torch.randint(0, 255, (8, 8, 1, 2, 2)), meta_data) for i in range(8): filepath = os.path.join(tempdir, "testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz") - img, _ = LoadImage("nibabelreader")(filepath) + img = LoadImage("nibabelreader")(filepath) self.assertEqual(img.shape, (1, 2, 2, 8)) def test_squeeze_end_dims(self): @@ -102,9 +102,8 @@ def test_squeeze_end_dims(self): # 2d image w channel saver.save(torch.randint(0, 255, (1, 2, 2)), meta_data) - im, meta = LoadImage()(os.path.join(tempdir, fname, fname + ".nii.gz")) + im = LoadImage()(os.path.join(tempdir, fname, fname + ".nii.gz")) self.assertTrue(im.ndim == 2 if squeeze_end_dims else 4) - self.assertTrue(meta["dim"][0] == im.ndim) if __name__ == "__main__": diff --git a/tests/test_numpy_reader.py b/tests/test_numpy_reader.py index c2f3679e33..bb7686f67d 100644 --- a/tests/test_numpy_reader.py +++ b/tests/test_numpy_reader.py @@ -19,7 +19,6 @@ from monai.data import DataLoader, Dataset, NumpyReader from monai.transforms import LoadImaged -from monai.utils.enums import PostFix class TestNumpyReader(unittest.TestCase): @@ -110,8 +109,6 @@ def test_dataloader(self): num_workers=num_workers, ) for d in loader: - for s in d[PostFix.meta("image")]["spatial_shape"]: - torch.testing.assert_allclose(s, torch.as_tensor([3, 4, 5])) for c in d["image"]: torch.testing.assert_allclose(c, test_data) diff --git a/tests/test_resample_to_match.py b/tests/test_resample_to_match.py index e1f6a28998..f6aa58c191 100644 --- a/tests/test_resample_to_match.py +++ b/tests/test_resample_to_match.py @@ -21,7 +21,7 @@ from monai.data.image_reader import ITKReader, NibabelReader from monai.data.image_writer import ITKWriter -from monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, ResampleToMatch, SaveImaged +from monai.transforms import Compose, EnsureChannelFirstd, FromMetaTensord, LoadImaged, ResampleToMatch, SaveImaged from tests.utils import assert_allclose, download_url_or_skip_test, testing_data_config TEST_CASES = ["itkreader", "nibabelreader"] @@ -41,7 +41,13 @@ def setUp(self): @parameterized.expand(itertools.product([NibabelReader, ITKReader], ["monai.data.NibabelWriter", ITKWriter])) def test_correct(self, reader, writer): with tempfile.TemporaryDirectory() as temp_dir: - loader = Compose([LoadImaged(("im1", "im2"), reader=reader), EnsureChannelFirstd(("im1", "im2"))]) + loader = Compose( + [ + LoadImaged(("im1", "im2"), reader=reader), + FromMetaTensord(("im1", "im2")), + EnsureChannelFirstd(("im1", "im2")), + ] + ) data = loader({"im1": self.fnames[0], "im2": self.fnames[1]}) im_mod, meta = ResampleToMatch()(data["im2"], data["im2_meta_dict"], data["im1_meta_dict"]) diff --git a/tests/test_resample_to_matchd.py b/tests/test_resample_to_matchd.py index d9dbeee133..14536891df 100644 --- a/tests/test_resample_to_matchd.py +++ b/tests/test_resample_to_matchd.py @@ -17,6 +17,7 @@ Compose, CopyItemsd, EnsureChannelFirstd, + FromMetaTensord, Invertd, Lambda, LoadImaged, @@ -47,6 +48,7 @@ def test_correct(self): transforms = Compose( [ LoadImaged(("im1", "im2")), + FromMetaTensord(("im1", "im2")), EnsureChannelFirstd(("im1", "im2")), CopyItemsd(("im2", "im2_meta_dict"), names=("im3", "im3_meta_dict")), ResampleToMatchd("im3", "im1_meta_dict"), diff --git a/tests/test_smartcachedataset.py b/tests/test_smartcachedataset.py index 6eca6113f0..9f9043d19e 100644 --- a/tests/test_smartcachedataset.py +++ b/tests/test_smartcachedataset.py @@ -17,10 +17,12 @@ import nibabel as nib import numpy as np +import torch from parameterized import parameterized from monai.data import DataLoader, SmartCacheDataset from monai.transforms import Compose, Lambda, LoadImaged +from tests.utils import assert_allclose TEST_CASE_1 = [0.1, 0, Compose([LoadImaged(keys=["image", "label", "extra"])])] @@ -77,8 +79,8 @@ def test_shape(self, replace_rate, num_replace_workers, transform): for _ in range(3): dataset.update_cache() self.assertIsNotNone(dataset[15]) - if isinstance(dataset[15]["image"], np.ndarray): - np.testing.assert_allclose(dataset[15]["image"], dataset[15]["label"]) + if isinstance(dataset[15]["image"], (np.ndarray, torch.Tensor)): + assert_allclose(dataset[15]["image"], dataset[15]["label"]) else: self.assertIsInstance(dataset[15]["image"], str) dataset.shutdown() diff --git a/tests/test_splitdimd.py b/tests/test_splitdimd.py index 6b164a3cb8..f204ec277d 100644 --- a/tests/test_splitdimd.py +++ b/tests/test_splitdimd.py @@ -13,9 +13,10 @@ from copy import deepcopy import numpy as np +import torch from parameterized import parameterized -from monai.transforms import LoadImaged +from monai.transforms import Compose, FromMetaTensord, LoadImaged from monai.transforms.utility.dictionary import SplitDimd from tests.utils import TEST_NDARRAYS, assert_allclose, make_nifti_image, make_rand_affine @@ -33,7 +34,8 @@ def setUpClass(cls): affine = make_rand_affine() data = {"i": make_nifti_image(arr, affine)} - cls.data = LoadImaged("i")(data) + loader = Compose([LoadImaged("i"), FromMetaTensord("i")]) + cls.data = loader(data) @parameterized.expand(TESTS) def test_correct(self, keepdim, im_type, update_meta): @@ -54,8 +56,10 @@ def test_correct(self, keepdim, im_type, update_meta): split_idx = deepcopy(idx) split_idx[dim] = 0 # idx[1:] to remove channel and then add 1 for 4th element - real_world = data["i_meta_dict"]["affine"] @ (idx[1:] + [1]) - real_world2 = out[f"i_{split_im_idx}_meta_dict"]["affine"] @ (split_idx[1:] + [1]) + real_world = data["i_meta_dict"]["affine"] @ torch.tensor(idx[1:] + [1]).double() + real_world2 = ( + out[f"i_{split_im_idx}_meta_dict"]["affine"] @ torch.tensor(split_idx[1:] + [1]).double() + ) assert_allclose(real_world, real_world2) out = out["i_0"] diff --git a/tests/test_warp.py b/tests/test_warp.py index c039b57211..56f1de23f2 100644 --- a/tests/test_warp.py +++ b/tests/test_warp.py @@ -153,6 +153,7 @@ def test_grad(self): def load_img_and_sample_ddf(): # load image img = LoadImaged(keys="img")({"img": FILE_PATH})["img"] + img = img.detach().numpy() # W, H, D -> D, H, W img = img.transpose((2, 1, 0)) diff --git a/tests/test_wsireader.py b/tests/test_wsireader.py index 5d092c4ce5..afce957469 100644 --- a/tests/test_wsireader.py +++ b/tests/test_wsireader.py @@ -20,7 +20,7 @@ from monai.data import DataLoader, Dataset from monai.data.image_reader import WSIReader -from monai.transforms import Compose, LoadImaged, ToTensord +from monai.transforms import Compose, FromMetaTensord, LoadImaged, ToTensord from monai.utils import first, optional_import from monai.utils.enums import PostFix from tests.utils import download_url_or_skip_test, testing_data_config @@ -193,6 +193,7 @@ def test_with_dataloader(self, file_path, level, expected_spatial_shape, expecte train_transform = Compose( [ LoadImaged(keys=["image"], reader=WSIReader, backend=self.backend, level=level), + FromMetaTensord(keys=["image"]), ToTensord(keys=["image"]), ] ) diff --git a/tests/test_wsireader_new.py b/tests/test_wsireader_new.py index 2ac4125f97..12f1879adf 100644 --- a/tests/test_wsireader_new.py +++ b/tests/test_wsireader_new.py @@ -20,7 +20,7 @@ from monai.data import DataLoader, Dataset from monai.data.wsi_reader import WSIReader -from monai.transforms import Compose, LoadImaged, ToTensord +from monai.transforms import Compose, FromMetaTensord, LoadImaged from monai.utils import first, optional_import from monai.utils.enums import PostFix from tests.utils import download_url_or_skip_test, testing_data_config @@ -200,7 +200,7 @@ def test_with_dataloader(self, file_path, level, expected_spatial_shape, expecte train_transform = Compose( [ LoadImaged(keys=["image"], reader=WSIReader, backend=self.backend, level=level), - ToTensord(keys=["image"]), + FromMetaTensord(keys=["image"]), ] ) dataset = Dataset([{"image": file_path}], transform=train_transform) diff --git a/tests/testing_data/inference.json b/tests/testing_data/inference.json index 8ef12c0d85..1f89811764 100644 --- a/tests/testing_data/inference.json +++ b/tests/testing_data/inference.json @@ -33,6 +33,14 @@ "_target_": "LoadImaged", "keys": "image" }, + { + "_target_": "FromMetaTensord", + "keys": "image" + }, + { + "_target_": "ToNumpyd", + "keys": "image" + }, { "_target_": "EnsureChannelFirstd", "keys": "image" diff --git a/tests/testing_data/inference.yaml b/tests/testing_data/inference.yaml index c63c10517f..623607ef2a 100644 --- a/tests/testing_data/inference.yaml +++ b/tests/testing_data/inference.yaml @@ -26,6 +26,10 @@ preprocessing: transforms: - _target_: LoadImaged keys: image + - _target_: FromMetaTensord + keys: image + - _target_: ToNumpyd + keys: image - _target_: EnsureChannelFirstd keys: image - _target_: ScaleIntensityd