diff --git a/docs/source/data.rst b/docs/source/data.rst index 022f7877d1..6e7f5e2773 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -74,7 +74,7 @@ Generic Interfaces .. autoclass:: ImageDataset :members: :special-members: __getitem__ - + `NPZDictItemDataset` ~~~~~~~~~~~~~~~~~~~~ .. autoclass:: NPZDictItemDataset @@ -108,6 +108,11 @@ Patch-based dataset Image reader ------------ +ImageReader +~~~~~~~~~~~ +.. autoclass:: ImageReader + :members: + ITKReader ~~~~~~~~~ .. autoclass:: ITKReader diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 0c736a548d..1e3a89eb31 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -45,16 +45,30 @@ class ImageReader(ABC): - """Abstract class to define interface APIs to load image files. - users need to call `read` to load image and then use `get_data` - to get the image data and properties from meta data. + """ + An abstract class defines APIs to load image files. + + Typical usage of an implementation of this class is: + + .. code-block:: python + + image_reader = MyImageReader() + img_obj = image_reader.read(path_to_image) + img_data, meta_data = image_reader.get_data(img_obj) + + - The `read` call converts image filenames into image objects, + - The `get_data` call fetches the image data, as well as meta data. + - A reader should implement `verify_suffix` with the logic of checking the input filename + by the filename extensions. """ @abstractmethod def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: """ - Verify whether the specified file or files format is supported by current reader. + Verify whether the specified `filename` is supported by the current reader. + This method should return True if the reader is able to read the format suggested by the + `filename`. Args: filename: file name or a list of file names to read. @@ -67,7 +81,7 @@ def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: def read(self, data: Union[Sequence[str], str], **kwargs) -> Union[Sequence[Any], Any]: """ Read image data from specified file or files. - Note that it returns the raw data, so different readers return different image data type. + Note that it returns a data object or a sequence of data objects. Args: data: file name or a list of file names to read. @@ -80,7 +94,8 @@ def read(self, data: Union[Sequence[str], str], **kwargs) -> Union[Sequence[Any] def get_data(self, img) -> Tuple[np.ndarray, Dict]: """ Extract data array and meta data from loaded image and return them. - This function must return 2 objects, first is numpy array of image data, second is dict of meta data. + This function must return two objects, the first is a numpy array of image data, + the second is a dictionary of meta data. Args: img: an image object loaded from an image file or a list of image objects. @@ -124,7 +139,7 @@ def _stack_images(image_list: List, meta_dict: Dict): class ITKReader(ImageReader): """ Load medical images based on ITK library. - All the supported image formats can be found: + All the supported image formats can be found at: https://github.com/InsightSoftwareConsortium/ITK/tree/master/Modules/IO The loaded data array will be in C order, for example, a 3D image NumPy array index order will be `CDWH`. @@ -396,7 +411,10 @@ def _get_meta_dict(self, img) -> Dict: """ # swap to little endian as PyTorch doesn't support big endian - header = img.header.as_byteswapped("<") + try: + header = img.header.as_byteswapped("<") + except ValueError: + header = img.header return dict(header) def _get_affine(self, img): @@ -419,11 +437,18 @@ def _get_spatial_shape(self, img): """ # swap to little endian as PyTorch doesn't support big endian - header = img.header.as_byteswapped("<") - ndim = header["dim"][0] + try: + header = img.header.as_byteswapped("<") + except ValueError: + header = img.header + dim = header.get("dim", None) + if dim is None: + dim = header.get("dims") # mgh format? + dim = np.insert(dim, 0, 3) + ndim = dim[0] spatial_rank = min(ndim, 3) # the img data should have no channel dim or the last dim is channel - return np.asarray(header["dim"][1 : spatial_rank + 1]) + return np.asarray(dim[1 : spatial_rank + 1]) def _get_array_data(self, img): """ diff --git a/monai/data/nifti_saver.py b/monai/data/nifti_saver.py index 2aa9b44058..b7067def73 100644 --- a/monai/data/nifti_saver.py +++ b/monai/data/nifti_saver.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from pathlib import Path from typing import Dict, Optional, Union import numpy as np @@ -36,7 +37,7 @@ class NiftiSaver: def __init__( self, - output_dir: str = "./", + output_dir: Union[Path, str] = "./", output_postfix: str = "seg", output_ext: str = ".nii.gz", resample: bool = True, diff --git a/monai/data/png_saver.py b/monai/data/png_saver.py index d0aa787850..e6fb641cca 100644 --- a/monai/data/png_saver.py +++ b/monai/data/png_saver.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from pathlib import Path from typing import Dict, Optional, Union import numpy as np @@ -33,7 +34,7 @@ class PNGSaver: def __init__( self, - output_dir: str = "./", + output_dir: Union[Path, str] = "./", output_postfix: str = "seg", output_ext: str = ".png", resample: bool = True, diff --git a/monai/data/utils.py b/monai/data/utils.py index 737b2f84b5..25b3c24e4a 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -19,7 +19,7 @@ from copy import deepcopy from functools import reduce from itertools import product, starmap -from pathlib import PurePath +from pathlib import Path, PurePath from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np @@ -492,6 +492,8 @@ def correct_nifti_header_if_necessary(img_nii): Args: img_nii: nifti image object """ + if img_nii.header.get("dim") is None: + return img_nii # not nifti? dim = img_nii.header["dim"][0] if dim >= 5: return img_nii # do nothing for high-dimensional array @@ -677,7 +679,7 @@ def to_affine_nd(r: Union[np.ndarray, int], affine: np.ndarray) -> np.ndarray: def create_file_basename( postfix: str, input_file_name: str, - folder_path: str, + folder_path: Union[Path, str], data_root_dir: str = "", separate_folder: bool = True, patch_index: Optional[int] = None, diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 99d9c2b8b8..d15b8866e5 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -193,7 +193,7 @@ ) from .inverse import InvertibleTransform from .inverse_batch_transform import BatchInverseTransform, Decollated -from .io.array import LoadImage, SaveImage +from .io.array import SUPPORTED_READERS, LoadImage, SaveImage from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict from .nvtx import ( Mark, diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index a8e9ed1e7c..b8e2f75508 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -13,7 +13,11 @@ https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design """ +import inspect +import logging import sys +import warnings +from pathlib import Path from typing import Dict, List, Optional, Sequence, Union import numpy as np @@ -32,7 +36,14 @@ nib, _ = optional_import("nibabel") Image, _ = optional_import("PIL.Image") -__all__ = ["LoadImage", "SaveImage"] +__all__ = ["LoadImage", "SaveImage", "SUPPORTED_READERS"] + +SUPPORTED_READERS = { + "itkreader": ITKReader, + "numpyreader": NumpyReader, + "pilreader": PILReader, + "nibabelreader": NibabelReader, +} def switch_endianness(data, new="<"): @@ -57,87 +68,104 @@ def switch_endianness(data, new="<"): data = [switch_endianness(x, new) for x in data] elif isinstance(data, dict): data = {k: switch_endianness(v, new) for k, v in data.items()} - elif isinstance(data, (bool, str, float, int, type(None))): - pass - else: - raise AssertionError(f"Unknown type: {type(data).__name__}") + elif not isinstance(data, (bool, str, float, int, type(None))): + raise RuntimeError(f"Unknown type: {type(data).__name__}") return data class LoadImage(Transform): """ Load image file or files from provided path based on reader. - Automatically choose readers based on the supported suffixes and in below order: - - User specified reader at runtime when call this loader. - - Registered readers from the latest to the first in list. - - Default readers: (nii, nii.gz -> NibabelReader), (png, jpg, bmp -> PILReader), - (npz, npy -> NumpyReader), (others -> ITKReader). + If reader is not specified, this class automatically chooses readers + based on the supported suffixes and in the following order: + + - User-specified reader at runtime when calling this loader. + - User-specified reader in the constructor of `LoadImage`. + - Readers from the last to the first in the registered list. + - Current default readers: (nii, nii.gz -> NibabelReader), (png, jpg, bmp -> PILReader), + (npz, npy -> NumpyReader), (others -> ITKReader). + + See also: + + - tutorial: https://github.com/Project-MONAI/tutorials/blob/master/modules/load_medical_images.ipynb """ - def __init__( - self, - reader: Optional[Union[ImageReader, str]] = None, - image_only: bool = False, - dtype: DtypeLike = np.float32, - *args, - **kwargs, - ) -> None: + def __init__(self, reader=None, image_only: bool = False, dtype: DtypeLike = np.float32, *args, **kwargs) -> None: """ Args: - reader: register reader to load image file and meta data, if None, still can register readers - at runtime or use the default readers. If a string of reader name provided, will construct - a reader object with the `*args` and `**kwargs` parameters, supported reader name: "NibabelReader", - "PILReader", "ITKReader", "NumpyReader". + reader: reader to load image file and meta data + + - if `reader` is None, a default set of `SUPPORTED_READERS` will be used. + - if `reader` is a string, the corresponding item in `SUPPORTED_READERS` will be used, + and a reader instance will be constructed with the `*args` and `**kwargs` parameters. + the supported reader names are: "nibabelreader", "pilreader", "itkreader", "numpyreader". + - if `reader` is a reader class/instance, it will be registered to this loader accordingly. + image_only: if True return only the image volume, otherwise return image data array and header dict. dtype: if not None convert the loaded image to this data type. args: additional parameters for reader if providing a reader name. kwargs: additional parameters for reader if providing a reader name. Note: - The transform returns image data array if `image_only` is True, - or a tuple of two elements containing the data array, and the meta data in a dict format otherwise. + + - The transform returns an image data array if `image_only` is True, + or a tuple of two elements containing the data array, and the meta data in a dictionary format otherwise. + - If `reader` is specified, the loader will attempt to use the specified readers and the default supported + readers. This might introduce overheads when handling the exceptions of trying the incompatible loaders. + In this case, it is therefore recommended to set the most appropriate reader as + the last item of the `reader` parameter. """ - # set predefined readers as default - self.readers: List[ImageReader] = [ITKReader(), NumpyReader(), PILReader(), NibabelReader()] - if reader is not None: - if isinstance(reader, str): - supported_readers = { - "nibabelreader": NibabelReader, - "pilreader": PILReader, - "itkreader": ITKReader, - "numpyreader": NumpyReader, - } - the_reader = look_up_option(reader.lower(), supported_readers) - self.register(the_reader(*args, **kwargs)) - else: - self.register(reader) + self.auto_select = reader is None self.image_only = image_only self.dtype = dtype - def register(self, reader: ImageReader) -> List[ImageReader]: + self.readers: List[ImageReader] = [] + for r in SUPPORTED_READERS: # set predefined readers as default + try: + self.register(SUPPORTED_READERS[r](*args, **kwargs)) + except TypeError: # the reader doesn't have the corresponding args/kwargs + logging.getLogger(self.__class__.__name__).debug( + f"{r} is not supported with the given parameters {args} {kwargs}." + ) + self.register(SUPPORTED_READERS[r]()) + if reader is None: + return # no user-specified reader, no need to register + + for _r in ensure_tuple(reader): + if isinstance(_r, str): + the_reader = look_up_option(_r.lower(), SUPPORTED_READERS) + try: + self.register(the_reader(*args, **kwargs)) + except TypeError: # the reader doesn't have the corresponding args/kwargs + warnings.warn(f"{r} is not supported with the given parameters {args} {kwargs}.") + self.register(the_reader()) + elif inspect.isclass(_r): + self.register(_r(*args, **kwargs)) + else: + self.register(_r) # reader instance, ignoring the constructor args/kwargs + return + + def register(self, reader: ImageReader): """ - Register image reader to load image file and meta data, latest registered reader has higher priority. - Return all the registered image readers. + Register image reader to load image file and meta data. Args: - reader: registered reader to load image file and meta data based on suffix, - if all registered readers can't match suffix at runtime, use the default readers. + reader: reader instance to be registered with this loader. """ if not isinstance(reader, ImageReader): - raise ValueError(f"reader must be ImageReader object, but got {type(reader)}.") + warnings.warn(f"Preferably the reader should inherit ImageReader, but got {type(reader)}.") self.readers.append(reader) - return self.readers - def __call__( - self, - filename: Union[Sequence[str], str], - reader: Optional[ImageReader] = None, - ): + def __call__(self, filename: Union[Sequence[str], str, Path, Sequence[Path]], reader: Optional[ImageReader] = None): """ + Load image file and meta data from the given filename(s). + If `reader` is not specified, this class automatically chooses readers based on the + reversed order of registered readers `self.readers`. + Args: filename: path file or file-like object or a list of files. will save the filename to meta_data with key `filename_or_obj`. @@ -145,21 +173,34 @@ def __call__( reader: runtime reader to load image file and meta data. """ - if reader is None or not reader.verify_suffix(filename): - for r in reversed(self.readers): - if r.verify_suffix(filename): - reader = r - break - - if reader is None: + filename = tuple(str(s) for s in ensure_tuple(filename)) # allow Path objects + img = None + if reader is not None: + img = reader.read(filename) # runtime specified reader + else: + for reader in self.readers[::-1]: + if self.auto_select: # rely on the filename extension to choose the reader + if reader.verify_suffix(filename): + img = reader.read(filename) + break + else: # try the user designated readers + try: + img = reader.read(filename) + except Exception as e: + logging.getLogger(self.__class__.__name__).debug( + f"{reader.__class__.__name__}: unable to load {filename}.\n" f"Error: {e}" + ) + else: + break + + if img is None or reader is None: raise RuntimeError( - f"can not find suitable reader for this file: {filename}. \ - Please install dependency libraries: (nii, nii.gz) -> Nibabel, (png, jpg, bmp) -> PIL, \ - (npz, npy) -> Numpy, others -> ITK. Refer to the installation instruction: \ - https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies." + f"can not find a suitable reader for file: {filename}.\n" + " Please install the reader libraries, see also the installation instructions:\n" + " https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies.\n" + f" The current registered: {self.readers}.\n" ) - img = reader.read(filename) img_array, meta_data = reader.get_data(img) img_array = img_array.astype(self.dtype) @@ -241,7 +282,7 @@ class SaveImage(Transform): def __init__( self, - output_dir: str = "./", + output_dir: Union[Path, str] = "./", output_postfix: str = "trans", output_ext: str = ".nii.gz", resample: bool = True, @@ -256,7 +297,7 @@ def __init__( print_log: bool = True, ) -> None: self.saver: Union[NiftiSaver, PNGSaver] - if output_ext in (".nii.gz", ".nii"): + if output_ext in {".nii.gz", ".nii"}: self.saver = NiftiSaver( output_dir=output_dir, output_postfix=output_postfix, diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index db043848c7..764e20f838 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -15,6 +15,7 @@ Class names are ended with 'd' to denote dictionary-based transforms. """ +from pathlib import Path from typing import Optional, Union import numpy as np @@ -38,17 +39,31 @@ class LoadImaged(MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.LoadImage`, - must load image and metadata together. If loading a list of files in one key, - stack them together and add a new dimension as the first dimension, and use the - meta data of the first image to represent the stacked result. Note that the affine - transform of all the stacked images should be same. The output metadata field will - be created as ``meta_keys`` or ``key_{meta_key_postfix}``. + It can load both image data and metadata. When loading a list of files in one key, + the arrays will be stacked and a new dimension will be added as the first dimension + In this case, the meta data of the first image will be used to represent the stacked result. + The affine transform of all the stacked images should be same. + The output metadata field will be created as ``meta_keys`` or ``key_{meta_key_postfix}``. - It can automatically choose readers based on the supported suffixes and in below order: - - User specified reader at runtime when call this loader. - - Registered readers from the latest to the first in list. - - Default readers: (nii, nii.gz -> NibabelReader), (png, jpg, bmp -> PILReader), - (npz, npy -> NumpyReader), (others -> ITKReader). + If reader is not specified, this class automatically chooses readers + based on the supported suffixes and in the following order: + + - User-specified reader at runtime when calling this loader. + - User-specified reader in the constructor of `LoadImage`. + - Readers from the last to the first in the registered list. + - Current default readers: (nii, nii.gz -> NibabelReader), (png, jpg, bmp -> PILReader), + (npz, npy -> NumpyReader), (others -> ITKReader). + + Note: + + - If `reader` is specified, the loader will attempt to use the specified readers and the default supported + readers. This might introduce overheads when handling the exceptions of trying the incompatible loaders. + In this case, it is therefore recommended to set the most appropriate reader as + the last item of the `reader` parameter. + + See also: + + - tutorial: https://github.com/Project-MONAI/tutorials/blob/master/modules/load_medical_images.ipynb """ @@ -209,7 +224,7 @@ def __init__( keys: KeysCollection, meta_keys: Optional[KeysCollection] = None, meta_key_postfix: str = "meta_dict", - output_dir: str = "./", + output_dir: Union[Path, str] = "./", output_postfix: str = "trans", output_ext: str = ".nii.gz", resample: bool = True, diff --git a/tests/test_load_image.py b/tests/test_load_image.py index 7b325e7565..2aa6eced65 100644 --- a/tests/test_load_image.py +++ b/tests/test_load_image.py @@ -12,6 +12,7 @@ import os import tempfile import unittest +from pathlib import Path import itk import nibabel as nib @@ -22,6 +23,23 @@ from monai.data import ITKReader, NibabelReader from monai.transforms import LoadImage + +class _MiniReader: + """a test case customised reader""" + + def __init__(self, is_compatible=False): + self.is_compatible = is_compatible + + def verify_suffix(self, _name): + return self.is_compatible + + def read(self, name): + return name + + def get_data(self, _obj): + return np.zeros((1, 1, 1)), {"name": "my test"} + + TEST_CASE_1 = [{"image_only": True}, ["test_image.nii.gz"], (128, 128, 128)] TEST_CASE_2 = [{"image_only": False}, ["test_image.nii.gz"], (128, 128, 128)] @@ -32,12 +50,24 @@ (3, 128, 128, 128), ] +TEST_CASE_3_1 = [ # .mgz format + {"image_only": True, "reader": "nibabelreader"}, + ["test_image.mgz", "test_image2.mgz", "test_image3.mgz"], + (3, 128, 128, 128), +] + TEST_CASE_4 = [ {"image_only": False}, ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], (3, 128, 128, 128), ] +TEST_CASE_4_1 = [ # additional parameter + {"image_only": False, "mmap": False}, + ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], + (3, 128, 128, 128), +] + TEST_CASE_5 = [ {"reader": NibabelReader(mmap=False), "image_only": False}, ["test_image.nii.gz"], @@ -74,7 +104,9 @@ class TestLoadImage(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + @parameterized.expand( + [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_3_1, TEST_CASE_4, TEST_CASE_4_1, TEST_CASE_5] + ) def test_nibabel_reader(self, input_param, filenames, expected_shape): test_image = np.random.rand(128, 128, 128) with tempfile.TemporaryDirectory() as tempdir: @@ -135,7 +167,7 @@ def test_itk_reader_multichannel(self): filename = os.path.join(tempdir, "test_image.png") itk_np_view = itk.image_view_from_array(test_image, is_vector=True) itk.imwrite(itk_np_view, filename) - result, header = LoadImage(reader=ITKReader())(filename) + result, header = LoadImage(reader=ITKReader())(Path(filename)) self.assertTupleEqual(tuple(header["spatial_shape"]), (224, 256)) np.testing.assert_allclose(result[:, :, 0], test_image[:, :, 0].T) @@ -169,7 +201,6 @@ def test_register(self): def test_kwargs(self): spatial_size = (32, 64, 128) - expected_shape = (128, 64, 32) test_image = np.random.rand(*spatial_size) with tempfile.TemporaryDirectory() as tempdir: filename = os.path.join(tempdir, "test_image.nii.gz") @@ -187,6 +218,18 @@ def test_kwargs(self): np.testing.assert_allclose(header["spatial_shape"], header_raw["spatial_shape"]) self.assertTupleEqual(result.shape, result_raw.shape) + def test_my_reader(self): + """test customised readers""" + out = LoadImage(reader=_MiniReader, is_compatible=True)("test") + self.assertEqual(out[1]["name"], "my test") + out = LoadImage(reader=_MiniReader, is_compatible=False)("test") + self.assertEqual(out[1]["name"], "my test") + for item in (_MiniReader, _MiniReader(is_compatible=False)): + out = LoadImage(reader=item)("test") + self.assertEqual(out[1]["name"], "my test") + out = LoadImage()("test", reader=_MiniReader(is_compatible=False)) + self.assertEqual(out[1]["name"], "my test") + if __name__ == "__main__": unittest.main() diff --git a/tests/test_load_imaged.py b/tests/test_load_imaged.py index 2877b1cd57..ca5b56a7d9 100644 --- a/tests/test_load_imaged.py +++ b/tests/test_load_imaged.py @@ -12,6 +12,7 @@ import os import tempfile import unittest +from pathlib import Path import itk import nibabel as nib @@ -52,7 +53,7 @@ def test_register(self): loader = LoadImaged(keys="img") loader.register(ITKReader()) - result = loader({"img": filename}) + result = loader({"img": Path(filename)}) self.assertTupleEqual(tuple(result["img_meta_dict"]["spatial_shape"]), spatial_size[::-1]) self.assertTupleEqual(result["img"].shape, spatial_size[::-1]) @@ -69,6 +70,12 @@ def test_channel_dim(self): self.assertTupleEqual(tuple(result["img_meta_dict"]["spatial_shape"]), (32, 64, 128)) self.assertTupleEqual(result["img"].shape, (3, 32, 64, 128)) + def test_no_file(self): + with self.assertRaises(RuntimeError): + LoadImaged(keys="img")({"img": "unknown"}) + with self.assertRaises(RuntimeError): + LoadImaged(keys="img", reader="nibabelreader")({"img": "unknown"}) + class TestConsistency(unittest.TestCase): def _cmp(self, filename, shape, ch_shape, reader_1, reader_2, outname, ext): diff --git a/tests/test_nifti_endianness.py b/tests/test_nifti_endianness.py index 39cbed7795..bf0f27b9ca 100644 --- a/tests/test_nifti_endianness.py +++ b/tests/test_nifti_endianness.py @@ -12,6 +12,7 @@ import os import tempfile import unittest +from pathlib import Path from typing import TYPE_CHECKING, List, Tuple from unittest.case import skipUnless @@ -85,6 +86,9 @@ def test_switch(self): # verify data types with self.assertRaises(NotImplementedError): switch_endianness(np.zeros((2, 1)), "=") + with self.assertRaises(RuntimeError): + switch_endianness(Path("test"), "<") + @skipUnless(has_pil, "Requires PIL") def test_pil(self): tempdir = tempfile.mkdtemp() diff --git a/tests/test_nifti_saver.py b/tests/test_nifti_saver.py index f48374a61c..c07084172f 100644 --- a/tests/test_nifti_saver.py +++ b/tests/test_nifti_saver.py @@ -12,6 +12,7 @@ import os import tempfile import unittest +from pathlib import Path import numpy as np import torch @@ -24,7 +25,7 @@ class TestNiftiSaver(unittest.TestCase): def test_saved_content(self): with tempfile.TemporaryDirectory() as tempdir: - saver = NiftiSaver(output_dir=tempdir, output_postfix="seg", output_ext=".nii.gz") + saver = NiftiSaver(output_dir=Path(tempdir), output_postfix="seg", output_ext=".nii.gz") meta_data = {"filename_or_obj": ["testfile" + str(i) + ".nii" for i in range(8)]} saver.save_batch(torch.zeros(8, 1, 2, 2), meta_data) diff --git a/tests/test_png_saver.py b/tests/test_png_saver.py index dbc41dfd75..f8ea1df54b 100644 --- a/tests/test_png_saver.py +++ b/tests/test_png_saver.py @@ -12,6 +12,7 @@ import os import tempfile import unittest +from pathlib import Path import torch @@ -33,7 +34,7 @@ def test_saved_content(self): def test_saved_content_three_channel(self): with tempfile.TemporaryDirectory() as tempdir: - saver = PNGSaver(output_dir=tempdir, output_postfix="seg", output_ext=".png", scale=255) + saver = PNGSaver(output_dir=Path(tempdir), output_postfix="seg", output_ext=".png", scale=255) meta_data = {"filename_or_obj": ["testfile" + str(i) + ".jpg" for i in range(8)]} saver.save_batch(torch.randint(1, 200, (8, 3, 2, 2)), meta_data)