diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 0b6a4b7b19..5911e218ee 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -18,6 +18,7 @@ import numpy as np import torch +import monai from monai.config.type_definitions import NdarrayTensor from monai.data.meta_obj import MetaObj, get_track_meta from monai.data.utils import affine_to_spacing, decollate_batch, list_data_collate, remove_extra_metadata @@ -461,7 +462,9 @@ def clone(self): return new_inst @staticmethod - def ensure_torch_and_prune_meta(im: NdarrayTensor, meta: dict, simple_keys: bool = False): + def ensure_torch_and_prune_meta( + im: NdarrayTensor, meta: dict, simple_keys: bool = False, pattern: str | None = None, sep: str = "." + ): """ Convert the image to `torch.Tensor`. If `affine` is in the `meta` dictionary, convert that to `torch.Tensor`, too. Remove any superfluous metadata. @@ -470,6 +473,11 @@ def ensure_torch_and_prune_meta(im: NdarrayTensor, meta: dict, simple_keys: bool im: Input image (`np.ndarray` or `torch.Tensor`) meta: Metadata dictionary. simple_keys: whether to keep only a simple subset of metadata keys. + pattern: combined with `sep`, a regular expression used to match and prune keys + in the metadata (nested dictionary), default to None, no key deletion. + sep: combined with `pattern`, used to match and delete keys in the metadata (nested dictionary). + default is ".", see also :py:class:`monai.transforms.DeleteItemsd`. + e.g. ``pattern=".*_code$", sep=" "`` removes any meta keys that ends with ``"_code"``. Returns: By default, a `MetaTensor` is returned. @@ -488,6 +496,9 @@ def ensure_torch_and_prune_meta(im: NdarrayTensor, meta: dict, simple_keys: bool meta[MetaKeys.AFFINE] = convert_to_tensor(meta[MetaKeys.AFFINE]) # bc-breaking remove_extra_metadata(meta) # bc-breaking + if pattern is not None: + meta = monai.transforms.DeleteItemsd(keys=pattern, sep=sep, use_re=True)(meta) + # return the `MetaTensor` return MetaTensor(img, meta=meta) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 8dd849d33e..dc43475b63 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -112,6 +112,8 @@ def __init__( dtype: DtypeLike = np.float32, ensure_channel_first: bool = False, simple_keys: bool = False, + prune_meta_pattern: Optional[str] = None, + prune_meta_sep: str = ".", *args, **kwargs, ) -> None: @@ -129,6 +131,11 @@ def __init__( ensure_channel_first: if `True` and loaded both image array and metadata, automatically convert the image array shape to `channel first`. default to `False`. simple_keys: whether to remove redundant metadata keys, default to False for backward compatibility. + prune_meta_pattern: combined with `prune_meta_sep`, a regular expression used to match and prune keys + in the metadata (nested dictionary), default to None, no key deletion. + prune_meta_sep: combined with `prune_meta_pattern`, used to match and prune keys + in the metadata (nested dictionary). default is ".", see also :py:class:`monai.transforms.DeleteItemsd`. + e.g. ``prune_meta_pattern=".*_code$", prune_meta_sep=" "`` removes meta keys that ends with ``"_code"``. args: additional parameters for reader if providing a reader name. kwargs: additional parameters for reader if providing a reader name. @@ -148,6 +155,8 @@ def __init__( self.dtype = dtype self.ensure_channel_first = ensure_channel_first self.simple_keys = simple_keys + self.pattern = prune_meta_pattern + self.sep = prune_meta_sep self.readers: List[ImageReader] = [] for r in SUPPORTED_READERS: # set predefined readers as default @@ -258,7 +267,9 @@ def __call__(self, filename: Union[Sequence[PathLike], PathLike], reader: Option meta_data = switch_endianness(meta_data, "<") meta_data[Key.FILENAME_OR_OBJ] = f"{ensure_tuple(filename)[0]}" # Path obj should be strings for data loader - img = MetaTensor.ensure_torch_and_prune_meta(img_array, meta_data, self.simple_keys) + img = MetaTensor.ensure_torch_and_prune_meta( + img_array, meta_data, self.simple_keys, pattern=self.pattern, sep=self.sep + ) if self.ensure_channel_first: img = EnsureChannelFirst()(img) if self.image_only: diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index 42918f5e63..761c891e85 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -75,6 +75,8 @@ def __init__( image_only: bool = False, ensure_channel_first: bool = False, simple_keys: bool = False, + prune_meta_pattern: Optional[str] = None, + prune_meta_sep: str = ".", allow_missing_keys: bool = False, *args, **kwargs, @@ -105,12 +107,27 @@ def __init__( ensure_channel_first: if `True` and loaded both image array and metadata, automatically convert the image array shape to `channel first`. default to `False`. simple_keys: whether to remove redundant metadata keys, default to False for backward compatibility. + prune_meta_pattern: combined with `prune_meta_sep`, a regular expression used to match and prune keys + in the metadata (nested dictionary), default to None, no key deletion. + prune_meta_sep: combined with `prune_meta_pattern`, used to match and prune keys + in the metadata (nested dictionary). default is ".", see also :py:class:`monai.transforms.DeleteItemsd`. + e.g. ``prune_meta_pattern=".*_code$", prune_meta_sep=" "`` removes meta keys that ends with ``"_code"``. allow_missing_keys: don't raise exception if key is missing. args: additional parameters for reader if providing a reader name. kwargs: additional parameters for reader if providing a reader name. """ super().__init__(keys, allow_missing_keys) - self._loader = LoadImage(reader, image_only, dtype, ensure_channel_first, simple_keys, *args, **kwargs) + self._loader = LoadImage( + reader, + image_only, + dtype, + ensure_channel_first, + simple_keys, + prune_meta_pattern, + prune_meta_sep, + *args, + **kwargs, + ) if not isinstance(meta_key_postfix, str): raise TypeError(f"meta_key_postfix must be a str but is {type(meta_key_postfix).__name__}.") self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 073e50a3be..2c4394a7da 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -718,7 +718,7 @@ def __init__(self, keys: KeysCollection, sep: str = ".", use_re: Union[Sequence[ See also: :py:class:`monai.transforms.compose.MapTransform` sep: the separator tag to define nested dictionary keys, default to ".". use_re: whether the specified key is a regular expression, it also can be - a list of bool values, map the to keys. + a list of bool values, mapping them to `keys`. """ super().__init__(keys) self.sep = sep @@ -730,7 +730,7 @@ def _delete_item(keys, d, use_re: bool = False): if len(keys) > 1: d[key] = _delete_item(keys[1:], d[key], use_re) return d - return {k: v for k, v in d.items() if (use_re and not re.search(key, k)) or (not use_re and k != key)} + return {k: v for k, v in d.items() if (use_re and not re.search(key, f"{k}")) or (not use_re and k != key)} d = dict(data) for key, use_re in zip(self.keys, self.use_re): diff --git a/tests/test_load_image.py b/tests/test_load_image.py index e6dfae0901..cc227021a2 100644 --- a/tests/test_load_image.py +++ b/tests/test_load_image.py @@ -332,12 +332,13 @@ def tearDownClass(cls): @parameterized.expand(TESTS_META) def test_correct(self, input_param, expected_shape, track_meta): set_track_meta(track_meta) - r = LoadImage(image_only=True, **input_param)(self.test_data) + r = LoadImage(image_only=True, prune_meta_pattern="glmax", prune_meta_sep="%", **input_param)(self.test_data) self.assertTupleEqual(r.shape, expected_shape) if track_meta: self.assertIsInstance(r, MetaTensor) self.assertTrue(hasattr(r, "affine")) self.assertIsInstance(r.affine, torch.Tensor) + self.assertTrue("glmax" not in r.meta) else: self.assertIsInstance(r, torch.Tensor) self.assertNotIsInstance(r, MetaTensor) diff --git a/tests/test_load_imaged.py b/tests/test_load_imaged.py index 096fbafa06..cd8b476a58 100644 --- a/tests/test_load_imaged.py +++ b/tests/test_load_imaged.py @@ -164,9 +164,11 @@ def tearDownClass(cls): super(__class__, cls).tearDownClass() @parameterized.expand(TESTS_META) - def test_correct(self, input_param, expected_shape, track_meta): + def test_correct(self, input_p, expected_shape, track_meta): set_track_meta(track_meta) - result = LoadImaged(image_only=True, **input_param)(self.test_data) + result = LoadImaged(image_only=True, prune_meta_pattern=".*_code$", prune_meta_sep=" ", **input_p)( + self.test_data + ) # shouldn't have any extra meta data keys self.assertEqual(len(result), len(KEYS)) @@ -178,6 +180,7 @@ def test_correct(self, input_param, expected_shape, track_meta): self.assertTrue(hasattr(r, "affine")) self.assertIsInstance(r.affine, torch.Tensor) self.assertEqual(r.meta["space"], "RAS") + self.assertTrue("qform_code" not in r.meta) else: self.assertIsInstance(r, torch.Tensor) self.assertNotIsInstance(r, MetaTensor)