Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
2622402
collate , decollate, dataset, dataloader, out=
rijobro Apr 14, 2022
19e68c9
mypy
rijobro Apr 14, 2022
d017918
skip decollation for pytorch 1.7
rijobro Apr 14, 2022
b36cd10
fix
rijobro Apr 14, 2022
a8f0373
fix
rijobro Apr 14, 2022
c6ae90d
Merge branch 'dev' into MetaTensor_collate_decollate
rijobro Apr 19, 2022
12afd4a
add batch index testing
rijobro Apr 20, 2022
f7a583e
Merge remote-tracking branch 'MONAI/dev' into MetaTensor_collate_deco…
rijobro Apr 20, 2022
fb9b10f
fixes
rijobro Apr 20, 2022
4875784
fix
rijobro Apr 20, 2022
b307d46
fix
rijobro Apr 20, 2022
e40553a
fix
rijobro Apr 20, 2022
f2c2548
fix
rijobro Apr 20, 2022
f9fd14a
load image meta tensor
rijobro Apr 20, 2022
9a4bbb4
Merge branch 'MetaTensor_collate_decollate' into MetaTensorLoadImage
rijobro Apr 20, 2022
b227fdd
splitdims fix
rijobro Apr 20, 2022
603cdb5
flake8
rijobro Apr 20, 2022
226b104
Merge branch 'MetaTensor' into MetaTensorLoadImage
rijobro Apr 20, 2022
91eff3a
fix test_nifti_rw
rijobro Apr 21, 2022
0f548c6
test_smartcachedataset
rijobro Apr 21, 2022
7e3b50c
test fixes
rijobro Apr 21, 2022
549c2ee
test fixes
rijobro Apr 21, 2022
3563bc3
Merge remote-tracking branch 'MONAI/dev' into MetaTensorLoadImage
rijobro Apr 21, 2022
55f483a
fix
rijobro Apr 21, 2022
ef01504
fixes
rijobro Apr 21, 2022
752ed8b
fix
rijobro Apr 21, 2022
76c6459
fixes
rijobro Apr 21, 2022
5df8709
Merge remote-tracking branch 'MONAI/dev' into MetaTensorLoadImage
rijobro Apr 22, 2022
b1e19eb
Merge branch 'MetaTensor' into MetaTensorLoadImage
wyli Apr 27, 2022
ad1d780
fix wsi test
rijobro Apr 29, 2022
9e7b3d6
changes after code review
rijobro Apr 29, 2022
8cb4634
Merge branch 'MetaTensor' into MetaTensorLoadImage
wyli May 2, 2022
ad58d79
fixes unit tests
wyli May 3, 2022
2fb16f9
fixes: TypeError: __str__ returned non-string (type list)
wyli May 3, 2022
c74f0b3
fixes unit test tests.test_lr_finder
wyli May 3, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion monai/apps/deepgrow/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@

import numpy as np

from monai.transforms import AsChannelFirstd, Compose, LoadImaged, Orientationd, Spacingd
from monai.transforms import AsChannelFirstd, Compose, FromMetaTensord, LoadImaged, Orientationd, Spacingd, ToNumpyd
from monai.utils import GridSampleMode
from monai.utils.enums import PostFix


def create_dataset(
Expand Down Expand Up @@ -125,6 +126,8 @@ def _default_transforms(image_key, label_key, pixdim):
return Compose(
[
LoadImaged(keys=keys),
FromMetaTensord(keys=keys),
ToNumpyd(keys=keys + [PostFix.meta(k) for k in keys]),
AsChannelFirstd(keys=keys),
Orientationd(keys=keys, axcodes="RAS"),
Spacingd(keys=keys, pixdim=pixdim, mode=mode),
Expand Down
11 changes: 6 additions & 5 deletions monai/data/dataset_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
from monai.config import KeysCollection
from monai.data.dataloader import DataLoader
from monai.data.dataset import Dataset
from monai.data.utils import affine_to_spacing
from monai.transforms import concatenate
from monai.utils import convert_data_type
from monai.utils.enums import PostFix
from monai.utils import PostFix, convert_data_type

DEFAULT_POST_FIX = PostFix.meta()

Expand Down Expand Up @@ -84,7 +84,7 @@ def collect_meta_data(self):
raise ValueError(f"To collect meta data for the dataset, key `{self.meta_key}` must exist in `data`.")
self.all_meta_data.append(data[self.meta_key])

def get_target_spacing(self, spacing_key: str = "pixdim", anisotropic_threshold: int = 3, percentile: float = 10.0):
def get_target_spacing(self, spacing_key: str = "affine", anisotropic_threshold: int = 3, percentile: float = 10.0):
"""
Calculate the target spacing according to all spacings.
If the target spacing is very anisotropic,
Expand All @@ -93,7 +93,7 @@ def get_target_spacing(self, spacing_key: str = "pixdim", anisotropic_threshold:
After loading with `monai.DataLoader`, "pixdim" is in the form of `torch.Tensor` with size `(batch_size, 8)`.

Args:
spacing_key: key of spacing in meta data (default: ``pixdim``).
spacing_key: key of the affine used to compute spacing in meta data (default: ``affine``).
anisotropic_threshold: threshold to decide if the target spacing is anisotropic (default: ``3``).
percentile: for anisotropic target spacing, use the percentile of all spacings of the anisotropic axis to
replace that axis.
Expand All @@ -103,7 +103,8 @@ def get_target_spacing(self, spacing_key: str = "pixdim", anisotropic_threshold:
self.collect_meta_data()
if spacing_key not in self.all_meta_data[0]:
raise ValueError("The provided spacing_key is not in self.all_meta_data.")
all_spacings = concatenate(to_cat=[data[spacing_key][:, 1:4] for data in self.all_meta_data], axis=0)
spacings = [affine_to_spacing(data[spacing_key][0], 3)[None] for data in self.all_meta_data]
all_spacings = concatenate(to_cat=spacings, axis=0)
all_spacings, *_ = convert_data_type(data=all_spacings, output_type=np.ndarray, wrap_sequence=True)

target_spacing = np.median(all_spacings, axis=0)
Expand Down
17 changes: 8 additions & 9 deletions monai/data/image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(
raise ValueError("transform_with_metadata=True requires image_only=False.")
self.image_only = image_only
self.transform_with_metadata = transform_with_metadata
self.loader = LoadImage(reader, image_only, dtype, *args, **kwargs)
self.loader = LoadImage(reader, dtype, *args, **kwargs)
self.set_random_state(seed=get_seed())
self._seed = 0 # transform synchronization seed

Expand All @@ -101,14 +101,13 @@ def __getitem__(self, index: int):
meta_data, seg_meta_data, seg, label = None, None, None, None

# load data and optionally meta
if self.image_only:
img = self.loader(self.image_files[index])
if self.seg_files is not None:
seg = self.loader(self.seg_files[index])
else:
img, meta_data = self.loader(self.image_files[index])
if self.seg_files is not None:
seg, seg_meta_data = self.loader(self.seg_files[index])
img = self.loader(self.image_files[index])
if not self.image_only:
meta_data = img.meta
if self.seg_files is not None:
seg = self.loader(self.seg_files[index])
if not self.image_only:
seg_meta_data = seg.meta

# apply the transforms
if self.transform is not None:
Expand Down
6 changes: 2 additions & 4 deletions monai/data/meta_obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def _copy_attr(self, attribute: str, input_objs: list[MetaObj], default_fn: Call
Returns:
Returns `None`, but `self` should be updated to have the copied attribute.
"""
attributes = [getattr(i, attribute) for i in input_objs]
attributes = [getattr(i, attribute) for i in input_objs if hasattr(i, attribute)]
if len(attributes) > 0:
val = attributes[0]
if deep_copy:
Expand Down Expand Up @@ -189,9 +189,7 @@ def get_default_meta(self) -> dict:

def __repr__(self) -> str:
"""String representation of class."""
out: str = super().__repr__()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this line should be reverted, I think the output should be the same as torch.Tensor, but with the extra metadata (now we have metadata instead of voxel data). @wyli what do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, please feel free to change, the corresponding test case is this line:

self.assertTrue(str(out).startswith("\nMetaData"))


out += "\nMetaData\n"
out: str = "\nMetaData\n"
if self.meta is not None:
out += "".join(f"\t{k}: {v}\n" for k, v in self.meta.items())
else:
Expand Down
36 changes: 34 additions & 2 deletions monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

import torch

from monai.config.type_definitions import NdarrayTensor
from monai.data.meta_obj import MetaObj, get_track_meta, get_track_transforms
from monai.data.utils import decollate_batch, list_data_collate
from monai.transforms.utils import remove_extra_metadata
from monai.utils.enums import PostFix

__all__ = ["MetaTensor"]
Expand Down Expand Up @@ -186,8 +188,8 @@ def __torch_function__(cls, func, types, args=(), kwargs=None) -> Any:
kwargs = {}
ret = super().__torch_function__(func, types, args, kwargs)
# if `out` has been used as argument, metadata is not copied, nothing to do.
if "out" in kwargs:
return ret
# if "out" in kwargs:
# return ret
# we might have 1 or multiple outputs. Might be MetaTensor, might be something
# else (e.g., `__repr__` returns a string).
# Convert to list (if necessary), process, and at end remove list if one was added.
Expand Down Expand Up @@ -232,3 +234,33 @@ def affine(self) -> torch.Tensor:
def affine(self, d: torch.Tensor) -> None:
"""Set the affine."""
self.meta["affine"] = d

@staticmethod
def ensure_torch_and_prune_meta(im: NdarrayTensor, meta: dict):
"""
Convert the image to `torch.Tensor`. If `affine` is in the `meta` dictionary,
convert that to `torch.Tensor`, too. Remove any superfluous metadata.

Args:
im: Input image (`np.ndarray` or `torch.Tensor`)
meta: Metadata dictionary.

Returns:
By default, a `MetaTensor` is returned.
However, if `get_track_meta()` is `False`, a `torch.Tensor` is returned.
"""
img = torch.as_tensor(im)

# if not tracking metadata, return `torch.Tensor`
if not get_track_meta() or meta is None:
return img

# ensure affine is of type `torch.Tensor`
if "affine" in meta:
meta["affine"] = torch.as_tensor(meta["affine"])

# remove any superfluous metadata.
remove_extra_metadata(meta)

# return the `MetaTensor`
return MetaTensor(img, meta=meta)
2 changes: 2 additions & 0 deletions monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,8 @@ def affine_to_spacing(affine: NdarrayTensor, r: int = 3, dtype=float, suppress_z
Returns:
an `r` dimensional vector of spacing.
"""
if len(affine.shape) != 2 or affine.shape[0] != affine.shape[1]:
raise ValueError(f"affine must be a square matrix, got {affine.shape}.")
_affine, *_ = convert_to_dst_type(affine[:r, :r], dst=affine, dtype=dtype)
if isinstance(_affine, torch.Tensor):
spacing = torch.sqrt(torch.sum(_affine * _affine, dim=0))
Expand Down
3 changes: 3 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,7 @@
generate_label_classes_crop_centers,
generate_pos_neg_label_crop_centers,
generate_spatial_bounding_box,
get_extra_metadata_keys,
get_extreme_points,
get_largest_connected_component_mask,
get_number_image_type_conversions,
Expand All @@ -582,6 +583,8 @@
map_spatial_axes,
print_transform_backends,
rand_choice,
remove_extra_metadata,
remove_keys,
rescale_array,
rescale_array_int_max,
rescale_instance_array,
Expand Down
37 changes: 19 additions & 18 deletions monai/transforms/io/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,19 @@
from monai.data import image_writer
from monai.data.folder_layout import FolderLayout
from monai.data.image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader
from monai.data.meta_tensor import MetaTensor
from monai.transforms.transform import Transform
from monai.transforms.utility.array import EnsureChannelFirst
from monai.utils import GridSampleMode, GridSamplePadMode
from monai.utils import ImageMetaKey as Key
from monai.utils import InterpolateMode, OptionalImportError, ensure_tuple, look_up_option, optional_import
from monai.utils import (
InterpolateMode,
OptionalImportError,
deprecated_arg,
ensure_tuple,
look_up_option,
optional_import,
)

nib, _ = optional_import("nibabel")
Image, _ = optional_import("PIL.Image")
Expand Down Expand Up @@ -93,14 +101,11 @@ class LoadImage(Transform):

"""

@deprecated_arg(
name="image_only", since="0.8", msg_suffix="If necessary, please extract meta data with `MetaTensor.meta`"
)
def __init__(
self,
reader=None,
image_only: bool = False,
dtype: DtypeLike = np.float32,
ensure_channel_first: bool = False,
*args,
**kwargs,
self, reader=None, dtype: DtypeLike = np.float32, ensure_channel_first: bool = False, *args, **kwargs
) -> None:
"""
Args:
Expand All @@ -111,7 +116,6 @@ def __init__(
``"ITKReader"``, ``"NibabelReader"``, ``"NumpyReader"``.
a reader instance will be constructed with the `*args` and `**kwargs` parameters.
- if `reader` is a reader class/instance, it will be registered to this loader accordingly.
image_only: if True return only the image volume, otherwise return image data array and header dict.
dtype: if not None convert the loaded image to this data type.
ensure_channel_first: if `True` and loaded both image array and meta data, automatically convert
the image array shape to `channel first`. default to `False`.
Expand All @@ -120,8 +124,8 @@ def __init__(

Note:

- The transform returns an image data array if `image_only` is True,
or a tuple of two elements containing the data array, and the meta data in a dictionary format otherwise.
- The transform returns a MetaTensor, unless `set_track_meta(False)` has been used, in which case, a
`torch.Tensor` will be returned.
- If `reader` is specified, the loader will attempt to use the specified readers and the default supported
readers. This might introduce overheads when handling the exceptions of trying the incompatible loaders.
In this case, it is therefore recommended setting the most appropriate reader as
Expand All @@ -130,7 +134,6 @@ def __init__(
"""

self.auto_select = reader is None
self.image_only = image_only
self.dtype = dtype
self.ensure_channel_first = ensure_channel_first

Expand Down Expand Up @@ -241,14 +244,12 @@ def __call__(self, filename: Union[Sequence[PathLike], PathLike], reader: Option
raise ValueError("`meta_data` must be a dict.")
# make sure all elements in metadata are little endian
meta_data = switch_endianness(meta_data, "<")
if self.ensure_channel_first:
img_array = EnsureChannelFirst()(img_array, meta_data)

if self.image_only:
return img_array
meta_data[Key.FILENAME_OR_OBJ] = f"{ensure_tuple(filename)[0]}" # Path obj should be strings for data loader

return img_array, meta_data
img = MetaTensor.ensure_torch_and_prune_meta(img_array, meta_data)
if self.ensure_channel_first:
img = EnsureChannelFirst()(img)
return img


class SaveImage(Transform):
Expand Down
44 changes: 8 additions & 36 deletions monai/transforms/io/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from monai.data.image_reader import ImageReader
from monai.transforms.io.array import LoadImage, SaveImage
from monai.transforms.transform import MapTransform
from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode, ensure_tuple, ensure_tuple_rep
from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode, deprecated_arg, ensure_tuple_rep
from monai.utils.enums import PostFix

__all__ = ["LoadImaged", "LoadImageD", "LoadImageDict", "SaveImaged", "SaveImageD", "SaveImageDict"]
Expand Down Expand Up @@ -64,6 +64,10 @@ class LoadImaged(MapTransform):

"""

@deprecated_arg(name="image_only", since="0.8")
@deprecated_arg(name="meta_keys", since="0.8")
@deprecated_arg(name="meta_key_postfix", since="0.8")
@deprecated_arg(name="overwriting", since="0.8")
def __init__(
self,
keys: KeysCollection,
Expand All @@ -90,32 +94,14 @@ def __init__(
a reader instance will be constructed with the `*args` and `**kwargs` parameters.
- if `reader` is a reader class/instance, it will be registered to this loader accordingly.
dtype: if not None, convert the loaded image data to this data type.
meta_keys: explicitly indicate the key to store the corresponding meta data dictionary.
the meta data is a dictionary object which contains: filename, original_shape, etc.
it can be a sequence of string, map to the `keys`.
if None, will try to construct meta_keys by `key_{meta_key_postfix}`.
meta_key_postfix: if meta_keys is None, use `key_{postfix}` to store the metadata of the nifti image,
default is `meta_dict`. The meta data is a dictionary object.
For example, load nifti file for `image`, store the metadata into `image_meta_dict`.
overwriting: whether allow overwriting existing meta data of same key.
default is False, which will raise exception if encountering existing key.
image_only: if True return dictionary containing just only the image volumes, otherwise return
dictionary containing image data array and header dict per input key.
ensure_channel_first: if `True` and loaded both image array and meta data, automatically convert
the image array shape to `channel first`. default to `False`.
allow_missing_keys: don't raise exception if key is missing.
args: additional parameters for reader if providing a reader name.
kwargs: additional parameters for reader if providing a reader name.
"""
super().__init__(keys, allow_missing_keys)
self._loader = LoadImage(reader, image_only, dtype, ensure_channel_first, *args, **kwargs)
if not isinstance(meta_key_postfix, str):
raise TypeError(f"meta_key_postfix must be a str but is {type(meta_key_postfix).__name__}.")
self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys)
if len(self.keys) != len(self.meta_keys):
raise ValueError("meta_keys should have the same length as keys.")
self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys))
self.overwriting = overwriting
self._loader = LoadImage(reader, dtype, ensure_channel_first, *args, **kwargs)

def register(self, reader: ImageReader):
self._loader.register(reader)
Expand All @@ -127,22 +113,8 @@ def __call__(self, data, reader: Optional[ImageReader] = None):

"""
d = dict(data)
for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix):
data = self._loader(d[key], reader)
if self._loader.image_only:
if not isinstance(data, np.ndarray):
raise ValueError("loader must return a numpy array (because image_only=True was used).")
d[key] = data
else:
if not isinstance(data, (tuple, list)):
raise ValueError("loader must return a tuple or list (because image_only=False was used).")
d[key] = data[0]
if not isinstance(data[1], dict):
raise ValueError("metadata must be a dict.")
meta_key = meta_key or f"{key}_{meta_key_postfix}"
if meta_key in d and not self.overwriting:
raise KeyError(f"Meta data with key {meta_key} already exists and overwriting=False.")
d[meta_key] = data[1]
for key in self.key_iterator(d):
d[key] = self._loader(d[key], reader)
return d


Expand Down
3 changes: 3 additions & 0 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from monai.config import DtypeLike
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.meta_tensor import MetaTensor
from monai.transforms.transform import Randomizable, RandomizableTransform, Transform
from monai.transforms.utils import (
extreme_points_to_image,
Expand Down Expand Up @@ -210,6 +211,8 @@ def __call__(self, img: NdarrayOrTensor, meta_dict: Optional[Mapping] = None) ->
"""
Apply the transform to `img`.
"""
if isinstance(img, MetaTensor):
meta_dict = img.meta
if not isinstance(meta_dict, Mapping):
msg = "meta_dict not available, EnsureChannelFirst is not in use."
if self.strict_check:
Expand Down
8 changes: 7 additions & 1 deletion monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,13 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
d[split_meta_key] = deepcopy(orig_meta)
dim = self.splitter.dim
if dim > 0: # don't update affine if channel dim
shift = np.eye(len(d[split_meta_key]["affine"])) # type: ignore
affine = d[split_meta_key]["affine"] # type: ignore
ndim = len(affine)
shift: NdarrayOrTensor
if isinstance(affine, torch.Tensor):
shift = torch.eye(ndim, device=affine.device, dtype=affine.dtype)
else:
shift = np.eye(ndim)
shift[dim - 1, -1] = i # type: ignore
d[split_meta_key]["affine"] = d[split_meta_key]["affine"] @ shift # type: ignore

Expand Down
Loading