Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions monai/data/nifti_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import nibabel as nib

import numpy as np
from torch.utils.data import Dataset
from torch.utils.data._utils.collate import np_str_obj_array_pattern
from monai.utils.module import export

from monai.data.utils import correct_nifti_header_if_necessary
from monai.transforms.compose import Randomizable
from monai.utils.module import export


def load_nifti(filename_or_obj, as_closest_canonical=False, image_only=True, dtype=None):
Expand All @@ -38,6 +39,7 @@ def load_nifti(filename_or_obj, as_closest_canonical=False, image_only=True, dty
"""

img = nib.load(filename_or_obj)
img = correct_nifti_header_if_necessary(img)

header = dict(img.header)
header['filename_or_obj'] = filename_or_obj
Expand Down
62 changes: 62 additions & 0 deletions monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
import math
from itertools import starmap, product
from torch.utils.data._utils.collate import default_collate
Expand Down Expand Up @@ -191,3 +192,64 @@ def list_data_collate(batch):
elem = batch[0]
data = [i for k in batch for i in k] if isinstance(elem, list) else batch
return default_collate(data)


def correct_nifti_header_if_necessary(img_nii):
"""
check nifti object header's format, update the header if needed.
in the updated image pixdim matches the affine.

Args:
img (nifti image object)
"""
dim = img_nii.header['dim'][0]
if dim >= 5:
return img_nii # do nothing for high-dimensional array
# check that affine matches zooms
pixdim = np.asarray(img_nii.header.get_zooms())[:dim]
norm_affine = np.sqrt(np.sum(np.square(img_nii.affine[:dim, :dim]), 0))
if np.allclose(pixdim, norm_affine):
return img_nii
if hasattr(img_nii, 'get_sform'):
return rectify_header_sform_qform(img_nii)
return img_nii


def rectify_header_sform_qform(img_nii):
"""
Look at the sform and qform of the nifti object and correct it if any
incompatibilities with pixel dimensions

Adapted from https://github.com/NifTK/NiftyNet/blob/v0.6.0/niftynet/io/misc_io.py
"""
d = img_nii.header['dim'][0]
pixdim = np.asarray(img_nii.header.get_zooms())[:d]
sform, qform = img_nii.get_sform(), img_nii.get_qform()
norm_sform = np.sqrt(np.sum(np.square(sform[:d, :d]), 0))
norm_qform = np.sqrt(np.sum(np.square(qform[:d, :d]), 0))
sform_mismatch = not np.allclose(norm_sform, pixdim)
qform_mismatch = not np.allclose(norm_qform, pixdim)

if img_nii.header['sform_code'] != 0:
if not sform_mismatch:
return img_nii
if not qform_mismatch:
img_nii.set_sform(img_nii.get_qform())
return img_nii
if img_nii.header['qform_code'] != 0:
if not qform_mismatch:
return img_nii
if not sform_mismatch:
img_nii.set_qform(img_nii.get_sform())
return img_nii

norm_affine = np.sqrt(np.sum(np.square(img_nii.affine[:, :3]), 0))
to_divide = np.tile(np.expand_dims(np.append(norm_affine, 1), axis=1), [1, 4])
pixdim = np.append(pixdim, [1.] * (4 - len(pixdim)))
to_multiply = np.tile(np.expand_dims(pixdim, axis=1), [1, 4])
affine = img_nii.affine / to_divide.T * to_multiply.T
warnings.warn('Modifying image affine from {} to {}'.format(img_nii.affine, affine))

img_nii.set_sform(affine)
img_nii.set_qform(affine)
return img_nii
106 changes: 89 additions & 17 deletions monai/transforms/composables.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@
import monai
from monai.data.utils import get_random_patch, get_valid_patch_size
from monai.transforms.compose import Randomizable, Transform
from monai.transforms.transforms import LoadNifti, AsChannelFirst, AddChannel, Rotate90, SpatialCrop
from monai.transforms.transforms import (LoadNifti, AsChannelFirst, Orientation,
AddChannel, Spacing, Rotate90, SpatialCrop)
from monai.utils.misc import ensure_tuple
from monai.transforms.utils import generate_pos_neg_label_crop_centers
from monai.utils.aliases import alias

export = monai.utils.export("monai.transforms")


@export
class MapTransform(Transform):
"""
A subclass of ``monai.transforms.compose.Transform`` with an assumption
Expand Down Expand Up @@ -54,6 +57,85 @@ def __init__(self, keys):


@export
@alias('SpacingD', 'SpacingDict')
class Spacingd(MapTransform):
"""
dictionary-based wrapper of :class: `monai.transforms.transforms.Spacing`.
"""

def __init__(self, keys, affine_key, pixdim, interp_order=2, keep_shape=False, output_key='spacing'):
"""
Args:
affine_key (hashable): the key to the original affine.
The affine will be used to compute input data's pixdim.
pixdim (sequence of floats): output voxel spacing.
interp_order (int or sequence of ints): int: the same interpolation order
for all data indexed by `self,keys`; sequence of ints, should
correspond to an interpolation order for each data item indexed
by `self.keys` respectively.
keep_shape (bool): whether to maintain the original spatial shape
after resampling. Defaults to False.
output_key (hashable): key to be added to the output dictionary to track
the pixdim status.

"""
MapTransform.__init__(self, keys)
self.affine_key = affine_key
self.spacing_transform = Spacing(pixdim, keep_shape=keep_shape)
interp_order = ensure_tuple(interp_order)
self.interp_order = interp_order \
if len(interp_order) == len(self.keys) else interp_order * len(self.keys)
print(self.interp_order)
self.output_key = output_key

def __call__(self, data):
d = dict(data)
affine = d[self.affine_key]
original_pixdim, new_pixdim = None, None
for key, interp in zip(self.keys, self.interp_order):
d[key], original_pixdim, new_pixdim = self.spacing_transform(d[key], affine, interp_order=interp)
d[self.output_key] = {'original_pixdim': original_pixdim, 'current_pixdim': new_pixdim}
return d


@export
@alias('OrientationD', 'OrientationDict')
class Orientationd(MapTransform):
"""
dictionary-based wrapper of :class: `monai.transforms.transforms.Orientation`.
"""

def __init__(self, keys, affine_key, axcodes, labels=None, output_key='orientation'):
"""
Args:
affine_key (hashable): the key to the original affine.
The affine will be used to compute input data's orientation.
axcodes (N elements sequence): for spatial ND input's orientation.
e.g. axcodes='RAS' represents 3D orientation:
(Left, Right), (Posterior, Anterior), (Inferior, Superior).
default orientation labels options are: 'L' and 'R' for the first dimension,
'P' and 'A' for the second, 'I' and 'S' for the third.
labels : optional, None or sequence of (2,) sequences
(2,) sequences are labels for (beginning, end) of output axis.
see: ``nibabel.orientations.ornt2axcodes``.
"""
MapTransform.__init__(self, keys)
self.affine_key = affine_key
self.orientation_transform = Orientation(axcodes=axcodes, labels=labels)
self.output_key = output_key

def __call__(self, data):
d = dict(data)
affine = d[self.affine_key]
original_ornt, new_ornt = None, None
for key in self.keys:
d[key], original_ornt, new_ornt = self.orientation_transform(d[key], affine)
d[self.output_key] = {'original_ornt': original_ornt, 'current_ornt': new_ornt}
return d


@export
@alias('LoadNiftiD', 'LoadNiftiDict')
class LoadNiftid(MapTransform):
"""
dictionary-based wrapper of LoadNifti, must load image and metadata together.
Expand Down Expand Up @@ -92,6 +174,7 @@ def __call__(self, data):


@export
@alias('AsChannelFirstD', 'AsChannelFirstDict')
class AsChannelFirstd(MapTransform):
"""
dictionary-based wrapper of AsChannelFirst.
Expand All @@ -115,6 +198,7 @@ def __call__(self, data):


@export
@alias('AddChannelD', 'AddChannelDict')
class AddChanneld(MapTransform):
"""
dictionary-based wrapper of AddChannel.
Expand All @@ -137,6 +221,7 @@ def __call__(self, data):


@export
@alias('Rotate90D', 'Rotate90Dict')
class Rotate90d(MapTransform):
"""
dictionary-based wrapper of Rotate90.
Expand All @@ -162,6 +247,7 @@ def __call__(self, data):


@export
@alias('UniformRandomPatchD', 'UniformRandomPatchDict')
class UniformRandomPatchd(Randomizable, MapTransform):
"""
Selects a patch of the given size chosen at a uniformly random position in the image.
Expand Down Expand Up @@ -189,6 +275,7 @@ def __call__(self, data):


@export
@alias('RandRotate90D', 'RandRotate90Dict')
class RandRotate90d(Randomizable, MapTransform):
"""
With probability `prob`, input arrays are rotated by 90 degrees
Expand Down Expand Up @@ -233,6 +320,7 @@ def __call__(self, data):


@export
@alias('RandCropByPosNegLabelD', 'RandCropByPosNegLabelDict')
class RandCropByPosNegLabeld(Randomizable, MapTransform):
"""
Crop random fixed sized regions with the center being a foreground or background voxel
Expand Down Expand Up @@ -285,19 +373,3 @@ def __call__(self, data):
results[i][key] = data[key]

return results


# if __name__ == "__main__":
# import numpy as np
# data = {
# 'img': np.array((1, 2, 3, 4)).reshape((1, 2, 2)),
# 'seg': np.array((1, 2, 3, 4)).reshape((1, 2, 2)),
# 'affine': 3,
# 'dtype': 4,
# 'unused': 5,
# }
# rotator = RandRotate90d(keys=['img', 'seg'], prob=0.8)
# # rotator.set_random_state(1234)
# data_result = rotator(data)
# print(data_result.keys())
# print(data_result['img'], data_result['seg'])
Loading