From 5eed4e7b1aec2156bf59a8b63c1e9890da399dd8 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 1 Mar 2021 17:06:50 +0000 Subject: [PATCH 1/6] progress Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/apps/deepgrow/transforms.py | 19 +++++++----- monai/transforms/croppad/dictionary.py | 37 +++++++++++++++--------- monai/transforms/intensity/dictionary.py | 7 +++-- monai/transforms/transform.py | 23 +++++++++++++-- 4 files changed, 60 insertions(+), 26 deletions(-) diff --git a/monai/apps/deepgrow/transforms.py b/monai/apps/deepgrow/transforms.py index cc01a717ad..2998f9c31f 100644 --- a/monai/apps/deepgrow/transforms.py +++ b/monai/apps/deepgrow/transforms.py @@ -437,6 +437,7 @@ class SpatialCropForegroundd(MapTransform): end_coord_key: key to record the end coordinate of spatial bounding box for foreground. original_shape_key: key to record original shape for foreground. cropped_shape_key: key to record cropped shape for foreground. + allow_missing_keys: don't raise exception if key is missing. """ def __init__( @@ -452,8 +453,9 @@ def __init__( end_coord_key: str = "foreground_end_coord", original_shape_key: str = "foreground_original_shape", cropped_shape_key: str = "foreground_cropped_shape", + allow_missing_keys: bool = False, ) -> None: - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.source_key = source_key self.spatial_size = list(spatial_size) @@ -482,7 +484,7 @@ def __call__(self, data): else: cropper = SpatialCrop(roi_start=box_start, roi_end=box_end) - for key in self.keys: + for key in self.generator(d): meta_key = f"{key}_{self.meta_key_postfix}" d[meta_key][self.start_coord_key] = box_start d[meta_key][self.end_coord_key] = box_end @@ -629,6 +631,7 @@ class SpatialCropGuidanced(MapTransform): end_coord_key: key to record the end coordinate of spatial bounding box for foreground. original_shape_key: key to record original shape for foreground. cropped_shape_key: key to record cropped shape for foreground. + allow_missing_keys: don't raise exception if key is missing. """ def __init__( @@ -642,8 +645,9 @@ def __init__( end_coord_key: str = "foreground_end_coord", original_shape_key: str = "foreground_original_shape", cropped_shape_key: str = "foreground_cropped_shape", + allow_missing_keys: bool = False, ) -> None: - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.guidance = guidance self.spatial_size = list(spatial_size) @@ -697,7 +701,7 @@ def __call__(self, data): cropper = SpatialCrop(roi_start=box_start, roi_end=box_end) box_start, box_end = cropper.roi_start, cropper.roi_end - for key in self.keys: + for key in self.generator(d): if not np.array_equal(d[key].shape[1:], original_spatial_shape): raise RuntimeError("All the image specified in keys should have same spatial shape") meta_key = f"{key}_{self.meta_key_postfix}" @@ -898,10 +902,11 @@ class Fetch2DSliced(MapTransform): default is `meta_dict`, the meta data is a dictionary object. For example, to handle key `image`, read/write affine matrices from the metadata `image_meta_dict` dictionary's `affine` field. + allow_missing_keys: don't raise exception if key is missing. """ - def __init__(self, keys, guidance="guidance", axis: int = 0, meta_key_postfix: str = "meta_dict"): - super().__init__(keys) + def __init__(self, keys, guidance="guidance", axis: int = 0, meta_key_postfix: str = "meta_dict", allow_missing_keys: bool = False): + super().__init__(keys, allow_missing_keys) self.guidance = guidance self.axis = axis self.meta_key_postfix = meta_key_postfix @@ -920,7 +925,7 @@ def __call__(self, data): guidance = d[self.guidance] if len(guidance) < 3: raise RuntimeError("Guidance does not container slice_idx!") - for key in self.keys: + for key in self.generator(d): img_slice, idx = self._apply(d[key], guidance) d[key] = img_slice d[f"{key}_{self.meta_key_postfix}"]["slice_idx"] = idx diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 9739c6322f..4462e599eb 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -94,6 +94,7 @@ def __init__( spatial_size: Union[Sequence[int], int], method: Union[Method, str] = Method.SYMMETRIC, mode: NumpyPadModeSequence = NumpyPadMode.CONSTANT, + allow_missing_keys: bool = False, ) -> None: """ Args: @@ -108,15 +109,16 @@ def __init__( One of the listed string values or a user supplied function. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html It also can be a sequence of string, each element corresponds to a key in ``keys``. + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padder = SpatialPad(spatial_size, method) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key, m in zip(self.keys, self.mode): + for key, m in self.generator(d, self.mode): d[key] = self.padder(d[key], mode=m) return d @@ -216,6 +218,7 @@ def __init__( roi_size: Optional[Sequence[int]] = None, roi_start: Optional[Sequence[int]] = None, roi_end: Optional[Sequence[int]] = None, + allow_missing_keys: bool = False, ) -> None: """ Args: @@ -225,13 +228,14 @@ def __init__( roi_size: size of the crop ROI. roi_start: voxel coordinates for start of the crop ROI. roi_end: voxel coordinates for end of the crop ROI. + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.cropper = SpatialCrop(roi_center, roi_size, roi_start, roi_end) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.keys: + for key in self.generator(d): d[key] = self.cropper(d[key]) return d @@ -245,15 +249,16 @@ class CenterSpatialCropd(MapTransform): See also: monai.transforms.MapTransform roi_size: the size of the crop region e.g. [224,224,128] If its components have non-positive values, the corresponding size of input image will be used. + allow_missing_keys: don't raise exception if key is missing. """ - def __init__(self, keys: KeysCollection, roi_size: Union[Sequence[int], int]) -> None: - super().__init__(keys) + def __init__(self, keys: KeysCollection, roi_size: Union[Sequence[int], int], allow_missing_keys: bool = False) -> None: + super().__init__(keys, allow_missing_keys) self.cropper = CenterSpatialCrop(roi_size) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.keys: + for key in self.generator(d): d[key] = self.cropper(d[key]) return d @@ -274,6 +279,7 @@ class RandSpatialCropd(RandomizableTransform, MapTransform): random_center: crop at random position as center or the image center. random_size: crop with random size or specific size ROI. The actual size is sampled from `randint(roi_size, img_size)`. + allow_missing_keys: don't raise exception if key is missing. """ def __init__( @@ -282,9 +288,10 @@ def __init__( roi_size: Union[Sequence[int], int], random_center: bool = True, random_size: bool = True, + allow_missing_keys: bool = False, ) -> None: RandomizableTransform.__init__(self) - MapTransform.__init__(self, keys) + MapTransform.__init__(self, keys, allow_missing_keys) self.roi_size = roi_size self.random_center = random_center self.random_size = random_size @@ -304,7 +311,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self.randomize(d[self.keys[0]].shape[1:]) # image shape from the first data key if self._size is None: raise AssertionError - for key in self.keys: + for key in self.generator(d): if self.random_center: d[key] = d[key][self._slices] else: @@ -388,6 +395,7 @@ def __init__( margin: int = 0, start_coord_key: str = "foreground_start_coord", end_coord_key: str = "foreground_end_coord", + allow_missing_keys: bool = False, ) -> None: """ Args: @@ -400,8 +408,9 @@ def __init__( margin: add margin value to spatial dims of the bounding box, if only 1 value provided, use it for all dims. start_coord_key: key to record the start coordinate of spatial bounding box for foreground. end_coord_key: key to record the end coordinate of spatial bounding box for foreground. + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.source_key = source_key self.select_fn = select_fn self.channel_indices = ensure_tuple(channel_indices) if channel_indices is not None else None @@ -417,7 +426,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[self.start_coord_key] = np.asarray(box_start) d[self.end_coord_key] = np.asarray(box_end) cropper = SpatialCrop(roi_start=box_start, roi_end=box_end) - for key in self.keys: + for key in self.generator(d): d[key] = cropper(d[key]) return d @@ -435,6 +444,7 @@ class RandWeightedCropd(RandomizableTransform, MapTransform): If its components have non-positive values, the corresponding size of `img` will be used. num_samples: number of samples (image patches) to take in the returned list. center_coord_key: if specified, the actual sampling location will be stored with the corresponding key. + allow_missing_keys: don't raise exception if key is missing. See Also: :py:class:`monai.transforms.RandWeightedCrop` @@ -447,9 +457,10 @@ def __init__( spatial_size: Union[Sequence[int], int], num_samples: int = 1, center_coord_key: Optional[str] = None, + allow_missing_keys: bool = False, ): RandomizableTransform.__init__(self) - MapTransform.__init__(self, keys) + MapTransform.__init__(self, keys, allow_missing_keys) self.spatial_size = ensure_tuple(spatial_size) self.w_key = w_key self.num_samples = int(num_samples) @@ -468,7 +479,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n results: List[Dict[Hashable, np.ndarray]] = [{} for _ in range(self.num_samples)] for key in data.keys(): - if key in self.keys: + for key in self.generator(d): img = d[key] if img.shape[1:] != d[self.w_key].shape[1:]: raise ValueError( diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 7d0d66d2ba..e2485d843a 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -140,19 +140,20 @@ class ShiftIntensityd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.ShiftIntensity`. """ - def __init__(self, keys: KeysCollection, offset: float) -> None: + def __init__(self, keys: KeysCollection, offset: float, allow_missing_keys: bool = False) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` offset: offset value to shift the intensity of image. + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.shifter = ShiftIntensity(offset) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.keys: + for key in self.generator(d): d[key] = self.shifter(d[key]) return d diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 9c9729d250..1693e08176 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -13,7 +13,7 @@ """ from abc import ABC, abstractmethod -from typing import Any, Hashable, Optional, Tuple +from typing import Any, Dict, Hashable, Iterable, Optional, Tuple import numpy as np @@ -178,7 +178,7 @@ def __call__(self, data): if key in data: # update output data with some_transform_function(data[key]). else: - # do nothing or some exceptions handling. + # raise exception unless allow_missing_keys==True. return data Raises: @@ -187,8 +187,9 @@ def __call__(self, data): """ - def __init__(self, keys: KeysCollection) -> None: + def __init__(self, keys: KeysCollection, allow_missing_keys: bool) -> None: self.keys: Tuple[Hashable, ...] = ensure_tuple(keys) + self.allow_missing_keys = allow_missing_keys if not self.keys: raise ValueError("keys must be non empty.") for key in self.keys: @@ -224,3 +225,19 @@ def __call__(self, data): """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + def generator( + self, + data: Dict[str, Any], + extra_iterables: Optional[Iterable] = None, + ): + if extra_iterables is None: + for key in self.keys: + if key not in data.keys() and self.allow_missing_keys: + continue + yield key + else: + for key, *extra_iterables in zip(self.keys, extra_iterables): + if key not in data.keys() and self.allow_missing_keys: + continue + yield [key] + extra_iterables From 3f6671b62d650a7f108d5ba3364a2e626bc3f858 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 2 Mar 2021 12:57:02 +0000 Subject: [PATCH 2/6] add all classes Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/apps/deepgrow/transforms.py | 20 +-- monai/transforms/croppad/dictionary.py | 45 +++--- monai/transforms/intensity/dictionary.py | 97 +++++++----- monai/transforms/io/dictionary.py | 12 +- monai/transforms/post/dictionary.py | 43 +++--- monai/transforms/spatial/dictionary.py | 135 ++++++++++------- monai/transforms/transform.py | 24 ++- monai/transforms/utility/dictionary.py | 180 +++++++++++++---------- 8 files changed, 332 insertions(+), 224 deletions(-) diff --git a/monai/apps/deepgrow/transforms.py b/monai/apps/deepgrow/transforms.py index 2998f9c31f..9a7f0f7458 100644 --- a/monai/apps/deepgrow/transforms.py +++ b/monai/apps/deepgrow/transforms.py @@ -484,7 +484,7 @@ def __call__(self, data): else: cropper = SpatialCrop(roi_start=box_start, roi_end=box_end) - for key in self.generator(d): + for key in self.key_iterator(d): meta_key = f"{key}_{self.meta_key_postfix}" d[meta_key][self.start_coord_key] = box_start d[meta_key][self.end_coord_key] = box_end @@ -701,7 +701,7 @@ def __call__(self, data): cropper = SpatialCrop(roi_start=box_start, roi_end=box_end) box_start, box_end = cropper.roi_start, cropper.roi_end - for key in self.generator(d): + for key in self.key_iterator(d): if not np.array_equal(d[key].shape[1:], original_spatial_shape): raise RuntimeError("All the image specified in keys should have same spatial shape") meta_key = f"{key}_{self.meta_key_postfix}" @@ -808,6 +808,7 @@ class RestoreLabeld(MapTransform): end_coord_key: key that records the end coordinate of spatial bounding box for foreground. original_shape_key: key that records original shape for foreground. cropped_shape_key: key that records cropped shape for foreground. + allow_missing_keys: don't raise exception if key is missing. """ def __init__( @@ -822,8 +823,9 @@ def __init__( end_coord_key: str = "foreground_end_coord", original_shape_key: str = "foreground_original_shape", cropped_shape_key: str = "foreground_cropped_shape", + allow_missing_keys: bool = False, ) -> None: - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.ref_image = ref_image self.slice_only = slice_only self.mode = ensure_tuple_rep(mode, len(self.keys)) @@ -838,15 +840,15 @@ def __call__(self, data): d = dict(data) meta_dict: Dict = d[f"{self.ref_image}_{self.meta_key_postfix}"] - for idx, key in enumerate(self.keys): + for key, mode, align_corners in self.key_iterator(d, self.mode, self.align_corners): image = d[key] # Undo Resize current_shape = image.shape cropped_shape = meta_dict[self.cropped_shape_key] if np.any(np.not_equal(current_shape, cropped_shape)): - resizer = Resize(spatial_size=cropped_shape[1:], mode=self.mode[idx]) - image = resizer(image, mode=self.mode[idx], align_corners=self.align_corners[idx]) + resizer = Resize(spatial_size=cropped_shape[1:], mode=mode) + image = resizer(image, mode=mode, align_corners=align_corners) # Undo Crop original_shape = meta_dict[self.original_shape_key] @@ -866,8 +868,8 @@ def __call__(self, data): spatial_size = spatial_shape[-len(current_size) :] if np.any(np.not_equal(current_size, spatial_size)): - resizer = Resize(spatial_size=spatial_size, mode=self.mode[idx]) - result = resizer(result, mode=self.mode[idx], align_corners=self.align_corners[idx]) + resizer = Resize(spatial_size=spatial_size, mode=mode) + result = resizer(result, mode=mode, align_corners=align_corners) # Undo Slicing slice_idx = meta_dict.get("slice_idx") @@ -925,7 +927,7 @@ def __call__(self, data): guidance = d[self.guidance] if len(guidance) < 3: raise RuntimeError("Guidance does not container slice_idx!") - for key in self.generator(d): + for key in self.key_iterator(d): img_slice, idx = self._apply(d[key], guidance) d[key] = img_slice d[f"{key}_{self.meta_key_postfix}"]["slice_idx"] = idx diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 4462e599eb..30652d1be8 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -118,7 +118,7 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key, m in self.generator(d, self.mode): + for key, m in self.key_iterator(d, self.mode): d[key] = self.padder(d[key], mode=m) return d @@ -235,7 +235,7 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.generator(d): + for key in self.key_iterator(d): d[key] = self.cropper(d[key]) return d @@ -258,7 +258,7 @@ def __init__(self, keys: KeysCollection, roi_size: Union[Sequence[int], int], al def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.generator(d): + for key in self.key_iterator(d): d[key] = self.cropper(d[key]) return d @@ -311,7 +311,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self.randomize(d[self.keys[0]].shape[1:]) # image shape from the first data key if self._size is None: raise AssertionError - for key in self.generator(d): + for key in self.key_iterator(d): if self.random_center: d[key] = d[key][self._slices] else: @@ -426,7 +426,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[self.start_coord_key] = np.asarray(box_start) d[self.end_coord_key] = np.asarray(box_end) cropper = SpatialCrop(roi_start=box_start, roi_end=box_end) - for key in self.generator(d): + for key in self.key_iterator(d): d[key] = cropper(d[key]) return d @@ -479,7 +479,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n results: List[Dict[Hashable, np.ndarray]] = [{} for _ in range(self.num_samples)] for key in data.keys(): - for key in self.generator(d): + for key in self.key_iterator(d): img = d[key] if img.shape[1:] != d[self.w_key].shape[1:]: raise ValueError( @@ -528,6 +528,7 @@ class RandCropByPosNegLabeld(RandomizableTransform, MapTransform): `image_threshold`, and randomly select crop centers based on them, need to provide `fg_indices_key` and `bg_indices_key` together, expect to be 1 dim array of spatial indices after flattening. a typical usage is to call `FgBgToIndicesd` transform first and cache the results. + allow_missing_keys: don't raise exception if key is missing. Raises: ValueError: When ``pos`` or ``neg`` are negative. @@ -547,9 +548,10 @@ def __init__( image_threshold: float = 0.0, fg_indices_key: Optional[str] = None, bg_indices_key: Optional[str] = None, + allow_missing_keys: bool = False, ) -> None: RandomizableTransform.__init__(self) - MapTransform.__init__(self, keys) + MapTransform.__init__(self, keys, allow_missing_keys) self.label_key = label_key self.spatial_size: Union[Tuple[int, ...], Sequence[int], int] = spatial_size if pos < 0 or neg < 0: @@ -594,15 +596,15 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n if self.centers is None: raise AssertionError results: List[Dict[Hashable, np.ndarray]] = [{} for _ in range(self.num_samples)] - for key in data.keys(): - if key in self.keys: + + for i, center in enumerate(self.centers): + for key in self.key_iterator(d): img = d[key] - for i, center in enumerate(self.centers): - cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore - results[i][key] = cropper(img) - else: - for i in range(self.num_samples): - results[i][key] = data[key] + cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore + results[i][key] = cropper(img) + # fill in the extra keys with unmodified data + for key in set(data.keys()).difference(set(self.keys)): + results[i][key] = data[key] return results @@ -620,6 +622,7 @@ class ResizeWithPadOrCropd(MapTransform): ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} One of the listed string values or a user supplied function for padding. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + allow_missing_keys: don't raise exception if key is missing. """ @@ -628,13 +631,14 @@ def __init__( keys: KeysCollection, spatial_size: Union[Sequence[int], int], mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT, + allow_missing_keys: bool = False, ) -> None: - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.padcropper = ResizeWithPadOrCrop(spatial_size=spatial_size, mode=mode) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.padcropper(d[key]) return d @@ -649,10 +653,11 @@ class BoundingRectd(MapTransform): bbox_key_postfix: the output bounding box coordinates will be written to the value of `{key}_{bbox_key_postfix}`. select_fn: function to select expected foreground, default is to select values > 0. + allow_missing_keys: don't raise exception if key is missing. """ - def __init__(self, keys: KeysCollection, bbox_key_postfix: str = "bbox", select_fn: Callable = lambda x: x > 0): - super().__init__(keys=keys) + def __init__(self, keys: KeysCollection, bbox_key_postfix: str = "bbox", select_fn: Callable = lambda x: x > 0, allow_missing_keys: bool = False): + super().__init__(keys, allow_missing_keys) self.bbox = BoundingRect(select_fn=select_fn) self.bbox_key_postfix = bbox_key_postfix @@ -661,7 +666,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda See also: :py:class:`monai.transforms.utils.generate_spatial_bounding_box`. """ d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): bbox = self.bbox(d[key]) key_to_add = f"{key}_{self.bbox_key_postfix}" if key_to_add in d: diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index e2485d843a..d2bca28678 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -153,7 +153,7 @@ def __init__(self, keys: KeysCollection, offset: float, allow_missing_keys: bool def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.generator(d): + for key in self.key_iterator(d): d[key] = self.shifter(d[key]) return d @@ -163,7 +163,7 @@ class RandShiftIntensityd(RandomizableTransform, MapTransform): Dictionary-based version :py:class:`monai.transforms.RandShiftIntensity`. """ - def __init__(self, keys: KeysCollection, offsets: Union[Tuple[float, float], float], prob: float = 0.1) -> None: + def __init__(self, keys: KeysCollection, offsets: Union[Tuple[float, float], float], prob: float = 0.1, allow_missing_keys: bool = False) -> None: """ Args: keys: keys of the corresponding items to be transformed. @@ -172,8 +172,9 @@ def __init__(self, keys: KeysCollection, offsets: Union[Tuple[float, float], flo if single number, offset value is picked from (-offsets, offsets). prob: probability of rotating. (Default 0.1, with 10% probability it returns a rotated array.) + allow_missing_keys: don't raise exception if key is missing. """ - MapTransform.__init__(self, keys) + MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) if isinstance(offsets, (int, float)): @@ -193,7 +194,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda if not self._do_transform: return d shifter = ShiftIntensity(self._offset) - for key in self.keys: + for key in self.key_iterator(d): d[key] = shifter(d[key]) return d @@ -206,7 +207,7 @@ class ScaleIntensityd(MapTransform): """ def __init__( - self, keys: KeysCollection, minv: float = 0.0, maxv: float = 1.0, factor: Optional[float] = None + self, keys: KeysCollection, minv: float = 0.0, maxv: float = 1.0, factor: Optional[float] = None, allow_missing_keys: bool = False ) -> None: """ Args: @@ -215,14 +216,15 @@ def __init__( minv: minimum value of output data. maxv: maximum value of output data. factor: factor scale by ``v = v * (1 + factor)``. + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.scaler = ScaleIntensity(minv, maxv, factor) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.scaler(d[key]) return d @@ -232,7 +234,7 @@ class RandScaleIntensityd(RandomizableTransform, MapTransform): Dictionary-based version :py:class:`monai.transforms.RandScaleIntensity`. """ - def __init__(self, keys: KeysCollection, factors: Union[Tuple[float, float], float], prob: float = 0.1) -> None: + def __init__(self, keys: KeysCollection, factors: Union[Tuple[float, float], float], prob: float = 0.1, allow_missing_keys: bool = False) -> None: """ Args: keys: keys of the corresponding items to be transformed. @@ -241,9 +243,10 @@ def __init__(self, keys: KeysCollection, factors: Union[Tuple[float, float], flo if single number, factor value is picked from (-factors, factors). prob: probability of rotating. (Default 0.1, with 10% probability it returns a rotated array.) + allow_missing_keys: don't raise exception if key is missing. """ - MapTransform.__init__(self, keys) + MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) if isinstance(factors, (int, float)): @@ -263,7 +266,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda if not self._do_transform: return d scaler = ScaleIntensity(minv=None, maxv=None, factor=self.factor) - for key in self.keys: + for key in self.key_iterator(d): d[key] = scaler(d[key]) return d @@ -283,6 +286,7 @@ class NormalizeIntensityd(MapTransform): channel_wise: if using calculated mean and std, calculate on each channel separately or calculate on the entire image directly. dtype: output data type, defaut to float32. + allow_missing_keys: don't raise exception if key is missing. """ def __init__( @@ -293,13 +297,14 @@ def __init__( nonzero: bool = False, channel_wise: bool = False, dtype: DtypeLike = np.float32, + allow_missing_keys: bool = False, ) -> None: super().__init__(keys) self.normalizer = NormalizeIntensity(subtrahend, divisor, nonzero, channel_wise, dtype) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.normalizer(d[key]) return d @@ -314,15 +319,16 @@ class ThresholdIntensityd(MapTransform): threshold: the threshold to filter intensity values. above: filter values above the threshold or below the threshold, default is True. cval: value to fill the remaining parts of the image, default is 0. + allow_missing_keys: don't raise exception if key is missing. """ - def __init__(self, keys: KeysCollection, threshold: float, above: bool = True, cval: float = 0.0) -> None: + def __init__(self, keys: KeysCollection, threshold: float, above: bool = True, cval: float = 0.0, allow_missing_keys: bool = False) -> None: super().__init__(keys) self.filter = ThresholdIntensity(threshold, above, cval) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.filter(d[key]) return d @@ -339,17 +345,18 @@ class ScaleIntensityRanged(MapTransform): b_min: intensity target range min. b_max: intensity target range max. clip: whether to perform clip after scaling. + allow_missing_keys: don't raise exception if key is missing. """ def __init__( - self, keys: KeysCollection, a_min: float, a_max: float, b_min: float, b_max: float, clip: bool = False + self, keys: KeysCollection, a_min: float, a_max: float, b_min: float, b_max: float, clip: bool = False, allow_missing_keys: bool = False ) -> None: - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.scaler = ScaleIntensityRange(a_min, a_max, b_min, b_max, clip) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.scaler(d[key]) return d @@ -365,15 +372,16 @@ class AdjustContrastd(MapTransform): keys: keys of the corresponding items to be transformed. See also: monai.transforms.MapTransform gamma: gamma value to adjust the contrast as function. + allow_missing_keys: don't raise exception if key is missing. """ - def __init__(self, keys: KeysCollection, gamma: float) -> None: - super().__init__(keys) + def __init__(self, keys: KeysCollection, gamma: float, allow_missing_keys: bool = False) -> None: + super().__init__(keys, allow_missing_keys) self.adjuster = AdjustContrast(gamma) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.adjuster(d[key]) return d @@ -391,12 +399,13 @@ class RandAdjustContrastd(RandomizableTransform, MapTransform): prob: Probability of adjustment. gamma: Range of gamma values. If single number, value is picked from (0.5, gamma), default is (0.5, 4.5). + allow_missing_keys: don't raise exception if key is missing. """ def __init__( - self, keys: KeysCollection, prob: float = 0.1, gamma: Union[Tuple[float, float], float] = (0.5, 4.5) + self, keys: KeysCollection, prob: float = 0.1, gamma: Union[Tuple[float, float], float] = (0.5, 4.5), allow_missing_keys: bool = False ) -> None: - MapTransform.__init__(self, keys) + MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) if isinstance(gamma, (int, float)): @@ -424,7 +433,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda if not self._do_transform: return d adjuster = AdjustContrast(self.gamma_value) - for key in self.keys: + for key in self.key_iterator(d): d[key] = adjuster(d[key]) return d @@ -442,6 +451,7 @@ class ScaleIntensityRangePercentilesd(MapTransform): b_max: intensity target range max. clip: whether to perform clip after scaling. relative: whether to scale to the corresponding percentiles of [b_min, b_max] + allow_missing_keys: don't raise exception if key is missing. """ def __init__( @@ -453,13 +463,14 @@ def __init__( b_max: float, clip: bool = False, relative: bool = False, + allow_missing_keys: bool = False, ) -> None: - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.scaler = ScaleIntensityRangePercentiles(lower, upper, b_min, b_max, clip, relative) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.scaler(d[key]) return d @@ -478,6 +489,7 @@ class MaskIntensityd(MapTransform): if None, will extract the mask data from input data based on `mask_key`. mask_key: the key to extract mask data from input dictionary, only works when `mask_data` is None. + allow_missing_keys: don't raise exception if key is missing. """ @@ -486,14 +498,15 @@ def __init__( keys: KeysCollection, mask_data: Optional[np.ndarray] = None, mask_key: Optional[str] = None, + allow_missing_keys: bool = False, ) -> None: - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.converter = MaskIntensity(mask_data) self.mask_key = mask_key if mask_data is None else None def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.converter(d[key], d[self.mask_key]) if self.mask_key is not None else self.converter(d[key]) return d @@ -510,16 +523,17 @@ class GaussianSmoothd(MapTransform): use it for all spatial dimensions. approx: discrete Gaussian kernel type, available options are "erf", "sampled", and "scalespace". see also :py:meth:`monai.networks.layers.GaussianFilter`. + allow_missing_keys: don't raise exception if key is missing. """ - def __init__(self, keys: KeysCollection, sigma: Union[Sequence[float], float], approx: str = "erf") -> None: - super().__init__(keys) + def __init__(self, keys: KeysCollection, sigma: Union[Sequence[float], float], approx: str = "erf", allow_missing_keys: bool = False) -> None: + super().__init__(keys, allow_missing_keys) self.converter = GaussianSmooth(sigma, approx=approx) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.converter(d[key]) return d @@ -537,6 +551,7 @@ class RandGaussianSmoothd(RandomizableTransform, MapTransform): approx: discrete Gaussian kernel type, available options are "erf", "sampled", and "scalespace". see also :py:meth:`monai.networks.layers.GaussianFilter`. prob: probability of Gaussian smooth. + allow_missing_keys: don't raise exception if key is missing. """ @@ -548,8 +563,9 @@ def __init__( sigma_z: Tuple[float, float] = (0.25, 1.5), approx: str = "erf", prob: float = 0.1, + allow_missing_keys: bool = False, ) -> None: - MapTransform.__init__(self, keys) + MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) self.sigma_x = sigma_x self.sigma_y = sigma_y @@ -567,7 +583,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self.randomize() if not self._do_transform: return d - for key in self.keys: + for key in self.key_iterator(d): sigma = ensure_tuple_size(tup=(self.x, self.y, self.z), dim=d[key].ndim - 1) d[key] = GaussianSmooth(sigma=sigma, approx=self.approx)(d[key]) return d @@ -589,6 +605,7 @@ class GaussianSharpend(MapTransform): alpha: weight parameter to compute the final result. approx: discrete Gaussian kernel type, available options are "erf", "sampled", and "scalespace". see also :py:meth:`monai.networks.layers.GaussianFilter`. + allow_missing_keys: don't raise exception if key is missing. """ @@ -599,13 +616,14 @@ def __init__( sigma2: Union[Sequence[float], float] = 1.0, alpha: float = 30.0, approx: str = "erf", + allow_missing_keys: bool = False, ) -> None: - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.converter = GaussianSharpen(sigma1, sigma2, alpha, approx=approx) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.converter(d[key]) return d @@ -630,6 +648,7 @@ class RandGaussianSharpend(RandomizableTransform, MapTransform): approx: discrete Gaussian kernel type, available options are "erf", "sampled", and "scalespace". see also :py:meth:`monai.networks.layers.GaussianFilter`. prob: probability of Gaussian sharpen. + allow_missing_keys: don't raise exception if key is missing. """ @@ -645,8 +664,9 @@ def __init__( alpha: Tuple[float, float] = (10.0, 30.0), approx: str = "erf", prob: float = 0.1, + allow_missing_keys: bool = False, ): - MapTransform.__init__(self, keys) + MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) self.sigma1_x = sigma1_x self.sigma1_y = sigma1_y @@ -675,7 +695,7 @@ def __call__(self, data): self.randomize() if not self._do_transform: return d - for key in self.keys: + for key in self.key_iterator(d): sigma1 = ensure_tuple_size(tup=(self.x1, self.y1, self.z1), dim=d[key].ndim - 1) sigma2 = ensure_tuple_size(tup=(self.x2, self.y2, self.z2), dim=d[key].ndim - 1) d[key] = GaussianSharpen(sigma1=sigma1, sigma2=sigma2, alpha=self.a, approx=self.approx)(d[key]) @@ -694,12 +714,13 @@ class RandHistogramShiftd(RandomizableTransform, MapTransform): a smaller number of control points allows for larger intensity shifts. if two values provided, number of control points selecting from range (min_value, max_value). prob: probability of histogram shift. + allow_missing_keys: don't raise exception if key is missing. """ def __init__( - self, keys: KeysCollection, num_control_points: Union[Tuple[int, int], int] = 10, prob: float = 0.1 + self, keys: KeysCollection, num_control_points: Union[Tuple[int, int], int] = 10, prob: float = 0.1, allow_missing_keys: bool = False ) -> None: - MapTransform.__init__(self, keys) + MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) if isinstance(num_control_points, int): if num_control_points <= 2: @@ -727,7 +748,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self.randomize() if not self._do_transform: return d - for key in self.keys: + for key in self.key_iterator(d): img_min, img_max = d[key].min(), d[key].max() reference_control_points_scaled = self.reference_control_points * (img_max - img_min) + img_min floating_control_points_scaled = self.floating_control_points * (img_max - img_min) + img_min diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index 55707f750e..799aca9bea 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -59,6 +59,7 @@ def __init__( dtype: DtypeLike = np.float32, meta_key_postfix: str = "meta_dict", overwriting: bool = False, + allow_missing_keys: bool = False, *args, **kwargs, ) -> None: @@ -76,10 +77,11 @@ def __init__( For example, load nifti file for `image`, store the metadata into `image_meta_dict`. overwriting: whether allow to overwrite existing meta data of same key. default is False, which will raise exception if encountering existing key. + allow_missing_keys: don't raise exception if key is missing. args: additional parameters for reader if providing a reader name. kwargs: additional parameters for reader if providing a reader name. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self._loader = LoadImage(reader, False, dtype, *args, **kwargs) if not isinstance(meta_key_postfix, str): raise TypeError(f"meta_key_postfix must be a str but is {type(meta_key_postfix).__name__}.") @@ -96,7 +98,7 @@ def __call__(self, data, reader: Optional[ImageReader] = None): """ d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): data = self._loader(d[key], reader) if not isinstance(data, (tuple, list)): raise ValueError("loader must return a tuple or list.") @@ -155,6 +157,7 @@ class SaveImaged(MapTransform): it's used for NIfTI format only. save_batch: whether the import image is a batch data, default to `False`. usually pre-transforms run for channel first data, while post-transforms run for batch data. + allow_missing_keys: don't raise exception if key is missing. """ @@ -172,8 +175,9 @@ def __init__( dtype: DtypeLike = np.float64, output_dtype: DtypeLike = np.float32, save_batch: bool = False, + allow_missing_keys: bool = False, ) -> None: - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.meta_key_postfix = meta_key_postfix self._saver = SaveImage( output_dir=output_dir, @@ -190,7 +194,7 @@ def __init__( def __call__(self, data): d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): meta_data = d[f"{key}_{self.meta_key_postfix}"] if self.meta_key_postfix is not None else None self._saver(img=d[key], meta_data=meta_data) return d diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 85abdac0ac..b1b3dfbf00 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -71,6 +71,7 @@ def __init__( sigmoid: Union[Sequence[bool], bool] = False, softmax: Union[Sequence[bool], bool] = False, other: Optional[Union[Sequence[Callable], Callable]] = None, + allow_missing_keys: bool = False, ) -> None: """ Args: @@ -83,9 +84,10 @@ def __init__( other: callable function to execute other activation layers, for example: `other = lambda x: torch.tanh(x)`. it also can be a sequence of Callable, each element corresponds to a key in ``keys``. + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.sigmoid = ensure_tuple_rep(sigmoid, len(self.keys)) self.softmax = ensure_tuple_rep(softmax, len(self.keys)) self.other = ensure_tuple_rep(other, len(self.keys)) @@ -93,8 +95,8 @@ def __init__( def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) - for idx, key in enumerate(self.keys): - d[key] = self.converter(d[key], self.sigmoid[idx], self.softmax[idx], self.other[idx]) + for key, sigmoid, softmax, other in self.key_iterator(d, self.sigmoid, self.softmax, self.other): + d[key] = self.converter(d[key], sigmoid, softmax, other) return d @@ -111,6 +113,7 @@ def __init__( n_classes: Optional[Union[Sequence[int], int]] = None, threshold_values: Union[Sequence[bool], bool] = False, logit_thresh: Union[Sequence[float], float] = 0.5, + allow_missing_keys: bool = False, ) -> None: """ Args: @@ -126,9 +129,10 @@ def __init__( it also can be a sequence of bool, each element corresponds to a key in ``keys``. logit_thresh: the threshold value for thresholding operation, default is 0.5. it also can be a sequence of float, each element corresponds to a key in ``keys``. + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.argmax = ensure_tuple_rep(argmax, len(self.keys)) self.to_onehot = ensure_tuple_rep(to_onehot, len(self.keys)) self.n_classes = ensure_tuple_rep(n_classes, len(self.keys)) @@ -138,14 +142,14 @@ def __init__( def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) - for idx, key in enumerate(self.keys): + for key, argmax, to_onehot, n_classes, threshold_values, logit_thresh in self.key_iterator(d, self.argmax, self.to_onehot, self.n_classes, self.threshold_values, self.logit_thresh): d[key] = self.converter( d[key], - self.argmax[idx], - self.to_onehot[idx], - self.n_classes[idx], - self.threshold_values[idx], - self.logit_thresh[idx], + argmax, + to_onehot, + n_classes, + threshold_values, + logit_thresh, ) return d @@ -161,6 +165,7 @@ def __init__( applied_labels: Union[Sequence[int], int], independent: bool = True, connectivity: Optional[int] = None, + allow_missing_keys: bool = False, ) -> None: """ Args: @@ -175,14 +180,15 @@ def __init__( connectivity: Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor. Accepted values are ranging from 1 to input.ndim. If ``None``, a full connectivity of ``input.ndim`` is used. + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.converter = KeepLargestConnectedComponent(applied_labels, independent, connectivity) def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.converter(d[key]) return d @@ -192,20 +198,21 @@ class LabelToContourd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.LabelToContour`. """ - def __init__(self, keys: KeysCollection, kernel_type: str = "Laplace") -> None: + def __init__(self, keys: KeysCollection, kernel_type: str = "Laplace", allow_missing_keys: bool = False) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` kernel_type: the method applied to do edge detection, default is "Laplace". + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.converter = LabelToContour(kernel_type=kernel_type) def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.converter(d[key]) return d @@ -221,6 +228,7 @@ def __init__( keys: KeysCollection, ensemble: Callable[[Union[Sequence[torch.Tensor], torch.Tensor]], torch.Tensor], output_key: Optional[str] = None, + allow_missing_keys: bool = False, ) -> None: """ Args: @@ -229,13 +237,14 @@ def __init__( output_key: the key to store ensemble result in the dictionary. ensemble: callable method to execute ensemble on specified data. if only 1 key provided in `keys`, `output_key` can be None and use `keys` as default. + allow_missing_keys: don't raise exception if key is missing. Raises: TypeError: When ``ensemble`` is not ``callable``. ValueError: When ``len(keys) > 1`` and ``output_key=None``. Incompatible values. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) if not callable(ensemble): raise TypeError(f"ensemble must be callable but is {type(ensemble).__name__}.") self.ensemble = ensemble @@ -249,7 +258,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc if len(self.keys) == 1: items = d[self.keys[0]] else: - items = [d[key] for key in self.keys] + items = [d[key] for key in self.key_iterator(d)] d[self.output_key] = self.ensemble(items) return d diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index f29258bf28..6c33ae00a4 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -125,6 +125,7 @@ def __init__( align_corners: Union[Sequence[bool], bool] = False, dtype: Optional[Union[Sequence[DtypeLike], DtypeLike]] = np.float64, meta_key_postfix: str = "meta_dict", + allow_missing_keys: bool = False, ) -> None: """ Args: @@ -160,12 +161,13 @@ def __init__( default is `meta_dict`, the meta data is a dictionary object. For example, to handle key `image`, read/write affine matrices from the metadata `image_meta_dict` dictionary's `affine` field. + allow_missing_keys: don't raise exception if key is missing. Raises: TypeError: When ``meta_key_postfix`` is not a ``str``. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.spacing_transform = Spacing(pixdim, diagonal=diagonal) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) @@ -179,17 +181,17 @@ def __call__( self, data: Mapping[Union[Hashable, str], Dict[str, np.ndarray]] ) -> Dict[Union[Hashable, str], Union[np.ndarray, Dict[str, np.ndarray]]]: d: Dict = dict(data) - for idx, key in enumerate(self.keys): + for key, mode, padding_mode, align_corners, dtype in self.key_iterator(d, self.mode, self.padding_mode, self.align_corners, self.dtype): meta_data = d[f"{key}_{self.meta_key_postfix}"] # resample array of each corresponding key # using affine fetched from d[affine_key] d[key], _, new_affine = self.spacing_transform( data_array=np.asarray(d[key]), affine=meta_data["affine"], - mode=self.mode[idx], - padding_mode=self.padding_mode[idx], - align_corners=self.align_corners[idx], - dtype=self.dtype[idx], + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, + dtype=dtype, ) # set the 'affine' key meta_data["affine"] = new_affine @@ -214,6 +216,7 @@ def __init__( as_closest_canonical: bool = False, labels: Optional[Sequence[Tuple[str, str]]] = tuple(zip("LPI", "RAS")), meta_key_postfix: str = "meta_dict", + allow_missing_keys: bool = False, ) -> None: """ Args: @@ -230,6 +233,7 @@ def __init__( default is `meta_dict`, the meta data is a dictionary object. For example, to handle key `image`, read/write affine matrices from the metadata `image_meta_dict` dictionary's `affine` field. + allow_missing_keys: don't raise exception if key is missing. Raises: TypeError: When ``meta_key_postfix`` is not a ``str``. @@ -238,7 +242,7 @@ def __init__( `nibabel.orientations.ornt2axcodes`. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.ornt_transform = Orientation(axcodes=axcodes, as_closest_canonical=as_closest_canonical, labels=labels) if not isinstance(meta_key_postfix, str): raise TypeError(f"meta_key_postfix must be a str but is {type(meta_key_postfix).__name__}.") @@ -248,7 +252,7 @@ def __call__( self, data: Mapping[Union[Hashable, str], Dict[str, np.ndarray]] ) -> Dict[Union[Hashable, str], Union[np.ndarray, Dict[str, np.ndarray]]]: d: Dict = dict(data) - for key in self.keys: + for key in self.key_iterator(d): meta_data = d[f"{key}_{self.meta_key_postfix}"] d[key], _, new_affine = self.ornt_transform(d[key], affine=meta_data["affine"]) meta_data["affine"] = new_affine @@ -260,19 +264,20 @@ class Rotate90d(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.Rotate90`. """ - def __init__(self, keys: KeysCollection, k: int = 1, spatial_axes: Tuple[int, int] = (0, 1)) -> None: + def __init__(self, keys: KeysCollection, k: int = 1, spatial_axes: Tuple[int, int] = (0, 1), allow_missing_keys: bool = False) -> None: """ Args: k: number of times to rotate by 90 degrees. spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes. Default: (0, 1), this is the first two axis in spatial dimensions. + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.rotator = Rotate90(k, spatial_axes) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.rotator(d[key]) return d @@ -290,6 +295,7 @@ def __init__( prob: float = 0.1, max_k: int = 3, spatial_axes: Tuple[int, int] = (0, 1), + allow_missing_keys: bool = False, ) -> None: """ Args: @@ -301,8 +307,9 @@ def __init__( (Default 3) spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes. Default: (0, 1), this is the first two axis in spatial dimensions. + allow_missing_keys: don't raise exception if key is missing. """ - MapTransform.__init__(self, keys) + MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) self.max_k = max_k @@ -319,7 +326,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Mapping[Hashable, np. d = dict(data) rotator = Rotate90(self._rand_k, self.spatial_axes) - for key in self.keys: + for key in self.key_iterator(d): if self._do_transform: d[key] = rotator(d[key]) return d @@ -344,6 +351,7 @@ class Resized(MapTransform): 'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None. See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate It also can be a sequence of bool or None, each element corresponds to a key in ``keys``. + allow_missing_keys: don't raise exception if key is missing. """ def __init__( @@ -352,16 +360,17 @@ def __init__( spatial_size: Union[Sequence[int], int], mode: InterpolateModeSequence = InterpolateMode.AREA, align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None, + allow_missing_keys: bool = False, ) -> None: - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.resizer = Resize(spatial_size=spatial_size) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for idx, key in enumerate(self.keys): - d[key] = self.resizer(d[key], mode=self.mode[idx], align_corners=self.align_corners[idx]) + for key, mode, align_corners in self.key_iterator(d, self.mode, self.align_corners): + d[key] = self.resizer(d[key], mode=mode, align_corners=align_corners) return d @@ -383,6 +392,7 @@ def __init__( padding_mode: GridSamplePadModeSequence = GridSamplePadMode.REFLECTION, as_tensor_output: bool = True, device: Optional[torch.device] = None, + allow_missing_keys: bool = False, ) -> None: """ Args: @@ -416,12 +426,13 @@ def __init__( as_tensor_output: the computation is implemented using pytorch tensors, this option specifies whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. + allow_missing_keys: don't raise exception if key is missing. See also: - :py:class:`monai.transforms.compose.MapTransform` - :py:class:`RandAffineGrid` for the random affine parameters configurations. """ - MapTransform.__init__(self, keys) + MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) self.rand_affine = RandAffine( prob=1.0, # because probability handled in this class @@ -459,8 +470,8 @@ def __call__( else: grid = create_grid(spatial_size=sp_size) - for idx, key in enumerate(self.keys): - d[key] = self.rand_affine.resampler(d[key], grid, mode=self.mode[idx], padding_mode=self.padding_mode[idx]) + for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): + d[key] = self.rand_affine.resampler(d[key], grid, mode=mode, padding_mode=padding_mode) return d @@ -484,6 +495,7 @@ def __init__( padding_mode: GridSamplePadModeSequence = GridSamplePadMode.REFLECTION, as_tensor_output: bool = False, device: Optional[torch.device] = None, + allow_missing_keys: bool = False, ) -> None: """ Args: @@ -521,12 +533,13 @@ def __init__( as_tensor_output: the computation is implemented using pytorch tensors, this option specifies whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. + allow_missing_keys: don't raise exception if key is missing. See also: - :py:class:`RandAffineGrid` for the random affine parameters configurations. - :py:class:`Affine` for the affine transformation parameters configurations. """ - MapTransform.__init__(self, keys) + MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) self.rand_2d_elastic = Rand2DElastic( spacing=spacing, @@ -576,9 +589,9 @@ def __call__( else: grid = create_grid(spatial_size=sp_size) - for idx, key in enumerate(self.keys): + for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): d[key] = self.rand_2d_elastic.resampler( - d[key], grid, mode=self.mode[idx], padding_mode=self.padding_mode[idx] + d[key], grid, mode=mode, padding_mode=padding_mode ) return d @@ -603,6 +616,7 @@ def __init__( padding_mode: GridSamplePadModeSequence = GridSamplePadMode.REFLECTION, as_tensor_output: bool = False, device: Optional[torch.device] = None, + allow_missing_keys: bool = False, ) -> None: """ Args: @@ -641,12 +655,13 @@ def __init__( as_tensor_output: the computation is implemented using pytorch tensors, this option specifies whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. + allow_missing_keys: don't raise exception if key is missing. See also: - :py:class:`RandAffineGrid` for the random affine parameters configurations. - :py:class:`Affine` for the affine transformation parameters configurations. """ - MapTransform.__init__(self, keys) + MapTransform.__init__(self, keys, allow_missing_keys=) RandomizableTransform.__init__(self, prob) self.rand_3d_elastic = Rand3DElastic( sigma_range=sigma_range, @@ -690,9 +705,9 @@ def __call__( grid[:3] += gaussian(offset)[0] * self.rand_3d_elastic.magnitude grid = self.rand_3d_elastic.rand_affine_grid(grid=grid) - for idx, key in enumerate(self.keys): + for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): d[key] = self.rand_3d_elastic.resampler( - d[key], grid, mode=self.mode[idx], padding_mode=self.padding_mode[idx] + d[key], grid, mode=mode, padding_mode=padding_mode ) return d @@ -707,15 +722,16 @@ class Flipd(MapTransform): Args: keys: Keys to pick data for transformation. spatial_axis: Spatial axes along which to flip over. Default is None. + allow_missing_keys: don't raise exception if key is missing. """ - def __init__(self, keys: KeysCollection, spatial_axis: Optional[Union[Sequence[int], int]] = None) -> None: - super().__init__(keys) + def __init__(self, keys: KeysCollection, spatial_axis: Optional[Union[Sequence[int], int]] = None, allow_missing_keys: bool = False) -> None: + super().__init__(keys, allow_missing_keys) self.flipper = Flip(spatial_axis=spatial_axis) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.flipper(d[key]) return d @@ -731,6 +747,7 @@ class RandFlipd(RandomizableTransform, MapTransform): keys: Keys to pick data for transformation. prob: Probability of flipping. spatial_axis: Spatial axes along which to flip over. Default is None. + allow_missing_keys: don't raise exception if key is missing. """ def __init__( @@ -738,6 +755,7 @@ def __init__( keys: KeysCollection, prob: float = 0.1, spatial_axis: Optional[Union[Sequence[int], int]] = None, + allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys) RandomizableTransform.__init__(self, prob) @@ -748,7 +766,7 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: self.randomize(None) d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): if self._do_transform: d[key] = self.flipper(d[key]) return d @@ -764,11 +782,12 @@ class RandAxisFlipd(RandomizableTransform, MapTransform): Args: keys: Keys to pick data for transformation. prob: Probability of flipping. + allow_missing_keys: don't raise exception if key is missing. """ - def __init__(self, keys: KeysCollection, prob: float = 0.1) -> None: - MapTransform.__init__(self, keys) + def __init__(self, keys: KeysCollection, prob: float = 0.1, allow_missing_keys: bool = False) -> None: + MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) self._axis: Optional[int] = None @@ -781,7 +800,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda flipper = Flip(spatial_axis=self._axis) d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): if self._do_transform: d[key] = flipper(d[key]) return d @@ -812,6 +831,7 @@ class Rotated(MapTransform): If None, use the data type of input data. To be compatible with other modules, the output data type is always ``np.float32``. It also can be a sequence of dtype or None, each element corresponds to a key in ``keys``. + allow_missing_keys: don't raise exception if key is missing. """ def __init__( @@ -823,8 +843,9 @@ def __init__( padding_mode: GridSamplePadModeSequence = GridSamplePadMode.BORDER, align_corners: Union[Sequence[bool], bool] = False, dtype: Union[Sequence[DtypeLike], DtypeLike] = np.float64, + allow_missing_keys: bool = False, ) -> None: - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.rotator = Rotate(angle=angle, keep_size=keep_size) self.mode = ensure_tuple_rep(mode, len(self.keys)) @@ -834,13 +855,13 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for idx, key in enumerate(self.keys): + for key, mode, padding_mode, align_corners, dtype in self.key_iterator(d, self.mode, self.padding_mode, self.align_corners, self.dtype): d[key] = self.rotator( d[key], - mode=self.mode[idx], - padding_mode=self.padding_mode[idx], - align_corners=self.align_corners[idx], - dtype=self.dtype[idx], + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, + dtype=dtype, ) return d @@ -877,6 +898,7 @@ class RandRotated(RandomizableTransform, MapTransform): If None, use the data type of input data. To be compatible with other modules, the output data type is always ``np.float32``. It also can be a sequence of dtype or None, each element corresponds to a key in ``keys``. + allow_missing_keys: don't raise exception if key is missing. """ def __init__( @@ -891,8 +913,9 @@ def __init__( padding_mode: GridSamplePadModeSequence = GridSamplePadMode.BORDER, align_corners: Union[Sequence[bool], bool] = False, dtype: Union[Sequence[DtypeLike], DtypeLike] = np.float64, + allow_missing_keys: bool = False, ) -> None: - MapTransform.__init__(self, keys) + MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) self.range_x = ensure_tuple(range_x) if len(self.range_x) == 1: @@ -929,13 +952,13 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda angle=self.x if d[self.keys[0]].ndim == 3 else (self.x, self.y, self.z), keep_size=self.keep_size, ) - for idx, key in enumerate(self.keys): + for key, mode, padding_mode, align_corners, dtype in self.key_iterator(d, self.mode, self.padding_mode, self.align_corners, self.dtype): d[key] = rotator( d[key], - mode=self.mode[idx], - padding_mode=self.padding_mode[idx], - align_corners=self.align_corners[idx], - dtype=self.dtype[idx], + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, + dtype=dtype, ) return d @@ -962,6 +985,7 @@ class Zoomd(MapTransform): See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate It also can be a sequence of bool or None, each element corresponds to a key in ``keys``. keep_size: Should keep original size (pad if needed), default is True. + allow_missing_keys: don't raise exception if key is missing. """ def __init__( @@ -972,8 +996,9 @@ def __init__( padding_mode: NumpyPadModeSequence = NumpyPadMode.EDGE, align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None, keep_size: bool = True, + allow_missing_keys: bool = False, ) -> None: - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) @@ -981,12 +1006,12 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for idx, key in enumerate(self.keys): + for key, mode, padding_mode, align_corners in self.key_iterator(d, self.mode, self.padding_mode, self.align_corners): d[key] = self.zoomer( d[key], - mode=self.mode[idx], - padding_mode=self.padding_mode[idx], - align_corners=self.align_corners[idx], + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, ) return d @@ -1021,6 +1046,7 @@ class RandZoomd(RandomizableTransform, MapTransform): See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate It also can be a sequence of bool or None, each element corresponds to a key in ``keys``. keep_size: Should keep original size (pad if needed), default is True. + allow_missing_keys: don't raise exception if key is missing. """ def __init__( @@ -1033,8 +1059,9 @@ def __init__( padding_mode: NumpyPadModeSequence = NumpyPadMode.EDGE, align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None, keep_size: bool = True, + allow_missing_keys: bool = False, ) -> None: - MapTransform.__init__(self, keys) + MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) self.min_zoom = ensure_tuple(min_zoom) self.max_zoom = ensure_tuple(max_zoom) @@ -1067,12 +1094,12 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda # if 2 zoom factors provided for 3D data, use the first factor for H and W dims, second factor for D dim self._zoom = ensure_tuple_rep(self._zoom[0], img_dims - 2) + ensure_tuple(self._zoom[-1]) zoomer = Zoom(self._zoom, keep_size=self.keep_size) - for idx, key in enumerate(self.keys): + for key, mode, padding_mode, align_corners in self.key_iterator(d, self.mode, self.padding_mode, self.align_corners): d[key] = zoomer( d[key], - mode=self.mode[idx], - padding_mode=self.padding_mode[idx], - align_corners=self.align_corners[idx], + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, ) return d diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 1693e08176..ff80b92b40 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -226,18 +226,32 @@ def __call__(self, data): """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") - def generator( + def key_iterator( self, data: Dict[str, Any], extra_iterables: Optional[Iterable] = None, ): + """ + Iterate across keys and optionally extra iterables. If key is missing, exception is raised if + `allow_missing_keys==False` (default). If `allow_missing_keys==True`, key is skipped. + + Args: + data: data that the transform will be applied to + extra_iterables: anything else to be iterated through + """ if extra_iterables is None: for key in self.keys: - if key not in data.keys() and self.allow_missing_keys: - continue + if key not in data.keys(): + if self.allow_missing_keys: + continue + else: + raise KeyError("Key was missing and allow_missing_keys==False") yield key else: for key, *extra_iterables in zip(self.keys, extra_iterables): - if key not in data.keys() and self.allow_missing_keys: - continue + if key not in data.keys(): + if self.allow_missing_keys: + continue + else: + raise KeyError("Key was missing and allow_missing_keys==False") yield [key] + extra_iterables diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index e9d923d0fd..665ad45297 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -133,21 +133,22 @@ class Identityd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.Identity`. """ - def __init__(self, keys: KeysCollection) -> None: + def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.identity = Identity() def __call__( self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.identity(d[key]) return d @@ -157,19 +158,20 @@ class AsChannelFirstd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.AsChannelFirst`. """ - def __init__(self, keys: KeysCollection, channel_dim: int = -1) -> None: + def __init__(self, keys: KeysCollection, channel_dim: int = -1, allow_missing_keys: bool = False) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` channel_dim: which dimension of input image is the channel, default is the last dimension. + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.converter = AsChannelFirst(channel_dim=channel_dim) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.converter(d[key]) return d @@ -179,19 +181,20 @@ class AsChannelLastd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.AsChannelLast`. """ - def __init__(self, keys: KeysCollection, channel_dim: int = 0) -> None: + def __init__(self, keys: KeysCollection, channel_dim: int = 0, allow_missing_keys: bool = False) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` channel_dim: which dimension of input image is the channel, default is the first dimension. + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.converter = AsChannelLast(channel_dim=channel_dim) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.converter(d[key]) return d @@ -201,18 +204,19 @@ class AddChanneld(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.AddChannel`. """ - def __init__(self, keys: KeysCollection) -> None: + def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.adder = AddChannel() def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.adder(d[key]) return d @@ -222,19 +226,20 @@ class RepeatChanneld(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.RepeatChannel`. """ - def __init__(self, keys: KeysCollection, repeats: int) -> None: + def __init__(self, keys: KeysCollection, repeats: int, allow_missing_keys: bool = False) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` repeats: the number of repetitions for each element. + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.repeater = RepeatChannel(repeats) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.repeater(d[key]) return d @@ -244,19 +249,20 @@ class RemoveRepeatedChanneld(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.RemoveRepeatedChannel`. """ - def __init__(self, keys: KeysCollection, repeats: int) -> None: + def __init__(self, keys: KeysCollection, repeats: int, allow_missing_keys: bool = False) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` repeats: the number of repetitions for each element. + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.repeater = RemoveRepeatedChannel(repeats) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.repeater(d[key]) return d @@ -273,6 +279,7 @@ def __init__( keys: KeysCollection, output_postfixes: Optional[Sequence[str]] = None, channel_dim: Optional[int] = None, + allow_missing_keys: bool = False, ) -> None: """ Args: @@ -286,9 +293,10 @@ def __init__( to automatically select: if data is numpy array, channel_dim is 0 as `numpy array` is used in the pre transforms, if PyTorch Tensor, channel_dim is 1 as in most of the cases `Tensor` is uses in the post transforms. + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.output_postfixes = output_postfixes self.splitter = SplitChannel(channel_dim=channel_dim) @@ -296,7 +304,7 @@ def __call__( self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): rets = self.splitter(d[key]) postfixes: Sequence = list(range(len(rets))) if self.output_postfixes is None else self.output_postfixes if len(postfixes) != len(rets): @@ -318,6 +326,7 @@ def __init__( self, keys: KeysCollection, dtype: Union[Sequence[Union[DtypeLike, torch.dtype]], DtypeLike, torch.dtype] = np.float32, + allow_missing_keys: bool = False, ) -> None: """ Args: @@ -326,9 +335,10 @@ def __init__( dtype: convert image to this data type, default is `np.float32`. it also can be a sequence of dtypes or torch.dtype, each element corresponds to a key in ``keys``. + allow_missing_keys: don't raise exception if key is missing. """ - MapTransform.__init__(self, keys) + MapTransform.__init__(self, keys, allow_missing_keys) self.dtype = ensure_tuple_rep(dtype, len(self.keys)) self.converter = CastToType() @@ -336,8 +346,8 @@ def __call__( self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: d = dict(data) - for idx, key in enumerate(self.keys): - d[key] = self.converter(d[key], dtype=self.dtype[idx]) + for key, dtype in self.key_iterator(d, self.dtype): + d[key] = self.converter(d[key], dtype=dtype) return d @@ -347,20 +357,21 @@ class ToTensord(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.ToTensor`. """ - def __init__(self, keys: KeysCollection) -> None: + def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.converter = ToTensor() def __call__( self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor, PILImageImage]] ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor, PILImageImage]]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.converter(d[key]) return d @@ -370,20 +381,21 @@ class ToNumpyd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.ToNumpy`. """ - def __init__(self, keys: KeysCollection) -> None: + def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.converter = ToNumpy() def __call__( self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor, PILImageImage]] ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor, PILImageImage]]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.converter(d[key]) return d @@ -393,20 +405,21 @@ class ToPILd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.ToNumpy`. """ - def __init__(self, keys: KeysCollection) -> None: + def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.converter = ToPIL() def __call__( self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor, PILImageImage]] ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor, PILImageImage]]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.converter(d[key]) return d @@ -418,7 +431,7 @@ class DeleteItemsd(MapTransform): """ def __call__(self, data): - return {key: val for key, val in data.items() if key not in self.keys} + return {key: val for key, val in data.items() if key not in self.key_iterator(data)} class SelectItemsd(MapTransform): @@ -428,7 +441,7 @@ class SelectItemsd(MapTransform): """ def __call__(self, data): - result = {key: val for key, val in data.items() if key in self.keys} + result = {key: val for key, val in data.items() if key in self.key_iterator(d)} return result @@ -437,19 +450,20 @@ class SqueezeDimd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.SqueezeDim`. """ - def __init__(self, keys: KeysCollection, dim: int = 0) -> None: + def __init__(self, keys: KeysCollection, dim: int = 0, allow_missing_keys: bool = False) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` dim: dimension to be squeezed. Default: 0 (the first dimension) + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.converter = SqueezeDim(dim=dim) def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.converter(d[key]) return d @@ -468,6 +482,7 @@ def __init__( data_value: Union[Sequence[bool], bool] = False, additional_info: Optional[Union[Sequence[Callable], Callable]] = None, logger_handler: Optional[logging.Handler] = None, + allow_missing_keys: bool = False, ) -> None: """ Args: @@ -487,9 +502,10 @@ def __init__( corresponds to a key in ``keys``. logger_handler: add additional handler to output data: save to file, etc. add existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.prefix = ensure_tuple_rep(prefix, len(self.keys)) self.data_shape = ensure_tuple_rep(data_shape, len(self.keys)) self.value_range = ensure_tuple_rep(value_range, len(self.keys)) @@ -500,14 +516,14 @@ def __init__( def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: d = dict(data) - for idx, key in enumerate(self.keys): + for key, prefix, data_shape, value_range, data_value, additional_info in self.key_iterator(d, self.prefix, self.data_shape, self.value_range, self.data_value, self.additional_info): d[key] = self.printer( d[key], - self.prefix[idx], - self.data_shape[idx], - self.value_range[idx], - self.data_value[idx], - self.additional_info[idx], + prefix, + data_shape, + value_range, + data_value, + additional_info, ) return d @@ -517,23 +533,24 @@ class SimulateDelayd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.SimulateDelay`. """ - def __init__(self, keys: KeysCollection, delay_time: Union[Sequence[float], float] = 0.0) -> None: + def __init__(self, keys: KeysCollection, delay_time: Union[Sequence[float], float] = 0.0, allow_missing_keys: bool = False) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` delay_time: The minimum amount of time, in fractions of seconds, to accomplish this identity task. It also can be a sequence of string, each element corresponds to a key in ``keys``. + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.delay_time = ensure_tuple_rep(delay_time, len(self.keys)) self.delayer = SimulateDelay() def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: d = dict(data) - for idx, key in enumerate(self.keys): - d[key] = self.delayer(d[key], delay_time=self.delay_time[idx]) + for key, delay_time in self.key_iterator(d, self.delay_time): + d[key] = self.delayer(d[key], delay_time=delay_time) return d @@ -596,19 +613,20 @@ class ConcatItemsd(MapTransform): """ - def __init__(self, keys: KeysCollection, name: str, dim: int = 0) -> None: + def __init__(self, keys: KeysCollection, name: str, dim: int = 0, allow_missing_keys: bool = False) -> None: """ Args: keys: keys of the corresponding items to be concatenated together. See also: :py:class:`monai.transforms.compose.MapTransform` name: the name corresponding to the key to store the concatenated data. dim: on which dimension to concatenate the items, default is 0. + allow_missing_keys: don't raise exception if key is missing. Raises: ValueError: When insufficient keys are given (``len(self.keys) < 2``). """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) if len(self.keys) < 2: raise ValueError("Concatenation requires at least 2 keys.") self.name = name @@ -624,7 +642,7 @@ def __call__(self, data): d = dict(data) output = [] data_type = None - for key in self.keys: + for key in self.key_iterator(d): if data_type is None: data_type = type(d[key]) elif not isinstance(d[key], data_type): @@ -660,6 +678,7 @@ class Lambdad(MapTransform): each element corresponds to a key in ``keys``. overwrite: whether to overwrite the original data in the input dictionary with lamdbda function output. default to True. it also can be a sequence of bool, each element corresponds to a key in ``keys``. + allow_missing_keys: don't raise exception if key is missing. """ def __init__( @@ -667,17 +686,18 @@ def __init__( keys: KeysCollection, func: Union[Sequence[Callable], Callable], overwrite: Union[Sequence[bool], bool] = True, + allow_missing_keys: bool = False, ) -> None: - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.func = ensure_tuple_rep(func, len(self.keys)) self.overwrite = ensure_tuple_rep(overwrite, len(self.keys)) self._lambd = Lambda() def __call__(self, data): d = dict(data) - for idx, key in enumerate(self.keys): - ret = self._lambd(d[key], func=self.func[idx]) - if self.overwrite[idx]: + for key, func, overwrite in self.key_iterator(d, self.func, self.overwrite): + ret = self._lambd(d[key], func=func) + if overwrite: d[key] = ret return d @@ -715,6 +735,7 @@ class LabelToMaskd(MapTransform): `select_labels` is the expected channel indices. merge_channels: whether to use `np.any()` to merge the result on channel dim. if yes, will return a single channel mask with binary data. + allow_missing_keys: don't raise exception if key is missing. """ @@ -723,13 +744,14 @@ def __init__( # pytype: disable=annotation-type-mismatch keys: KeysCollection, select_labels: Union[Sequence[int], int], merge_channels: bool = False, + allow_missing_keys: bool = False, ) -> None: # pytype: disable=annotation-type-mismatch - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.converter = LabelToMask(select_labels=select_labels, merge_channels=merge_channels) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.converter(d[key]) return d @@ -751,6 +773,7 @@ class FgBgToIndicesd(MapTransform): image_threshold: if enabled image_key, use ``image > image_threshold`` to determine the valid image content area and select background only in this area. output_shape: expected shape of output indices. if not None, unravel indices to specified shape. + allow_missing_keys: don't raise exception if key is missing. """ @@ -762,8 +785,9 @@ def __init__( image_key: Optional[str] = None, image_threshold: float = 0.0, output_shape: Optional[Sequence[int]] = None, + allow_missing_keys: bool = False, ) -> None: - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.fg_postfix = fg_postfix self.bg_postfix = bg_postfix self.image_key = image_key @@ -772,7 +796,7 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) image = d[self.image_key] if self.image_key else None - for key in self.keys: + for key in self.key_iterator(d): d[str(key) + self.fg_postfix], d[str(key) + self.bg_postfix] = self.converter(d[key], image) return d @@ -789,13 +813,13 @@ class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform): and ET (Enhancing tumor). """ - def __init__(self, keys: KeysCollection): - super().__init__(keys) + def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False): + super().__init__(keys, allow_missing_keys) self.converter = ConvertToMultiChannelBasedOnBratsClasses() def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.converter(d[key]) return d @@ -815,6 +839,7 @@ class AddExtremePointsChanneld(RandomizableTransform, MapTransform): use it for all spatial dimensions. rescale_min: minimum value of output data. rescale_max: maximum value of output data. + allow_missing_keys: don't raise exception if key is missing. """ @@ -827,8 +852,9 @@ def __init__( sigma: Union[Sequence[float], float, Sequence[torch.Tensor], torch.Tensor] = 3.0, rescale_min: float = -1.0, rescale_max: float = 1.0, + allow_missing_keys: bool = False, ): - MapTransform.__init__(self, keys) + MapTransform.__init__(self, keys, allow_missing_keys) self.background = background self.pert = pert self.points: List[Tuple[int, ...]] = [] @@ -849,17 +875,16 @@ def __call__(self, data): # Generate extreme points self.randomize(label[0, :]) - for key in data.keys(): - if key in self.keys: - img = d[key] - points_image = extreme_points_to_image( - points=self.points, - label=label, - sigma=self.sigma, - rescale_min=self.rescale_min, - rescale_max=self.rescale_max, - ) - d[key] = np.concatenate([img, points_image], axis=0) + for key in self.key_iterator(d): + img = d[key] + points_image = extreme_points_to_image( + points=self.points, + label=label, + sigma=self.sigma, + rescale_min=self.rescale_min, + rescale_max=self.rescale_max, + ) + d[key] = np.concatenate([img, points_image], axis=0) return d @@ -870,22 +895,23 @@ class TorchVisiond(MapTransform): data to be dict of PyTorch Tensors, users can easily call `ToTensord` transform to convert Numpy to Tensor. """ - def __init__(self, keys: KeysCollection, name: str, *args, **kwargs) -> None: + def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False, *args, **kwargs) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` name: The transform name in TorchVision package. + allow_missing_keys: don't raise exception if key is missing. args: parameters for the TorchVision transform. kwargs: parameters for the TorchVision transform. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.trans = TorchVision(name, *args, **kwargs) def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.trans(d[key]) return d From a1931a6fe9866f1a3d2bbf494ff692bac3ba2362 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 2 Mar 2021 15:03:09 +0000 Subject: [PATCH 3/6] current progress Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/apps/deepgrow/transforms.py | 9 ++- monai/transforms/croppad/dictionary.py | 35 ++++++++--- monai/transforms/intensity/dictionary.py | 77 ++++++++++++++++++++---- monai/transforms/post/dictionary.py | 4 +- monai/transforms/spatial/dictionary.py | 43 ++++++++----- monai/transforms/transform.py | 35 +++++------ monai/transforms/utility/dictionary.py | 26 +++++--- 7 files changed, 161 insertions(+), 68 deletions(-) diff --git a/monai/apps/deepgrow/transforms.py b/monai/apps/deepgrow/transforms.py index 9a7f0f7458..644507092d 100644 --- a/monai/apps/deepgrow/transforms.py +++ b/monai/apps/deepgrow/transforms.py @@ -907,7 +907,14 @@ class Fetch2DSliced(MapTransform): allow_missing_keys: don't raise exception if key is missing. """ - def __init__(self, keys, guidance="guidance", axis: int = 0, meta_key_postfix: str = "meta_dict", allow_missing_keys: bool = False): + def __init__( + self, + keys, + guidance="guidance", + axis: int = 0, + meta_key_postfix: str = "meta_dict", + allow_missing_keys: bool = False, + ): super().__init__(keys, allow_missing_keys) self.guidance = guidance self.axis = axis diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 30652d1be8..9cf0ca3205 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -134,6 +134,7 @@ def __init__( keys: KeysCollection, spatial_border: Union[Sequence[int], int], mode: NumpyPadModeSequence = NumpyPadMode.CONSTANT, + allow_missing_keys: bool = False, ) -> None: """ Args: @@ -155,15 +156,16 @@ def __init__( One of the listed string values or a user supplied function. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html It also can be a sequence of string, each element corresponds to a key in ``keys``. + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padder = BorderPad(spatial_border=spatial_border) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key, m in zip(self.keys, self.mode): + for key, m in self.key_iterator(d, self.mode): d[key] = self.padder(d[key], mode=m) return d @@ -175,7 +177,11 @@ class DivisiblePadd(MapTransform): """ def __init__( - self, keys: KeysCollection, k: Union[Sequence[int], int], mode: NumpyPadModeSequence = NumpyPadMode.CONSTANT + self, + keys: KeysCollection, + k: Union[Sequence[int], int], + mode: NumpyPadModeSequence = NumpyPadMode.CONSTANT, + allow_missing_keys: bool = False, ) -> None: """ Args: @@ -189,17 +195,18 @@ def __init__( One of the listed string values or a user supplied function. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html It also can be a sequence of string, each element corresponds to a key in ``keys``. + allow_missing_keys: don't raise exception if key is missing. See also :py:class:`monai.transforms.SpatialPad` """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padder = DivisiblePad(k=k) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key, m in zip(self.keys, self.mode): + for key, m in self.key_iterator(d, self.mode): d[key] = self.padder(d[key], mode=m) return d @@ -252,7 +259,9 @@ class CenterSpatialCropd(MapTransform): allow_missing_keys: don't raise exception if key is missing. """ - def __init__(self, keys: KeysCollection, roi_size: Union[Sequence[int], int], allow_missing_keys: bool = False) -> None: + def __init__( + self, keys: KeysCollection, roi_size: Union[Sequence[int], int], allow_missing_keys: bool = False + ) -> None: super().__init__(keys, allow_missing_keys) self.cropper = CenterSpatialCrop(roi_size) @@ -338,6 +347,7 @@ class RandSpatialCropSamplesd(RandomizableTransform, MapTransform): random_center: crop at random position as center or the image center. random_size: crop with random size or specific size ROI. The actual size is sampled from `randint(roi_size, img_size)`. + allow_missing_keys: don't raise exception if key is missing. Raises: ValueError: When ``num_samples`` is nonpositive. @@ -351,13 +361,14 @@ def __init__( num_samples: int, random_center: bool = True, random_size: bool = True, + allow_missing_keys: bool = False, ) -> None: RandomizableTransform.__init__(self) - MapTransform.__init__(self, keys) + MapTransform.__init__(self, keys, allow_missing_keys) if num_samples < 1: raise ValueError(f"num_samples must be positive, got {num_samples}.") self.num_samples = num_samples - self.cropper = RandSpatialCropd(keys, roi_size, random_center, random_size) + self.cropper = RandSpatialCropd(keys, roi_size, random_center, random_size, allow_missing_keys) def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None @@ -656,7 +667,13 @@ class BoundingRectd(MapTransform): allow_missing_keys: don't raise exception if key is missing. """ - def __init__(self, keys: KeysCollection, bbox_key_postfix: str = "bbox", select_fn: Callable = lambda x: x > 0, allow_missing_keys: bool = False): + def __init__( + self, + keys: KeysCollection, + bbox_key_postfix: str = "bbox", + select_fn: Callable = lambda x: x > 0, + allow_missing_keys: bool = False, + ): super().__init__(keys, allow_missing_keys) self.bbox = BoundingRect(select_fn=select_fn) self.bbox_key_postfix = bbox_key_postfix diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index d2bca28678..4602d59379 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -103,12 +103,18 @@ class RandGaussianNoised(RandomizableTransform, MapTransform): prob: Probability to add Gaussian noise. mean: Mean or “centre” of the distribution. std: Standard deviation (spread) of distribution. + allow_missing_keys: don't raise exception if key is missing. """ def __init__( - self, keys: KeysCollection, prob: float = 0.1, mean: Union[Sequence[float], float] = 0.0, std: float = 0.1 + self, + keys: KeysCollection, + prob: float = 0.1, + mean: Union[Sequence[float], float] = 0.0, + std: float = 0.1, + allow_missing_keys: bool = False, ) -> None: - MapTransform.__init__(self, keys) + MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) self.mean = ensure_tuple_rep(mean, len(self.keys)) self.std = std @@ -129,7 +135,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda raise AssertionError if not self._do_transform: return d - for noise, key in zip(self._noise, self.keys): + for key, noise in self.key_iterator(d, self._noise): dtype = dtype_torch_to_numpy(d[key].dtype) if isinstance(d[key], torch.Tensor) else d[key].dtype d[key] = d[key] + noise.astype(dtype) return d @@ -163,7 +169,13 @@ class RandShiftIntensityd(RandomizableTransform, MapTransform): Dictionary-based version :py:class:`monai.transforms.RandShiftIntensity`. """ - def __init__(self, keys: KeysCollection, offsets: Union[Tuple[float, float], float], prob: float = 0.1, allow_missing_keys: bool = False) -> None: + def __init__( + self, + keys: KeysCollection, + offsets: Union[Tuple[float, float], float], + prob: float = 0.1, + allow_missing_keys: bool = False, + ) -> None: """ Args: keys: keys of the corresponding items to be transformed. @@ -207,7 +219,12 @@ class ScaleIntensityd(MapTransform): """ def __init__( - self, keys: KeysCollection, minv: float = 0.0, maxv: float = 1.0, factor: Optional[float] = None, allow_missing_keys: bool = False + self, + keys: KeysCollection, + minv: float = 0.0, + maxv: float = 1.0, + factor: Optional[float] = None, + allow_missing_keys: bool = False, ) -> None: """ Args: @@ -234,7 +251,13 @@ class RandScaleIntensityd(RandomizableTransform, MapTransform): Dictionary-based version :py:class:`monai.transforms.RandScaleIntensity`. """ - def __init__(self, keys: KeysCollection, factors: Union[Tuple[float, float], float], prob: float = 0.1, allow_missing_keys: bool = False) -> None: + def __init__( + self, + keys: KeysCollection, + factors: Union[Tuple[float, float], float], + prob: float = 0.1, + allow_missing_keys: bool = False, + ) -> None: """ Args: keys: keys of the corresponding items to be transformed. @@ -299,7 +322,7 @@ def __init__( dtype: DtypeLike = np.float32, allow_missing_keys: bool = False, ) -> None: - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.normalizer = NormalizeIntensity(subtrahend, divisor, nonzero, channel_wise, dtype) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: @@ -322,8 +345,15 @@ class ThresholdIntensityd(MapTransform): allow_missing_keys: don't raise exception if key is missing. """ - def __init__(self, keys: KeysCollection, threshold: float, above: bool = True, cval: float = 0.0, allow_missing_keys: bool = False) -> None: - super().__init__(keys) + def __init__( + self, + keys: KeysCollection, + threshold: float, + above: bool = True, + cval: float = 0.0, + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) self.filter = ThresholdIntensity(threshold, above, cval) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: @@ -349,7 +379,14 @@ class ScaleIntensityRanged(MapTransform): """ def __init__( - self, keys: KeysCollection, a_min: float, a_max: float, b_min: float, b_max: float, clip: bool = False, allow_missing_keys: bool = False + self, + keys: KeysCollection, + a_min: float, + a_max: float, + b_min: float, + b_max: float, + clip: bool = False, + allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) self.scaler = ScaleIntensityRange(a_min, a_max, b_min, b_max, clip) @@ -403,7 +440,11 @@ class RandAdjustContrastd(RandomizableTransform, MapTransform): """ def __init__( - self, keys: KeysCollection, prob: float = 0.1, gamma: Union[Tuple[float, float], float] = (0.5, 4.5), allow_missing_keys: bool = False + self, + keys: KeysCollection, + prob: float = 0.1, + gamma: Union[Tuple[float, float], float] = (0.5, 4.5), + allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) @@ -527,7 +568,13 @@ class GaussianSmoothd(MapTransform): """ - def __init__(self, keys: KeysCollection, sigma: Union[Sequence[float], float], approx: str = "erf", allow_missing_keys: bool = False) -> None: + def __init__( + self, + keys: KeysCollection, + sigma: Union[Sequence[float], float], + approx: str = "erf", + allow_missing_keys: bool = False, + ) -> None: super().__init__(keys, allow_missing_keys) self.converter = GaussianSmooth(sigma, approx=approx) @@ -718,7 +765,11 @@ class RandHistogramShiftd(RandomizableTransform, MapTransform): """ def __init__( - self, keys: KeysCollection, num_control_points: Union[Tuple[int, int], int] = 10, prob: float = 0.1, allow_missing_keys: bool = False + self, + keys: KeysCollection, + num_control_points: Union[Tuple[int, int], int] = 10, + prob: float = 0.1, + allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index b1b3dfbf00..42796e2412 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -142,7 +142,9 @@ def __init__( def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) - for key, argmax, to_onehot, n_classes, threshold_values, logit_thresh in self.key_iterator(d, self.argmax, self.to_onehot, self.n_classes, self.threshold_values, self.logit_thresh): + for key, argmax, to_onehot, n_classes, threshold_values, logit_thresh in self.key_iterator( + d, self.argmax, self.to_onehot, self.n_classes, self.threshold_values, self.logit_thresh + ): d[key] = self.converter( d[key], argmax, diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 6c33ae00a4..a81aeb432b 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -181,7 +181,9 @@ def __call__( self, data: Mapping[Union[Hashable, str], Dict[str, np.ndarray]] ) -> Dict[Union[Hashable, str], Union[np.ndarray, Dict[str, np.ndarray]]]: d: Dict = dict(data) - for key, mode, padding_mode, align_corners, dtype in self.key_iterator(d, self.mode, self.padding_mode, self.align_corners, self.dtype): + for key, mode, padding_mode, align_corners, dtype in self.key_iterator( + d, self.mode, self.padding_mode, self.align_corners, self.dtype + ): meta_data = d[f"{key}_{self.meta_key_postfix}"] # resample array of each corresponding key # using affine fetched from d[affine_key] @@ -264,7 +266,9 @@ class Rotate90d(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.Rotate90`. """ - def __init__(self, keys: KeysCollection, k: int = 1, spatial_axes: Tuple[int, int] = (0, 1), allow_missing_keys: bool = False) -> None: + def __init__( + self, keys: KeysCollection, k: int = 1, spatial_axes: Tuple[int, int] = (0, 1), allow_missing_keys: bool = False + ) -> None: """ Args: k: number of times to rotate by 90 degrees. @@ -590,9 +594,7 @@ def __call__( grid = create_grid(spatial_size=sp_size) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): - d[key] = self.rand_2d_elastic.resampler( - d[key], grid, mode=mode, padding_mode=padding_mode - ) + d[key] = self.rand_2d_elastic.resampler(d[key], grid, mode=mode, padding_mode=padding_mode) return d @@ -661,7 +663,7 @@ def __init__( - :py:class:`RandAffineGrid` for the random affine parameters configurations. - :py:class:`Affine` for the affine transformation parameters configurations. """ - MapTransform.__init__(self, keys, allow_missing_keys=) + MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) self.rand_3d_elastic = Rand3DElastic( sigma_range=sigma_range, @@ -706,9 +708,7 @@ def __call__( grid = self.rand_3d_elastic.rand_affine_grid(grid=grid) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): - d[key] = self.rand_3d_elastic.resampler( - d[key], grid, mode=mode, padding_mode=padding_mode - ) + d[key] = self.rand_3d_elastic.resampler(d[key], grid, mode=mode, padding_mode=padding_mode) return d @@ -725,7 +725,12 @@ class Flipd(MapTransform): allow_missing_keys: don't raise exception if key is missing. """ - def __init__(self, keys: KeysCollection, spatial_axis: Optional[Union[Sequence[int], int]] = None, allow_missing_keys: bool = False) -> None: + def __init__( + self, + keys: KeysCollection, + spatial_axis: Optional[Union[Sequence[int], int]] = None, + allow_missing_keys: bool = False, + ) -> None: super().__init__(keys, allow_missing_keys) self.flipper = Flip(spatial_axis=spatial_axis) @@ -757,7 +762,7 @@ def __init__( spatial_axis: Optional[Union[Sequence[int], int]] = None, allow_missing_keys: bool = False, ) -> None: - MapTransform.__init__(self, keys) + MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) self.spatial_axis = spatial_axis @@ -855,7 +860,9 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key, mode, padding_mode, align_corners, dtype in self.key_iterator(d, self.mode, self.padding_mode, self.align_corners, self.dtype): + for key, mode, padding_mode, align_corners, dtype in self.key_iterator( + d, self.mode, self.padding_mode, self.align_corners, self.dtype + ): d[key] = self.rotator( d[key], mode=mode, @@ -952,7 +959,9 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda angle=self.x if d[self.keys[0]].ndim == 3 else (self.x, self.y, self.z), keep_size=self.keep_size, ) - for key, mode, padding_mode, align_corners, dtype in self.key_iterator(d, self.mode, self.padding_mode, self.align_corners, self.dtype): + for key, mode, padding_mode, align_corners, dtype in self.key_iterator( + d, self.mode, self.padding_mode, self.align_corners, self.dtype + ): d[key] = rotator( d[key], mode=mode, @@ -1006,7 +1015,9 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key, mode, padding_mode, align_corners in self.key_iterator(d, self.mode, self.padding_mode, self.align_corners): + for key, mode, padding_mode, align_corners in self.key_iterator( + d, self.mode, self.padding_mode, self.align_corners + ): d[key] = self.zoomer( d[key], mode=mode, @@ -1094,7 +1105,9 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda # if 2 zoom factors provided for 3D data, use the first factor for H and W dims, second factor for D dim self._zoom = ensure_tuple_rep(self._zoom[0], img_dims - 2) + ensure_tuple(self._zoom[-1]) zoomer = Zoom(self._zoom, keep_size=self.keep_size) - for key, mode, padding_mode, align_corners in self.key_iterator(d, self.mode, self.padding_mode, self.align_corners): + for key, mode, padding_mode, align_corners in self.key_iterator( + d, self.mode, self.padding_mode, self.align_corners + ): d[key] = zoomer( d[key], mode=mode, diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index ff80b92b40..a1f2f5a360 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -13,7 +13,7 @@ """ from abc import ABC, abstractmethod -from typing import Any, Dict, Hashable, Iterable, Optional, Tuple +from typing import Any, Dict, Generator, Hashable, Iterable, Optional, Tuple import numpy as np @@ -228,9 +228,9 @@ def __call__(self, data): def key_iterator( self, - data: Dict[str, Any], - extra_iterables: Optional[Iterable] = None, - ): + data: Dict[Hashable, Any], + *extra_iterables: Optional[Iterable], + ) -> Generator: """ Iterate across keys and optionally extra iterables. If key is missing, exception is raised if `allow_missing_keys==False` (default). If `allow_missing_keys==True`, key is skipped. @@ -239,19 +239,14 @@ def key_iterator( data: data that the transform will be applied to extra_iterables: anything else to be iterated through """ - if extra_iterables is None: - for key in self.keys: - if key not in data.keys(): - if self.allow_missing_keys: - continue - else: - raise KeyError("Key was missing and allow_missing_keys==False") - yield key - else: - for key, *extra_iterables in zip(self.keys, extra_iterables): - if key not in data.keys(): - if self.allow_missing_keys: - continue - else: - raise KeyError("Key was missing and allow_missing_keys==False") - yield [key] + extra_iterables + # if no extra iterables given, create a dummy list of Nones + ex_iters = extra_iterables if extra_iterables else [None] * len(self.keys) + + # loop over keys and any extra iterables + for key, *_ex_iters in zip(self.keys, ex_iters): + # all normal, yield (what we yield depends on whether extra iterables were given) + if str(key) in data.keys(): + yield (key,) + tuple(_ex_iters) if extra_iterables else key + # if missing keys not allowed, raise + elif not self.allow_missing_keys: + raise KeyError(f"Key was missing ({key}) and allow_missing_keys==False") diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 665ad45297..1d3be3d29f 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -516,7 +516,9 @@ def __init__( def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: d = dict(data) - for key, prefix, data_shape, value_range, data_value, additional_info in self.key_iterator(d, self.prefix, self.data_shape, self.value_range, self.data_value, self.additional_info): + for key, prefix, data_shape, value_range, data_value, additional_info in self.key_iterator( + d, self.prefix, self.data_shape, self.value_range, self.data_value, self.additional_info + ): d[key] = self.printer( d[key], prefix, @@ -533,7 +535,9 @@ class SimulateDelayd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.SimulateDelay`. """ - def __init__(self, keys: KeysCollection, delay_time: Union[Sequence[float], float] = 0.0, allow_missing_keys: bool = False) -> None: + def __init__( + self, keys: KeysCollection, delay_time: Union[Sequence[float], float] = 0.0, allow_missing_keys: bool = False + ) -> None: """ Args: keys: keys of the corresponding items to be transformed. @@ -561,7 +565,9 @@ class CopyItemsd(MapTransform): """ - def __init__(self, keys: KeysCollection, times: int, names: KeysCollection) -> None: + def __init__( + self, keys: KeysCollection, times: int, names: KeysCollection, allow_missing_keys: bool = False + ) -> None: """ Args: keys: keys of the corresponding items to be transformed. @@ -571,13 +577,14 @@ def __init__(self, keys: KeysCollection, times: int, names: KeysCollection) -> N names: the names corresponding to the newly copied data, the length should match `len(keys) x times`. for example, if keys is ["img", "seg"] and times is 2, names can be: ["img_1", "seg_1", "img_2", "seg_2"]. + allow_missing_keys: don't raise exception if key is missing. Raises: ValueError: When ``times`` is nonpositive. ValueError: When ``len(names)`` is not ``len(keys) * times``. Incompatible values. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) if times < 1: raise ValueError(f"times must be positive, got {times}.") self.times = times @@ -596,13 +603,14 @@ def __call__(self, data): """ d = dict(data) - for key, new_key in zip(self.keys * self.times, self.names): + for new_key in self.names: if new_key in d: raise KeyError(f"Key {new_key} already exists in data.") - if isinstance(d[key], torch.Tensor): - d[new_key] = d[key].detach().clone() - else: - d[new_key] = copy.deepcopy(d[key]) + for key in self.key_iterator(d): + if isinstance(d[key], torch.Tensor): + d[new_key] = d[key].detach().clone() + else: + d[new_key] = copy.deepcopy(d[key]) return d From 502c9e0f20fb8c6459baaa22ef44bdf1bc185c71 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 2 Mar 2021 15:36:26 +0000 Subject: [PATCH 4/6] current progress Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index a1f2f5a360..0eeedd27c9 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -243,7 +243,7 @@ def key_iterator( ex_iters = extra_iterables if extra_iterables else [None] * len(self.keys) # loop over keys and any extra iterables - for key, *_ex_iters in zip(self.keys, ex_iters): + for key, *_ex_iters in zip(self.keys, *ex_iters): # all normal, yield (what we yield depends on whether extra iterables were given) if str(key) in data.keys(): yield (key,) + tuple(_ex_iters) if extra_iterables else key From 3ce990ba108ddad9dff8c67e89eb19653dffa84e Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 2 Mar 2021 15:41:31 +0000 Subject: [PATCH 5/6] code format Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/transform.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 0eeedd27c9..0af3343da9 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -13,7 +13,7 @@ """ from abc import ABC, abstractmethod -from typing import Any, Dict, Generator, Hashable, Iterable, Optional, Tuple +from typing import Any, Dict, Generator, Hashable, Iterable, List, Optional, Tuple import numpy as np @@ -243,6 +243,7 @@ def key_iterator( ex_iters = extra_iterables if extra_iterables else [None] * len(self.keys) # loop over keys and any extra iterables + _ex_iters: List[Any] for key, *_ex_iters in zip(self.keys, *ex_iters): # all normal, yield (what we yield depends on whether extra iterables were given) if str(key) in data.keys(): From 69b46ecf419ee448732e780f34789457c22b4cef Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 2 Mar 2021 17:57:35 +0000 Subject: [PATCH 6/6] unit tests working locally Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/dictionary.py | 32 +++++++++++++------------- monai/transforms/transform.py | 6 ++--- monai/transforms/utility/dictionary.py | 2 +- 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 9cf0ca3205..823b2dd3f4 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -489,22 +489,22 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n _spatial_size = fall_back_tuple(self.spatial_size, d[self.w_key].shape[1:]) results: List[Dict[Hashable, np.ndarray]] = [{} for _ in range(self.num_samples)] - for key in data.keys(): - for key in self.key_iterator(d): - img = d[key] - if img.shape[1:] != d[self.w_key].shape[1:]: - raise ValueError( - f"data {key} and weight map {self.w_key} spatial shape mismatch: " - f"{img.shape[1:]} vs {d[self.w_key].shape[1:]}." - ) - for i, center in enumerate(self.centers): - cropper = SpatialCrop(roi_center=center, roi_size=_spatial_size) - results[i][key] = cropper(img) - if self.center_coord_key: - results[i][self.center_coord_key] = center - else: - for i in range(self.num_samples): - results[i][key] = data[key] + for key in self.key_iterator(d): + img = d[key] + if img.shape[1:] != d[self.w_key].shape[1:]: + raise ValueError( + f"data {key} and weight map {self.w_key} spatial shape mismatch: " + f"{img.shape[1:]} vs {d[self.w_key].shape[1:]}." + ) + for i, center in enumerate(self.centers): + cropper = SpatialCrop(roi_center=center, roi_size=_spatial_size) + results[i][key] = cropper(img) + if self.center_coord_key: + results[i][self.center_coord_key] = center + # fill in the extra keys with unmodified data + for key in set(data.keys()).difference(set(self.keys)): + for i in range(self.num_samples): + results[i][key] = data[key] return results diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 0af3343da9..7a09efa6d5 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -187,7 +187,7 @@ def __call__(self, data): """ - def __init__(self, keys: KeysCollection, allow_missing_keys: bool) -> None: + def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: self.keys: Tuple[Hashable, ...] = ensure_tuple(keys) self.allow_missing_keys = allow_missing_keys if not self.keys: @@ -240,13 +240,13 @@ def key_iterator( extra_iterables: anything else to be iterated through """ # if no extra iterables given, create a dummy list of Nones - ex_iters = extra_iterables if extra_iterables else [None] * len(self.keys) + ex_iters = extra_iterables if extra_iterables else [[None] * len(self.keys)] # loop over keys and any extra iterables _ex_iters: List[Any] for key, *_ex_iters in zip(self.keys, *ex_iters): # all normal, yield (what we yield depends on whether extra iterables were given) - if str(key) in data.keys(): + if key in data.keys(): yield (key,) + tuple(_ex_iters) if extra_iterables else key # if missing keys not allowed, raise elif not self.allow_missing_keys: diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 1d3be3d29f..fbe594e641 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -441,7 +441,7 @@ class SelectItemsd(MapTransform): """ def __call__(self, data): - result = {key: val for key, val in data.items() if key in self.key_iterator(d)} + result = {key: data[key] for key in self.key_iterator(data)} return result