Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
3495958
add swin_unetr model (#4074)
ahatamiz May 3, 2022
06ccf35
4217 Update PyTorch docker to 22.04 (#4218)
Nic-Ma May 5, 2022
2cf927c
Add InstanceNorm3dNVFuser support (#4194)
yiheng-wang-nv May 5, 2022
c9aee92
Update dice.py (#4234)
ryancinsight May 6, 2022
84c9c15
Bug fix and improvement in WSI (#4216)
bhashemian May 6, 2022
63ce977
Replace module (#4245)
rijobro May 9, 2022
a0b4100
Add GaussianSmooth as antialiasing filter in Resize (#4249)
Can-Zhao May 10, 2022
1cdff36
4235 fix 2204 nvfuser issue (#4241)
yiheng-wang-nv May 11, 2022
591931f
Update to Bundle Specifiation (#4250)
ericspod May 11, 2022
90e2ac9
Implement NrrdReader and NrrdImage classes
May 11, 2022
9535e36
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 11, 2022
f4139f2
run auto style fixes on image_reader.py
kbressem May 12, 2022
6a7b9e8
add NrrdReader to monai/data/__init__.py
kbressem May 12, 2022
59aed3e
Change the way spatial information is handled in NrrdReader
kbressem May 12, 2022
9c59784
add tests for NrrdReader
kbressem May 12, 2022
76329cb
Add NrrdReader to list of possible readers for LoadImage
kbressem May 12, 2022
46169c0
autofix formating
kbressem May 12, 2022
01e495a
autofix formating
kbressem May 12, 2022
c051993
change NrrdImage class to namedtuple and make flake8 happy
kbressem May 12, 2022
b7c3efb
Add pynrrd to requirements
kbressem May 13, 2022
03b59ea
correct typing for namedtumple
kbressem May 13, 2022
7f8a4c1
Add pynrrd info to `get_optional_config_values`
kbressem May 13, 2022
8f46bc5
Merge branch 'dev' into 4238-nrrd-reader
wyli May 13, 2022
b885c23
exclude test_nrrd_reader.py from min tests
kbressem May 13, 2022
6ad883b
Merge branch '4238-nrrd-reader' of http://github.com/kbressem/MONAI i…
kbressem May 13, 2022
7294d34
add pynrrd to config files
kbressem May 13, 2022
a5aa7e5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 13, 2022
4330fd8
Change the way space is handled in the header. Now, if space is not i…
kbressem May 13, 2022
ce9c883
add `TestLoadSaveNrrd` where it is tested if a nrrd file, created by …
kbressem May 13, 2022
7417d89
autofix format
kbressem May 13, 2022
544394d
Merge branch '4238-nrrd-reader' of http://github.com/kbressem/MONAI i…
kbressem May 13, 2022
37a3f11
Merge branch 'dev' into 4238-nrrd-reader
wyli May 13, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@ tifffile; platform_system == "Linux"
pyyaml
fire
jsonschema
pynrrd
4 changes: 2 additions & 2 deletions docs/source/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,9 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is

- The options are
```
[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema]
[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, pynrrd]
```
which correspond to `nibabel`, `scikit-image`, `pillow`, `tensorboard`,
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, respectively.
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `pynrrd`, respectively.

- `pip install 'monai[all]'` installs all the optional dependencies.
1 change: 1 addition & 0 deletions environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ dependencies:
- pyyaml
- fire
- jsonschema
- pynrrd
- pip
- pip:
# pip for itk as conda-forge version only up to v5.1
Expand Down
1 change: 1 addition & 0 deletions monai/config/deviceconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def get_optional_config_values():
output["einops"] = get_package_version("einops")
output["transformers"] = get_package_version("transformers")
output["mlflow"] = get_package_version("mlflow")
output["pynrrd"] = get_package_version("nrrd")

return output

Expand Down
2 changes: 1 addition & 1 deletion monai/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from .folder_layout import FolderLayout
from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter, PatchIterd
from .image_dataset import ImageDataset
from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader
from .image_reader import ImageReader, ITKReader, NibabelReader, NrrdReader, NumpyReader, PILReader
from .image_writer import (
SUPPORTED_WRITERS,
ImageWriter,
Expand Down
162 changes: 160 additions & 2 deletions monai/data/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import warnings
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

Expand All @@ -25,21 +26,23 @@
if TYPE_CHECKING:
import itk
import nibabel as nib
import nrrd
from nibabel.nifti1 import Nifti1Image
from PIL import Image as PILImage

has_itk = has_nib = has_pil = True
has_nrrd = has_itk = has_nib = has_pil = True
else:
itk, has_itk = optional_import("itk", allow_namespace_pkg=True)
nib, has_nib = optional_import("nibabel")
Nifti1Image, _ = optional_import("nibabel.nifti1", name="Nifti1Image")
PILImage, has_pil = optional_import("PIL.Image")
nrrd, has_nrrd = optional_import("nrrd", allow_namespace_pkg=True)

OpenSlide, _ = optional_import("openslide", name="OpenSlide")
CuImage, _ = optional_import("cucim", name="CuImage")
TiffFile, _ = optional_import("tifffile", name="TiffFile")

__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "WSIReader"]
__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "WSIReader", "NrrdReader"]


class ImageReader(ABC):
Expand Down Expand Up @@ -976,3 +979,158 @@ def _extract_patches(
idx += 1

return flat_patch_grid


@dataclass
class NrrdImage:
"""Class to wrap nrrd image array and metadata header"""

array: np.ndarray
header: dict


@require_pkg(pkg_name="nrrd")
class NrrdReader(ImageReader):
"""
Load NRRD format images based on pynrrd library.

Args:
channel_dim: the channel dimension of the input image, default is None.
This is used to set original_channel_dim in the meta data, EnsureChannelFirstD reads this field.
If None, `original_channel_dim` will be either `no_channel` or `0`.
NRRD files are usually "channel first".
dtype: dtype of the data array when loading image.
index_order: Specify whether the returned data array should be in C-order (‘C’) or Fortran-order (‘F’).
Numpy is usually in C-order, but default on the NRRD header is F
kwargs: additional args for `nrrd.read` API. more details about available args:
https://github.com/mhe/pynrrd/blob/master/nrrd/reader.py

"""

def __init__(
self,
channel_dim: Optional[int] = None,
dtype: Union[np.dtype, type, str, None] = np.float32,
index_order: str = "F",
**kwargs,
):
self.channel_dim = channel_dim
self.dtype = dtype
self.index_order = index_order
self.kwargs = kwargs

def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool:
"""
Verify whether the specified `filename` is supported by pynrrd reader.

Args:
filename: file name or a list of file names to read.
if a list of files, verify all the suffixes.

"""
suffixes: Sequence[str] = ["nrrd", "seg.nrrd"]
return has_nrrd and is_supported_format(filename, suffixes)

def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs) -> Union[Sequence[Any], Any]:
"""
Read image data from specified file or files.
Note that it returns a data object or a sequence of data objects.

Args:
data: file name or a list of file names to read.
kwargs: additional args for actual `read` API of 3rd party libs.

"""
img_: List = []
filenames: Sequence[PathLike] = ensure_tuple(data)
kwargs_ = self.kwargs.copy()
kwargs_.update(kwargs)
for name in filenames:
nrrd_image = NrrdImage(*nrrd.read(name, index_order=self.index_order, *kwargs_))
img_.append(nrrd_image)
return img_ if len(filenames) > 1 else img_[0]

def get_data(self, img: Union[NrrdImage, List[NrrdImage]]) -> Tuple[np.ndarray, Dict]:
"""
Extract data array and meta data from loaded image and return them.
This function must return two objects, the first is a numpy array of image data,
the second is a dictionary of meta data.

Args:
img: a `NrrdImage` loaded from an image file or a list of image objects.

"""
img_array: List[np.ndarray] = []
compatible_meta: Dict = {}

for i in ensure_tuple(img):
data = i.array.astype(self.dtype)
img_array.append(data)
header = dict(i.header)
if self.index_order == "C":
header = self._convert_f_to_c_order(header)
header["original_affine"] = self._get_affine(i)
header = self._switch_lps_ras(header)
header["affine"] = header["original_affine"].copy()
header["spatial_shape"] = header["sizes"]
[header.pop(k) for k in ("sizes", "space origin", "space directions")] # rm duplicated data in header

if self.channel_dim is None: # default to "no_channel" or -1
header["original_channel_dim"] = "no_channel" if len(data.shape) == len(header["spatial_shape"]) else 0
else:
header["original_channel_dim"] = self.channel_dim
_copy_compatible_dict(header, compatible_meta)

return _stack_images(img_array, compatible_meta), compatible_meta

def _get_affine(self, img: NrrdImage) -> np.ndarray:
"""
Get the affine matrix of the image, it can be used to correct
spacing, orientation or execute spatial transforms.

Args:
img: A `NrrdImage` loaded from image file

"""
direction = img.header["space directions"]
origin = img.header["space origin"]

x, y = direction.shape
affine_diam = min(x, y) + 1
affine: np.ndarray = np.eye(affine_diam)
affine[:x, :y] = direction
affine[: (affine_diam - 1), -1] = origin # len origin is always affine_diam - 1
return affine

def _switch_lps_ras(self, header: dict) -> dict:
"""
For compatibility with nibabel, switch from LPS to RAS. Adapt affine matrix and
`space` argument in header accordingly. If no information of space is given in the header,
LPS is assumed and thus converted to RAS. If information about space is given,
but is not LPS, the unchanged header is returned.

Args:
header: The image meta data as dict

"""
if "space" not in header or header["space"] == "left-posterior-superior":
header["space"] = "right-anterior-superior"
header["original_affine"] = orientation_ras_lps(header["original_affine"])
return header

def _convert_f_to_c_order(self, header: dict) -> dict:
"""
All header fields of a NRRD are specified in `F` (Fortran) order, even if the image was read as C-ordered array.
1D arrays of header['space origin'] and header['sizes'] become inverted, e.g, [1,2,3] -> [3,2,1]
The 2D Array for header['space directions'] is transposed: [[1,0,0],[0,2,0],[0,0,3]] -> [[3,0,0],[0,2,0],[0,0,1]]
For more details refer to: https://pynrrd.readthedocs.io/en/latest/user-guide.html#index-ordering

Args:
header: The image meta data as dict

"""

header["space directions"] = np.rot90(np.flip(header["space directions"], 0))
header["space origin"] = header["space origin"][::-1]
header["sizes"] = header["sizes"][::-1]
return header
6 changes: 4 additions & 2 deletions monai/transforms/io/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from monai.config import DtypeLike, NdarrayOrTensor, PathLike
from monai.data import image_writer
from monai.data.folder_layout import FolderLayout
from monai.data.image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader
from monai.data.image_reader import ImageReader, ITKReader, NibabelReader, NrrdReader, NumpyReader, PILReader
from monai.transforms.transform import Transform
from monai.transforms.utility.array import EnsureChannelFirst
from monai.utils import GridSampleMode, GridSamplePadMode
Expand All @@ -37,11 +37,13 @@

nib, _ = optional_import("nibabel")
Image, _ = optional_import("PIL.Image")
nrrd, _ = optional_import("nrrd")

__all__ = ["LoadImage", "SaveImage", "SUPPORTED_READERS"]

SUPPORTED_READERS = {
"itkreader": ITKReader,
"nrrdreader": NrrdReader,
"numpyreader": NumpyReader,
"pilreader": PILReader,
"nibabelreader": NibabelReader,
Expand Down Expand Up @@ -85,7 +87,7 @@ class LoadImage(Transform):
- User-specified reader in the constructor of `LoadImage`.
- Readers from the last to the first in the registered list.
- Current default readers: (nii, nii.gz -> NibabelReader), (png, jpg, bmp -> PILReader),
(npz, npy -> NumpyReader), (DICOM file -> ITKReader).
(npz, npy -> NumpyReader), (nrrd -> NrrdReader), (DICOM file -> ITKReader).

See also:

Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,4 @@ types-PyYAML
pyyaml
fire
jsonschema
pynrrd
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ all =
pyyaml
fire
jsonschema
pynrrd
nibabel =
nibabel
skimage =
Expand Down Expand Up @@ -101,6 +102,8 @@ fire =
fire
jsonschema =
jsonschema
pynrrd =
pynrrd

[flake8]
select = B,C,E,F,N,P,T4,W,B9
Expand Down
1 change: 1 addition & 0 deletions tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def run_testsuit():
"test_nifti_header_revise",
"test_nifti_rw",
"test_nifti_saver",
"test_nrrd_reader",
"test_occlusion_sensitivity",
"test_orientation",
"test_orientationd",
Expand Down
35 changes: 34 additions & 1 deletion tests/test_image_rw.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numpy as np
from parameterized import parameterized

from monai.data.image_reader import ITKReader, NibabelReader, PILReader
from monai.data.image_reader import ITKReader, NibabelReader, NrrdReader, PILReader
from monai.data.image_writer import ITKWriter, NibabelWriter, PILWriter, register_writer, resolve_writer
from monai.transforms import LoadImage, SaveImage, moveaxis
from monai.utils import OptionalImportError
Expand Down Expand Up @@ -132,5 +132,38 @@ def test_1_new(self):
self.assertEqual(resolve_writer("new")[0](0), 1)


class TestLoadSaveNrrd(unittest.TestCase):
def setUp(self):
self.test_dir = tempfile.mkdtemp()

def tearDown(self):
shutil.rmtree(self.test_dir, ignore_errors=True)

def nrrd_rw(self, test_data, reader, writer, dtype, resample=True):
test_data = test_data.astype(dtype)
ndim = len(test_data.shape)
for p in TEST_NDARRAYS:
output_ext = ".nrrd"
filepath = f"testfile_{ndim}d"
saver = SaveImage(
output_dir=self.test_dir, output_ext=output_ext, resample=resample, separate_folder=False, writer=writer
)
saver(p(test_data), {"filename_or_obj": f"{filepath}{output_ext}", "spatial_shape": test_data.shape})
saved_path = os.path.join(self.test_dir, filepath + "_trans" + output_ext)
loader = LoadImage(reader=reader)
data, meta = loader(saved_path)
assert_allclose(data, test_data)

@parameterized.expand(itertools.product([NrrdReader, ITKReader], [ITKWriter, ITKWriter]))
def test_2d(self, reader, writer):
test_data = np.random.randn(8, 8).astype(np.float32)
self.nrrd_rw(test_data, reader, writer, np.float32)

@parameterized.expand(itertools.product([NrrdReader, ITKReader], [ITKWriter, ITKWriter]))
def test_3d(self, reader, writer):
test_data = np.random.randn(8, 8, 8).astype(np.float32)
self.nrrd_rw(test_data, reader, writer, np.float32)


if __name__ == "__main__":
unittest.main()
8 changes: 6 additions & 2 deletions tests/test_init_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import unittest

from monai.data import ITKReader, NibabelReader, NumpyReader, PILReader
from monai.data import ITKReader, NibabelReader, NrrdReader, NumpyReader, PILReader
from monai.transforms import LoadImage, LoadImaged
from tests.utils import SkipIfNoModule

Expand All @@ -23,13 +23,14 @@ def test_load_image(self):
self.assertIsInstance(instance1, LoadImage)
self.assertIsInstance(instance2, LoadImage)

for r in ["NibabelReader", "PILReader", "ITKReader", "NumpyReader", None]:
for r in ["NibabelReader", "PILReader", "ITKReader", "NumpyReader", "NrrdReader", None]:
inst = LoadImaged("image", reader=r)
self.assertIsInstance(inst, LoadImaged)

@SkipIfNoModule("itk")
@SkipIfNoModule("nibabel")
@SkipIfNoModule("PIL")
@SkipIfNoModule("nrrd")
def test_readers(self):
inst = ITKReader()
self.assertIsInstance(inst, ITKReader)
Expand All @@ -47,6 +48,9 @@ def test_readers(self):
inst = PILReader()
self.assertIsInstance(inst, PILReader)

inst = NrrdReader()
self.assertIsInstance(inst, NrrdReader)


if __name__ == "__main__":
unittest.main()
Loading