Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np
import torch

import monai
from monai.config.type_definitions import NdarrayTensor
from monai.data.meta_obj import MetaObj, get_track_meta
from monai.data.utils import affine_to_spacing, decollate_batch, list_data_collate, remove_extra_metadata
Expand Down Expand Up @@ -461,7 +462,9 @@ def clone(self):
return new_inst

@staticmethod
def ensure_torch_and_prune_meta(im: NdarrayTensor, meta: dict, simple_keys: bool = False):
def ensure_torch_and_prune_meta(
im: NdarrayTensor, meta: dict, simple_keys: bool = False, pattern: str | None = None, sep: str = "."
):
"""
Convert the image to `torch.Tensor`. If `affine` is in the `meta` dictionary,
convert that to `torch.Tensor`, too. Remove any superfluous metadata.
Expand All @@ -470,6 +473,11 @@ def ensure_torch_and_prune_meta(im: NdarrayTensor, meta: dict, simple_keys: bool
im: Input image (`np.ndarray` or `torch.Tensor`)
meta: Metadata dictionary.
simple_keys: whether to keep only a simple subset of metadata keys.
pattern: combined with `sep`, a regular expression used to match and prune keys
in the metadata (nested dictionary), default to None, no key deletion.
sep: combined with `pattern`, used to match and delete keys in the metadata (nested dictionary).
default is ".", see also :py:class:`monai.transforms.DeleteItemsd`.
e.g. ``pattern=".*_code$", sep=" "`` removes any meta keys that ends with ``"_code"``.

Returns:
By default, a `MetaTensor` is returned.
Expand All @@ -488,6 +496,9 @@ def ensure_torch_and_prune_meta(im: NdarrayTensor, meta: dict, simple_keys: bool
meta[MetaKeys.AFFINE] = convert_to_tensor(meta[MetaKeys.AFFINE]) # bc-breaking
remove_extra_metadata(meta) # bc-breaking

if pattern is not None:
meta = monai.transforms.DeleteItemsd(keys=pattern, sep=sep, use_re=True)(meta)

# return the `MetaTensor`
return MetaTensor(img, meta=meta)

Expand Down
13 changes: 12 additions & 1 deletion monai/transforms/io/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ def __init__(
dtype: DtypeLike = np.float32,
ensure_channel_first: bool = False,
simple_keys: bool = False,
prune_meta_pattern: Optional[str] = None,
prune_meta_sep: str = ".",
*args,
**kwargs,
) -> None:
Expand All @@ -129,6 +131,11 @@ def __init__(
ensure_channel_first: if `True` and loaded both image array and metadata, automatically convert
the image array shape to `channel first`. default to `False`.
simple_keys: whether to remove redundant metadata keys, default to False for backward compatibility.
prune_meta_pattern: combined with `prune_meta_sep`, a regular expression used to match and prune keys
in the metadata (nested dictionary), default to None, no key deletion.
prune_meta_sep: combined with `prune_meta_pattern`, used to match and prune keys
in the metadata (nested dictionary). default is ".", see also :py:class:`monai.transforms.DeleteItemsd`.
e.g. ``prune_meta_pattern=".*_code$", prune_meta_sep=" "`` removes meta keys that ends with ``"_code"``.
args: additional parameters for reader if providing a reader name.
kwargs: additional parameters for reader if providing a reader name.

Expand All @@ -148,6 +155,8 @@ def __init__(
self.dtype = dtype
self.ensure_channel_first = ensure_channel_first
self.simple_keys = simple_keys
self.pattern = prune_meta_pattern
self.sep = prune_meta_sep

self.readers: List[ImageReader] = []
for r in SUPPORTED_READERS: # set predefined readers as default
Expand Down Expand Up @@ -258,7 +267,9 @@ def __call__(self, filename: Union[Sequence[PathLike], PathLike], reader: Option
meta_data = switch_endianness(meta_data, "<")

meta_data[Key.FILENAME_OR_OBJ] = f"{ensure_tuple(filename)[0]}" # Path obj should be strings for data loader
img = MetaTensor.ensure_torch_and_prune_meta(img_array, meta_data, self.simple_keys)
img = MetaTensor.ensure_torch_and_prune_meta(
img_array, meta_data, self.simple_keys, pattern=self.pattern, sep=self.sep
)
if self.ensure_channel_first:
img = EnsureChannelFirst()(img)
if self.image_only:
Expand Down
19 changes: 18 additions & 1 deletion monai/transforms/io/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def __init__(
image_only: bool = False,
ensure_channel_first: bool = False,
simple_keys: bool = False,
prune_meta_pattern: Optional[str] = None,
prune_meta_sep: str = ".",
allow_missing_keys: bool = False,
*args,
**kwargs,
Expand Down Expand Up @@ -105,12 +107,27 @@ def __init__(
ensure_channel_first: if `True` and loaded both image array and metadata, automatically convert
the image array shape to `channel first`. default to `False`.
simple_keys: whether to remove redundant metadata keys, default to False for backward compatibility.
prune_meta_pattern: combined with `prune_meta_sep`, a regular expression used to match and prune keys
in the metadata (nested dictionary), default to None, no key deletion.
prune_meta_sep: combined with `prune_meta_pattern`, used to match and prune keys
in the metadata (nested dictionary). default is ".", see also :py:class:`monai.transforms.DeleteItemsd`.
e.g. ``prune_meta_pattern=".*_code$", prune_meta_sep=" "`` removes meta keys that ends with ``"_code"``.
allow_missing_keys: don't raise exception if key is missing.
args: additional parameters for reader if providing a reader name.
kwargs: additional parameters for reader if providing a reader name.
"""
super().__init__(keys, allow_missing_keys)
self._loader = LoadImage(reader, image_only, dtype, ensure_channel_first, simple_keys, *args, **kwargs)
self._loader = LoadImage(
reader,
image_only,
dtype,
ensure_channel_first,
simple_keys,
prune_meta_pattern,
prune_meta_sep,
*args,
**kwargs,
)
if not isinstance(meta_key_postfix, str):
raise TypeError(f"meta_key_postfix must be a str but is {type(meta_key_postfix).__name__}.")
self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys)
Expand Down
4 changes: 2 additions & 2 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,7 @@ def __init__(self, keys: KeysCollection, sep: str = ".", use_re: Union[Sequence[
See also: :py:class:`monai.transforms.compose.MapTransform`
sep: the separator tag to define nested dictionary keys, default to ".".
use_re: whether the specified key is a regular expression, it also can be
a list of bool values, map the to keys.
a list of bool values, mapping them to `keys`.
"""
super().__init__(keys)
self.sep = sep
Expand All @@ -730,7 +730,7 @@ def _delete_item(keys, d, use_re: bool = False):
if len(keys) > 1:
d[key] = _delete_item(keys[1:], d[key], use_re)
return d
return {k: v for k, v in d.items() if (use_re and not re.search(key, k)) or (not use_re and k != key)}
return {k: v for k, v in d.items() if (use_re and not re.search(key, f"{k}")) or (not use_re and k != key)}

d = dict(data)
for key, use_re in zip(self.keys, self.use_re):
Expand Down
3 changes: 2 additions & 1 deletion tests/test_load_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,12 +332,13 @@ def tearDownClass(cls):
@parameterized.expand(TESTS_META)
def test_correct(self, input_param, expected_shape, track_meta):
set_track_meta(track_meta)
r = LoadImage(image_only=True, **input_param)(self.test_data)
r = LoadImage(image_only=True, prune_meta_pattern="glmax", prune_meta_sep="%", **input_param)(self.test_data)
self.assertTupleEqual(r.shape, expected_shape)
if track_meta:
self.assertIsInstance(r, MetaTensor)
self.assertTrue(hasattr(r, "affine"))
self.assertIsInstance(r.affine, torch.Tensor)
self.assertTrue("glmax" not in r.meta)
else:
self.assertIsInstance(r, torch.Tensor)
self.assertNotIsInstance(r, MetaTensor)
Expand Down
7 changes: 5 additions & 2 deletions tests/test_load_imaged.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,11 @@ def tearDownClass(cls):
super(__class__, cls).tearDownClass()

@parameterized.expand(TESTS_META)
def test_correct(self, input_param, expected_shape, track_meta):
def test_correct(self, input_p, expected_shape, track_meta):
set_track_meta(track_meta)
result = LoadImaged(image_only=True, **input_param)(self.test_data)
result = LoadImaged(image_only=True, prune_meta_pattern=".*_code$", prune_meta_sep=" ", **input_p)(
self.test_data
)

# shouldn't have any extra meta data keys
self.assertEqual(len(result), len(KEYS))
Expand All @@ -178,6 +180,7 @@ def test_correct(self, input_param, expected_shape, track_meta):
self.assertTrue(hasattr(r, "affine"))
self.assertIsInstance(r.affine, torch.Tensor)
self.assertEqual(r.meta["space"], "RAS")
self.assertTrue("qform_code" not in r.meta)
else:
self.assertIsInstance(r, torch.Tensor)
self.assertNotIsInstance(r, MetaTensor)
Expand Down