diff --git a/monai/apps/deepgrow/transforms.py b/monai/apps/deepgrow/transforms.py index cc01a717ad..644507092d 100644 --- a/monai/apps/deepgrow/transforms.py +++ b/monai/apps/deepgrow/transforms.py @@ -437,6 +437,7 @@ class SpatialCropForegroundd(MapTransform): end_coord_key: key to record the end coordinate of spatial bounding box for foreground. original_shape_key: key to record original shape for foreground. cropped_shape_key: key to record cropped shape for foreground. + allow_missing_keys: don't raise exception if key is missing. """ def __init__( @@ -452,8 +453,9 @@ def __init__( end_coord_key: str = "foreground_end_coord", original_shape_key: str = "foreground_original_shape", cropped_shape_key: str = "foreground_cropped_shape", + allow_missing_keys: bool = False, ) -> None: - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.source_key = source_key self.spatial_size = list(spatial_size) @@ -482,7 +484,7 @@ def __call__(self, data): else: cropper = SpatialCrop(roi_start=box_start, roi_end=box_end) - for key in self.keys: + for key in self.key_iterator(d): meta_key = f"{key}_{self.meta_key_postfix}" d[meta_key][self.start_coord_key] = box_start d[meta_key][self.end_coord_key] = box_end @@ -629,6 +631,7 @@ class SpatialCropGuidanced(MapTransform): end_coord_key: key to record the end coordinate of spatial bounding box for foreground. original_shape_key: key to record original shape for foreground. cropped_shape_key: key to record cropped shape for foreground. + allow_missing_keys: don't raise exception if key is missing. """ def __init__( @@ -642,8 +645,9 @@ def __init__( end_coord_key: str = "foreground_end_coord", original_shape_key: str = "foreground_original_shape", cropped_shape_key: str = "foreground_cropped_shape", + allow_missing_keys: bool = False, ) -> None: - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.guidance = guidance self.spatial_size = list(spatial_size) @@ -697,7 +701,7 @@ def __call__(self, data): cropper = SpatialCrop(roi_start=box_start, roi_end=box_end) box_start, box_end = cropper.roi_start, cropper.roi_end - for key in self.keys: + for key in self.key_iterator(d): if not np.array_equal(d[key].shape[1:], original_spatial_shape): raise RuntimeError("All the image specified in keys should have same spatial shape") meta_key = f"{key}_{self.meta_key_postfix}" @@ -804,6 +808,7 @@ class RestoreLabeld(MapTransform): end_coord_key: key that records the end coordinate of spatial bounding box for foreground. original_shape_key: key that records original shape for foreground. cropped_shape_key: key that records cropped shape for foreground. + allow_missing_keys: don't raise exception if key is missing. """ def __init__( @@ -818,8 +823,9 @@ def __init__( end_coord_key: str = "foreground_end_coord", original_shape_key: str = "foreground_original_shape", cropped_shape_key: str = "foreground_cropped_shape", + allow_missing_keys: bool = False, ) -> None: - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.ref_image = ref_image self.slice_only = slice_only self.mode = ensure_tuple_rep(mode, len(self.keys)) @@ -834,15 +840,15 @@ def __call__(self, data): d = dict(data) meta_dict: Dict = d[f"{self.ref_image}_{self.meta_key_postfix}"] - for idx, key in enumerate(self.keys): + for key, mode, align_corners in self.key_iterator(d, self.mode, self.align_corners): image = d[key] # Undo Resize current_shape = image.shape cropped_shape = meta_dict[self.cropped_shape_key] if np.any(np.not_equal(current_shape, cropped_shape)): - resizer = Resize(spatial_size=cropped_shape[1:], mode=self.mode[idx]) - image = resizer(image, mode=self.mode[idx], align_corners=self.align_corners[idx]) + resizer = Resize(spatial_size=cropped_shape[1:], mode=mode) + image = resizer(image, mode=mode, align_corners=align_corners) # Undo Crop original_shape = meta_dict[self.original_shape_key] @@ -862,8 +868,8 @@ def __call__(self, data): spatial_size = spatial_shape[-len(current_size) :] if np.any(np.not_equal(current_size, spatial_size)): - resizer = Resize(spatial_size=spatial_size, mode=self.mode[idx]) - result = resizer(result, mode=self.mode[idx], align_corners=self.align_corners[idx]) + resizer = Resize(spatial_size=spatial_size, mode=mode) + result = resizer(result, mode=mode, align_corners=align_corners) # Undo Slicing slice_idx = meta_dict.get("slice_idx") @@ -898,10 +904,18 @@ class Fetch2DSliced(MapTransform): default is `meta_dict`, the meta data is a dictionary object. For example, to handle key `image`, read/write affine matrices from the metadata `image_meta_dict` dictionary's `affine` field. + allow_missing_keys: don't raise exception if key is missing. """ - def __init__(self, keys, guidance="guidance", axis: int = 0, meta_key_postfix: str = "meta_dict"): - super().__init__(keys) + def __init__( + self, + keys, + guidance="guidance", + axis: int = 0, + meta_key_postfix: str = "meta_dict", + allow_missing_keys: bool = False, + ): + super().__init__(keys, allow_missing_keys) self.guidance = guidance self.axis = axis self.meta_key_postfix = meta_key_postfix @@ -920,7 +934,7 @@ def __call__(self, data): guidance = d[self.guidance] if len(guidance) < 3: raise RuntimeError("Guidance does not container slice_idx!") - for key in self.keys: + for key in self.key_iterator(d): img_slice, idx = self._apply(d[key], guidance) d[key] = img_slice d[f"{key}_{self.meta_key_postfix}"]["slice_idx"] = idx diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 9739c6322f..823b2dd3f4 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -94,6 +94,7 @@ def __init__( spatial_size: Union[Sequence[int], int], method: Union[Method, str] = Method.SYMMETRIC, mode: NumpyPadModeSequence = NumpyPadMode.CONSTANT, + allow_missing_keys: bool = False, ) -> None: """ Args: @@ -108,15 +109,16 @@ def __init__( One of the listed string values or a user supplied function. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html It also can be a sequence of string, each element corresponds to a key in ``keys``. + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padder = SpatialPad(spatial_size, method) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key, m in zip(self.keys, self.mode): + for key, m in self.key_iterator(d, self.mode): d[key] = self.padder(d[key], mode=m) return d @@ -132,6 +134,7 @@ def __init__( keys: KeysCollection, spatial_border: Union[Sequence[int], int], mode: NumpyPadModeSequence = NumpyPadMode.CONSTANT, + allow_missing_keys: bool = False, ) -> None: """ Args: @@ -153,15 +156,16 @@ def __init__( One of the listed string values or a user supplied function. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html It also can be a sequence of string, each element corresponds to a key in ``keys``. + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padder = BorderPad(spatial_border=spatial_border) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key, m in zip(self.keys, self.mode): + for key, m in self.key_iterator(d, self.mode): d[key] = self.padder(d[key], mode=m) return d @@ -173,7 +177,11 @@ class DivisiblePadd(MapTransform): """ def __init__( - self, keys: KeysCollection, k: Union[Sequence[int], int], mode: NumpyPadModeSequence = NumpyPadMode.CONSTANT + self, + keys: KeysCollection, + k: Union[Sequence[int], int], + mode: NumpyPadModeSequence = NumpyPadMode.CONSTANT, + allow_missing_keys: bool = False, ) -> None: """ Args: @@ -187,17 +195,18 @@ def __init__( One of the listed string values or a user supplied function. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html It also can be a sequence of string, each element corresponds to a key in ``keys``. + allow_missing_keys: don't raise exception if key is missing. See also :py:class:`monai.transforms.SpatialPad` """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padder = DivisiblePad(k=k) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key, m in zip(self.keys, self.mode): + for key, m in self.key_iterator(d, self.mode): d[key] = self.padder(d[key], mode=m) return d @@ -216,6 +225,7 @@ def __init__( roi_size: Optional[Sequence[int]] = None, roi_start: Optional[Sequence[int]] = None, roi_end: Optional[Sequence[int]] = None, + allow_missing_keys: bool = False, ) -> None: """ Args: @@ -225,13 +235,14 @@ def __init__( roi_size: size of the crop ROI. roi_start: voxel coordinates for start of the crop ROI. roi_end: voxel coordinates for end of the crop ROI. + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.cropper = SpatialCrop(roi_center, roi_size, roi_start, roi_end) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.cropper(d[key]) return d @@ -245,15 +256,18 @@ class CenterSpatialCropd(MapTransform): See also: monai.transforms.MapTransform roi_size: the size of the crop region e.g. [224,224,128] If its components have non-positive values, the corresponding size of input image will be used. + allow_missing_keys: don't raise exception if key is missing. """ - def __init__(self, keys: KeysCollection, roi_size: Union[Sequence[int], int]) -> None: - super().__init__(keys) + def __init__( + self, keys: KeysCollection, roi_size: Union[Sequence[int], int], allow_missing_keys: bool = False + ) -> None: + super().__init__(keys, allow_missing_keys) self.cropper = CenterSpatialCrop(roi_size) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.cropper(d[key]) return d @@ -274,6 +288,7 @@ class RandSpatialCropd(RandomizableTransform, MapTransform): random_center: crop at random position as center or the image center. random_size: crop with random size or specific size ROI. The actual size is sampled from `randint(roi_size, img_size)`. + allow_missing_keys: don't raise exception if key is missing. """ def __init__( @@ -282,9 +297,10 @@ def __init__( roi_size: Union[Sequence[int], int], random_center: bool = True, random_size: bool = True, + allow_missing_keys: bool = False, ) -> None: RandomizableTransform.__init__(self) - MapTransform.__init__(self, keys) + MapTransform.__init__(self, keys, allow_missing_keys) self.roi_size = roi_size self.random_center = random_center self.random_size = random_size @@ -304,7 +320,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self.randomize(d[self.keys[0]].shape[1:]) # image shape from the first data key if self._size is None: raise AssertionError - for key in self.keys: + for key in self.key_iterator(d): if self.random_center: d[key] = d[key][self._slices] else: @@ -331,6 +347,7 @@ class RandSpatialCropSamplesd(RandomizableTransform, MapTransform): random_center: crop at random position as center or the image center. random_size: crop with random size or specific size ROI. The actual size is sampled from `randint(roi_size, img_size)`. + allow_missing_keys: don't raise exception if key is missing. Raises: ValueError: When ``num_samples`` is nonpositive. @@ -344,13 +361,14 @@ def __init__( num_samples: int, random_center: bool = True, random_size: bool = True, + allow_missing_keys: bool = False, ) -> None: RandomizableTransform.__init__(self) - MapTransform.__init__(self, keys) + MapTransform.__init__(self, keys, allow_missing_keys) if num_samples < 1: raise ValueError(f"num_samples must be positive, got {num_samples}.") self.num_samples = num_samples - self.cropper = RandSpatialCropd(keys, roi_size, random_center, random_size) + self.cropper = RandSpatialCropd(keys, roi_size, random_center, random_size, allow_missing_keys) def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None @@ -388,6 +406,7 @@ def __init__( margin: int = 0, start_coord_key: str = "foreground_start_coord", end_coord_key: str = "foreground_end_coord", + allow_missing_keys: bool = False, ) -> None: """ Args: @@ -400,8 +419,9 @@ def __init__( margin: add margin value to spatial dims of the bounding box, if only 1 value provided, use it for all dims. start_coord_key: key to record the start coordinate of spatial bounding box for foreground. end_coord_key: key to record the end coordinate of spatial bounding box for foreground. + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.source_key = source_key self.select_fn = select_fn self.channel_indices = ensure_tuple(channel_indices) if channel_indices is not None else None @@ -417,7 +437,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[self.start_coord_key] = np.asarray(box_start) d[self.end_coord_key] = np.asarray(box_end) cropper = SpatialCrop(roi_start=box_start, roi_end=box_end) - for key in self.keys: + for key in self.key_iterator(d): d[key] = cropper(d[key]) return d @@ -435,6 +455,7 @@ class RandWeightedCropd(RandomizableTransform, MapTransform): If its components have non-positive values, the corresponding size of `img` will be used. num_samples: number of samples (image patches) to take in the returned list. center_coord_key: if specified, the actual sampling location will be stored with the corresponding key. + allow_missing_keys: don't raise exception if key is missing. See Also: :py:class:`monai.transforms.RandWeightedCrop` @@ -447,9 +468,10 @@ def __init__( spatial_size: Union[Sequence[int], int], num_samples: int = 1, center_coord_key: Optional[str] = None, + allow_missing_keys: bool = False, ): RandomizableTransform.__init__(self) - MapTransform.__init__(self, keys) + MapTransform.__init__(self, keys, allow_missing_keys) self.spatial_size = ensure_tuple(spatial_size) self.w_key = w_key self.num_samples = int(num_samples) @@ -467,22 +489,22 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n _spatial_size = fall_back_tuple(self.spatial_size, d[self.w_key].shape[1:]) results: List[Dict[Hashable, np.ndarray]] = [{} for _ in range(self.num_samples)] - for key in data.keys(): - if key in self.keys: - img = d[key] - if img.shape[1:] != d[self.w_key].shape[1:]: - raise ValueError( - f"data {key} and weight map {self.w_key} spatial shape mismatch: " - f"{img.shape[1:]} vs {d[self.w_key].shape[1:]}." - ) - for i, center in enumerate(self.centers): - cropper = SpatialCrop(roi_center=center, roi_size=_spatial_size) - results[i][key] = cropper(img) - if self.center_coord_key: - results[i][self.center_coord_key] = center - else: - for i in range(self.num_samples): - results[i][key] = data[key] + for key in self.key_iterator(d): + img = d[key] + if img.shape[1:] != d[self.w_key].shape[1:]: + raise ValueError( + f"data {key} and weight map {self.w_key} spatial shape mismatch: " + f"{img.shape[1:]} vs {d[self.w_key].shape[1:]}." + ) + for i, center in enumerate(self.centers): + cropper = SpatialCrop(roi_center=center, roi_size=_spatial_size) + results[i][key] = cropper(img) + if self.center_coord_key: + results[i][self.center_coord_key] = center + # fill in the extra keys with unmodified data + for key in set(data.keys()).difference(set(self.keys)): + for i in range(self.num_samples): + results[i][key] = data[key] return results @@ -517,6 +539,7 @@ class RandCropByPosNegLabeld(RandomizableTransform, MapTransform): `image_threshold`, and randomly select crop centers based on them, need to provide `fg_indices_key` and `bg_indices_key` together, expect to be 1 dim array of spatial indices after flattening. a typical usage is to call `FgBgToIndicesd` transform first and cache the results. + allow_missing_keys: don't raise exception if key is missing. Raises: ValueError: When ``pos`` or ``neg`` are negative. @@ -536,9 +559,10 @@ def __init__( image_threshold: float = 0.0, fg_indices_key: Optional[str] = None, bg_indices_key: Optional[str] = None, + allow_missing_keys: bool = False, ) -> None: RandomizableTransform.__init__(self) - MapTransform.__init__(self, keys) + MapTransform.__init__(self, keys, allow_missing_keys) self.label_key = label_key self.spatial_size: Union[Tuple[int, ...], Sequence[int], int] = spatial_size if pos < 0 or neg < 0: @@ -583,15 +607,15 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n if self.centers is None: raise AssertionError results: List[Dict[Hashable, np.ndarray]] = [{} for _ in range(self.num_samples)] - for key in data.keys(): - if key in self.keys: + + for i, center in enumerate(self.centers): + for key in self.key_iterator(d): img = d[key] - for i, center in enumerate(self.centers): - cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore - results[i][key] = cropper(img) - else: - for i in range(self.num_samples): - results[i][key] = data[key] + cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore + results[i][key] = cropper(img) + # fill in the extra keys with unmodified data + for key in set(data.keys()).difference(set(self.keys)): + results[i][key] = data[key] return results @@ -609,6 +633,7 @@ class ResizeWithPadOrCropd(MapTransform): ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} One of the listed string values or a user supplied function for padding. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + allow_missing_keys: don't raise exception if key is missing. """ @@ -617,13 +642,14 @@ def __init__( keys: KeysCollection, spatial_size: Union[Sequence[int], int], mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT, + allow_missing_keys: bool = False, ) -> None: - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.padcropper = ResizeWithPadOrCrop(spatial_size=spatial_size, mode=mode) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.padcropper(d[key]) return d @@ -638,10 +664,17 @@ class BoundingRectd(MapTransform): bbox_key_postfix: the output bounding box coordinates will be written to the value of `{key}_{bbox_key_postfix}`. select_fn: function to select expected foreground, default is to select values > 0. + allow_missing_keys: don't raise exception if key is missing. """ - def __init__(self, keys: KeysCollection, bbox_key_postfix: str = "bbox", select_fn: Callable = lambda x: x > 0): - super().__init__(keys=keys) + def __init__( + self, + keys: KeysCollection, + bbox_key_postfix: str = "bbox", + select_fn: Callable = lambda x: x > 0, + allow_missing_keys: bool = False, + ): + super().__init__(keys, allow_missing_keys) self.bbox = BoundingRect(select_fn=select_fn) self.bbox_key_postfix = bbox_key_postfix @@ -650,7 +683,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda See also: :py:class:`monai.transforms.utils.generate_spatial_bounding_box`. """ d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): bbox = self.bbox(d[key]) key_to_add = f"{key}_{self.bbox_key_postfix}" if key_to_add in d: diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 7d0d66d2ba..4602d59379 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -103,12 +103,18 @@ class RandGaussianNoised(RandomizableTransform, MapTransform): prob: Probability to add Gaussian noise. mean: Mean or “centre” of the distribution. std: Standard deviation (spread) of distribution. + allow_missing_keys: don't raise exception if key is missing. """ def __init__( - self, keys: KeysCollection, prob: float = 0.1, mean: Union[Sequence[float], float] = 0.0, std: float = 0.1 + self, + keys: KeysCollection, + prob: float = 0.1, + mean: Union[Sequence[float], float] = 0.0, + std: float = 0.1, + allow_missing_keys: bool = False, ) -> None: - MapTransform.__init__(self, keys) + MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) self.mean = ensure_tuple_rep(mean, len(self.keys)) self.std = std @@ -129,7 +135,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda raise AssertionError if not self._do_transform: return d - for noise, key in zip(self._noise, self.keys): + for key, noise in self.key_iterator(d, self._noise): dtype = dtype_torch_to_numpy(d[key].dtype) if isinstance(d[key], torch.Tensor) else d[key].dtype d[key] = d[key] + noise.astype(dtype) return d @@ -140,19 +146,20 @@ class ShiftIntensityd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.ShiftIntensity`. """ - def __init__(self, keys: KeysCollection, offset: float) -> None: + def __init__(self, keys: KeysCollection, offset: float, allow_missing_keys: bool = False) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` offset: offset value to shift the intensity of image. + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.shifter = ShiftIntensity(offset) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.shifter(d[key]) return d @@ -162,7 +169,13 @@ class RandShiftIntensityd(RandomizableTransform, MapTransform): Dictionary-based version :py:class:`monai.transforms.RandShiftIntensity`. """ - def __init__(self, keys: KeysCollection, offsets: Union[Tuple[float, float], float], prob: float = 0.1) -> None: + def __init__( + self, + keys: KeysCollection, + offsets: Union[Tuple[float, float], float], + prob: float = 0.1, + allow_missing_keys: bool = False, + ) -> None: """ Args: keys: keys of the corresponding items to be transformed. @@ -171,8 +184,9 @@ def __init__(self, keys: KeysCollection, offsets: Union[Tuple[float, float], flo if single number, offset value is picked from (-offsets, offsets). prob: probability of rotating. (Default 0.1, with 10% probability it returns a rotated array.) + allow_missing_keys: don't raise exception if key is missing. """ - MapTransform.__init__(self, keys) + MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) if isinstance(offsets, (int, float)): @@ -192,7 +206,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda if not self._do_transform: return d shifter = ShiftIntensity(self._offset) - for key in self.keys: + for key in self.key_iterator(d): d[key] = shifter(d[key]) return d @@ -205,7 +219,12 @@ class ScaleIntensityd(MapTransform): """ def __init__( - self, keys: KeysCollection, minv: float = 0.0, maxv: float = 1.0, factor: Optional[float] = None + self, + keys: KeysCollection, + minv: float = 0.0, + maxv: float = 1.0, + factor: Optional[float] = None, + allow_missing_keys: bool = False, ) -> None: """ Args: @@ -214,14 +233,15 @@ def __init__( minv: minimum value of output data. maxv: maximum value of output data. factor: factor scale by ``v = v * (1 + factor)``. + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.scaler = ScaleIntensity(minv, maxv, factor) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.scaler(d[key]) return d @@ -231,7 +251,13 @@ class RandScaleIntensityd(RandomizableTransform, MapTransform): Dictionary-based version :py:class:`monai.transforms.RandScaleIntensity`. """ - def __init__(self, keys: KeysCollection, factors: Union[Tuple[float, float], float], prob: float = 0.1) -> None: + def __init__( + self, + keys: KeysCollection, + factors: Union[Tuple[float, float], float], + prob: float = 0.1, + allow_missing_keys: bool = False, + ) -> None: """ Args: keys: keys of the corresponding items to be transformed. @@ -240,9 +266,10 @@ def __init__(self, keys: KeysCollection, factors: Union[Tuple[float, float], flo if single number, factor value is picked from (-factors, factors). prob: probability of rotating. (Default 0.1, with 10% probability it returns a rotated array.) + allow_missing_keys: don't raise exception if key is missing. """ - MapTransform.__init__(self, keys) + MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) if isinstance(factors, (int, float)): @@ -262,7 +289,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda if not self._do_transform: return d scaler = ScaleIntensity(minv=None, maxv=None, factor=self.factor) - for key in self.keys: + for key in self.key_iterator(d): d[key] = scaler(d[key]) return d @@ -282,6 +309,7 @@ class NormalizeIntensityd(MapTransform): channel_wise: if using calculated mean and std, calculate on each channel separately or calculate on the entire image directly. dtype: output data type, defaut to float32. + allow_missing_keys: don't raise exception if key is missing. """ def __init__( @@ -292,13 +320,14 @@ def __init__( nonzero: bool = False, channel_wise: bool = False, dtype: DtypeLike = np.float32, + allow_missing_keys: bool = False, ) -> None: - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.normalizer = NormalizeIntensity(subtrahend, divisor, nonzero, channel_wise, dtype) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.normalizer(d[key]) return d @@ -313,15 +342,23 @@ class ThresholdIntensityd(MapTransform): threshold: the threshold to filter intensity values. above: filter values above the threshold or below the threshold, default is True. cval: value to fill the remaining parts of the image, default is 0. + allow_missing_keys: don't raise exception if key is missing. """ - def __init__(self, keys: KeysCollection, threshold: float, above: bool = True, cval: float = 0.0) -> None: - super().__init__(keys) + def __init__( + self, + keys: KeysCollection, + threshold: float, + above: bool = True, + cval: float = 0.0, + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) self.filter = ThresholdIntensity(threshold, above, cval) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.filter(d[key]) return d @@ -338,17 +375,25 @@ class ScaleIntensityRanged(MapTransform): b_min: intensity target range min. b_max: intensity target range max. clip: whether to perform clip after scaling. + allow_missing_keys: don't raise exception if key is missing. """ def __init__( - self, keys: KeysCollection, a_min: float, a_max: float, b_min: float, b_max: float, clip: bool = False + self, + keys: KeysCollection, + a_min: float, + a_max: float, + b_min: float, + b_max: float, + clip: bool = False, + allow_missing_keys: bool = False, ) -> None: - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.scaler = ScaleIntensityRange(a_min, a_max, b_min, b_max, clip) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.scaler(d[key]) return d @@ -364,15 +409,16 @@ class AdjustContrastd(MapTransform): keys: keys of the corresponding items to be transformed. See also: monai.transforms.MapTransform gamma: gamma value to adjust the contrast as function. + allow_missing_keys: don't raise exception if key is missing. """ - def __init__(self, keys: KeysCollection, gamma: float) -> None: - super().__init__(keys) + def __init__(self, keys: KeysCollection, gamma: float, allow_missing_keys: bool = False) -> None: + super().__init__(keys, allow_missing_keys) self.adjuster = AdjustContrast(gamma) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.adjuster(d[key]) return d @@ -390,12 +436,17 @@ class RandAdjustContrastd(RandomizableTransform, MapTransform): prob: Probability of adjustment. gamma: Range of gamma values. If single number, value is picked from (0.5, gamma), default is (0.5, 4.5). + allow_missing_keys: don't raise exception if key is missing. """ def __init__( - self, keys: KeysCollection, prob: float = 0.1, gamma: Union[Tuple[float, float], float] = (0.5, 4.5) + self, + keys: KeysCollection, + prob: float = 0.1, + gamma: Union[Tuple[float, float], float] = (0.5, 4.5), + allow_missing_keys: bool = False, ) -> None: - MapTransform.__init__(self, keys) + MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) if isinstance(gamma, (int, float)): @@ -423,7 +474,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda if not self._do_transform: return d adjuster = AdjustContrast(self.gamma_value) - for key in self.keys: + for key in self.key_iterator(d): d[key] = adjuster(d[key]) return d @@ -441,6 +492,7 @@ class ScaleIntensityRangePercentilesd(MapTransform): b_max: intensity target range max. clip: whether to perform clip after scaling. relative: whether to scale to the corresponding percentiles of [b_min, b_max] + allow_missing_keys: don't raise exception if key is missing. """ def __init__( @@ -452,13 +504,14 @@ def __init__( b_max: float, clip: bool = False, relative: bool = False, + allow_missing_keys: bool = False, ) -> None: - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.scaler = ScaleIntensityRangePercentiles(lower, upper, b_min, b_max, clip, relative) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.scaler(d[key]) return d @@ -477,6 +530,7 @@ class MaskIntensityd(MapTransform): if None, will extract the mask data from input data based on `mask_key`. mask_key: the key to extract mask data from input dictionary, only works when `mask_data` is None. + allow_missing_keys: don't raise exception if key is missing. """ @@ -485,14 +539,15 @@ def __init__( keys: KeysCollection, mask_data: Optional[np.ndarray] = None, mask_key: Optional[str] = None, + allow_missing_keys: bool = False, ) -> None: - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.converter = MaskIntensity(mask_data) self.mask_key = mask_key if mask_data is None else None def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.converter(d[key], d[self.mask_key]) if self.mask_key is not None else self.converter(d[key]) return d @@ -509,16 +564,23 @@ class GaussianSmoothd(MapTransform): use it for all spatial dimensions. approx: discrete Gaussian kernel type, available options are "erf", "sampled", and "scalespace". see also :py:meth:`monai.networks.layers.GaussianFilter`. + allow_missing_keys: don't raise exception if key is missing. """ - def __init__(self, keys: KeysCollection, sigma: Union[Sequence[float], float], approx: str = "erf") -> None: - super().__init__(keys) + def __init__( + self, + keys: KeysCollection, + sigma: Union[Sequence[float], float], + approx: str = "erf", + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) self.converter = GaussianSmooth(sigma, approx=approx) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.converter(d[key]) return d @@ -536,6 +598,7 @@ class RandGaussianSmoothd(RandomizableTransform, MapTransform): approx: discrete Gaussian kernel type, available options are "erf", "sampled", and "scalespace". see also :py:meth:`monai.networks.layers.GaussianFilter`. prob: probability of Gaussian smooth. + allow_missing_keys: don't raise exception if key is missing. """ @@ -547,8 +610,9 @@ def __init__( sigma_z: Tuple[float, float] = (0.25, 1.5), approx: str = "erf", prob: float = 0.1, + allow_missing_keys: bool = False, ) -> None: - MapTransform.__init__(self, keys) + MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) self.sigma_x = sigma_x self.sigma_y = sigma_y @@ -566,7 +630,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self.randomize() if not self._do_transform: return d - for key in self.keys: + for key in self.key_iterator(d): sigma = ensure_tuple_size(tup=(self.x, self.y, self.z), dim=d[key].ndim - 1) d[key] = GaussianSmooth(sigma=sigma, approx=self.approx)(d[key]) return d @@ -588,6 +652,7 @@ class GaussianSharpend(MapTransform): alpha: weight parameter to compute the final result. approx: discrete Gaussian kernel type, available options are "erf", "sampled", and "scalespace". see also :py:meth:`monai.networks.layers.GaussianFilter`. + allow_missing_keys: don't raise exception if key is missing. """ @@ -598,13 +663,14 @@ def __init__( sigma2: Union[Sequence[float], float] = 1.0, alpha: float = 30.0, approx: str = "erf", + allow_missing_keys: bool = False, ) -> None: - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.converter = GaussianSharpen(sigma1, sigma2, alpha, approx=approx) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.converter(d[key]) return d @@ -629,6 +695,7 @@ class RandGaussianSharpend(RandomizableTransform, MapTransform): approx: discrete Gaussian kernel type, available options are "erf", "sampled", and "scalespace". see also :py:meth:`monai.networks.layers.GaussianFilter`. prob: probability of Gaussian sharpen. + allow_missing_keys: don't raise exception if key is missing. """ @@ -644,8 +711,9 @@ def __init__( alpha: Tuple[float, float] = (10.0, 30.0), approx: str = "erf", prob: float = 0.1, + allow_missing_keys: bool = False, ): - MapTransform.__init__(self, keys) + MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) self.sigma1_x = sigma1_x self.sigma1_y = sigma1_y @@ -674,7 +742,7 @@ def __call__(self, data): self.randomize() if not self._do_transform: return d - for key in self.keys: + for key in self.key_iterator(d): sigma1 = ensure_tuple_size(tup=(self.x1, self.y1, self.z1), dim=d[key].ndim - 1) sigma2 = ensure_tuple_size(tup=(self.x2, self.y2, self.z2), dim=d[key].ndim - 1) d[key] = GaussianSharpen(sigma1=sigma1, sigma2=sigma2, alpha=self.a, approx=self.approx)(d[key]) @@ -693,12 +761,17 @@ class RandHistogramShiftd(RandomizableTransform, MapTransform): a smaller number of control points allows for larger intensity shifts. if two values provided, number of control points selecting from range (min_value, max_value). prob: probability of histogram shift. + allow_missing_keys: don't raise exception if key is missing. """ def __init__( - self, keys: KeysCollection, num_control_points: Union[Tuple[int, int], int] = 10, prob: float = 0.1 + self, + keys: KeysCollection, + num_control_points: Union[Tuple[int, int], int] = 10, + prob: float = 0.1, + allow_missing_keys: bool = False, ) -> None: - MapTransform.__init__(self, keys) + MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) if isinstance(num_control_points, int): if num_control_points <= 2: @@ -726,7 +799,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self.randomize() if not self._do_transform: return d - for key in self.keys: + for key in self.key_iterator(d): img_min, img_max = d[key].min(), d[key].max() reference_control_points_scaled = self.reference_control_points * (img_max - img_min) + img_min floating_control_points_scaled = self.floating_control_points * (img_max - img_min) + img_min diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index d9b6b5e6ab..ea965255d5 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -60,6 +60,7 @@ def __init__( meta_key_postfix: str = "meta_dict", overwriting: bool = False, image_only: bool = False, + allow_missing_keys: bool = False, *args, **kwargs, ) -> None: @@ -79,10 +80,11 @@ def __init__( default is False, which will raise exception if encountering existing key. image_only: if True return dictionary containing just only the image volumes, otherwise return dictionary containing image data array and header dict per input key. + allow_missing_keys: don't raise exception if key is missing. args: additional parameters for reader if providing a reader name. kwargs: additional parameters for reader if providing a reader name. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self._loader = LoadImage(reader, image_only, dtype, *args, **kwargs) if not isinstance(meta_key_postfix, str): raise TypeError(f"meta_key_postfix must be a str but is {type(meta_key_postfix).__name__}.") @@ -99,7 +101,7 @@ def __call__(self, data, reader: Optional[ImageReader] = None): """ d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): data = self._loader(d[key], reader) if self._loader.image_only: if not isinstance(data, np.ndarray): @@ -163,6 +165,7 @@ class SaveImaged(MapTransform): it's used for NIfTI format only. save_batch: whether the import image is a batch data, default to `False`. usually pre-transforms run for channel first data, while post-transforms run for batch data. + allow_missing_keys: don't raise exception if key is missing. """ @@ -180,8 +183,9 @@ def __init__( dtype: DtypeLike = np.float64, output_dtype: DtypeLike = np.float32, save_batch: bool = False, + allow_missing_keys: bool = False, ) -> None: - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.meta_key_postfix = meta_key_postfix self._saver = SaveImage( output_dir=output_dir, @@ -198,7 +202,7 @@ def __init__( def __call__(self, data): d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): meta_data = d[f"{key}_{self.meta_key_postfix}"] if self.meta_key_postfix is not None else None self._saver(img=d[key], meta_data=meta_data) return d diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 85abdac0ac..42796e2412 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -71,6 +71,7 @@ def __init__( sigmoid: Union[Sequence[bool], bool] = False, softmax: Union[Sequence[bool], bool] = False, other: Optional[Union[Sequence[Callable], Callable]] = None, + allow_missing_keys: bool = False, ) -> None: """ Args: @@ -83,9 +84,10 @@ def __init__( other: callable function to execute other activation layers, for example: `other = lambda x: torch.tanh(x)`. it also can be a sequence of Callable, each element corresponds to a key in ``keys``. + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.sigmoid = ensure_tuple_rep(sigmoid, len(self.keys)) self.softmax = ensure_tuple_rep(softmax, len(self.keys)) self.other = ensure_tuple_rep(other, len(self.keys)) @@ -93,8 +95,8 @@ def __init__( def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) - for idx, key in enumerate(self.keys): - d[key] = self.converter(d[key], self.sigmoid[idx], self.softmax[idx], self.other[idx]) + for key, sigmoid, softmax, other in self.key_iterator(d, self.sigmoid, self.softmax, self.other): + d[key] = self.converter(d[key], sigmoid, softmax, other) return d @@ -111,6 +113,7 @@ def __init__( n_classes: Optional[Union[Sequence[int], int]] = None, threshold_values: Union[Sequence[bool], bool] = False, logit_thresh: Union[Sequence[float], float] = 0.5, + allow_missing_keys: bool = False, ) -> None: """ Args: @@ -126,9 +129,10 @@ def __init__( it also can be a sequence of bool, each element corresponds to a key in ``keys``. logit_thresh: the threshold value for thresholding operation, default is 0.5. it also can be a sequence of float, each element corresponds to a key in ``keys``. + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.argmax = ensure_tuple_rep(argmax, len(self.keys)) self.to_onehot = ensure_tuple_rep(to_onehot, len(self.keys)) self.n_classes = ensure_tuple_rep(n_classes, len(self.keys)) @@ -138,14 +142,16 @@ def __init__( def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) - for idx, key in enumerate(self.keys): + for key, argmax, to_onehot, n_classes, threshold_values, logit_thresh in self.key_iterator( + d, self.argmax, self.to_onehot, self.n_classes, self.threshold_values, self.logit_thresh + ): d[key] = self.converter( d[key], - self.argmax[idx], - self.to_onehot[idx], - self.n_classes[idx], - self.threshold_values[idx], - self.logit_thresh[idx], + argmax, + to_onehot, + n_classes, + threshold_values, + logit_thresh, ) return d @@ -161,6 +167,7 @@ def __init__( applied_labels: Union[Sequence[int], int], independent: bool = True, connectivity: Optional[int] = None, + allow_missing_keys: bool = False, ) -> None: """ Args: @@ -175,14 +182,15 @@ def __init__( connectivity: Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor. Accepted values are ranging from 1 to input.ndim. If ``None``, a full connectivity of ``input.ndim`` is used. + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.converter = KeepLargestConnectedComponent(applied_labels, independent, connectivity) def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.converter(d[key]) return d @@ -192,20 +200,21 @@ class LabelToContourd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.LabelToContour`. """ - def __init__(self, keys: KeysCollection, kernel_type: str = "Laplace") -> None: + def __init__(self, keys: KeysCollection, kernel_type: str = "Laplace", allow_missing_keys: bool = False) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` kernel_type: the method applied to do edge detection, default is "Laplace". + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.converter = LabelToContour(kernel_type=kernel_type) def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.converter(d[key]) return d @@ -221,6 +230,7 @@ def __init__( keys: KeysCollection, ensemble: Callable[[Union[Sequence[torch.Tensor], torch.Tensor]], torch.Tensor], output_key: Optional[str] = None, + allow_missing_keys: bool = False, ) -> None: """ Args: @@ -229,13 +239,14 @@ def __init__( output_key: the key to store ensemble result in the dictionary. ensemble: callable method to execute ensemble on specified data. if only 1 key provided in `keys`, `output_key` can be None and use `keys` as default. + allow_missing_keys: don't raise exception if key is missing. Raises: TypeError: When ``ensemble`` is not ``callable``. ValueError: When ``len(keys) > 1`` and ``output_key=None``. Incompatible values. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) if not callable(ensemble): raise TypeError(f"ensemble must be callable but is {type(ensemble).__name__}.") self.ensemble = ensemble @@ -249,7 +260,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc if len(self.keys) == 1: items = d[self.keys[0]] else: - items = [d[key] for key in self.keys] + items = [d[key] for key in self.key_iterator(d)] d[self.output_key] = self.ensemble(items) return d diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index f29258bf28..a81aeb432b 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -125,6 +125,7 @@ def __init__( align_corners: Union[Sequence[bool], bool] = False, dtype: Optional[Union[Sequence[DtypeLike], DtypeLike]] = np.float64, meta_key_postfix: str = "meta_dict", + allow_missing_keys: bool = False, ) -> None: """ Args: @@ -160,12 +161,13 @@ def __init__( default is `meta_dict`, the meta data is a dictionary object. For example, to handle key `image`, read/write affine matrices from the metadata `image_meta_dict` dictionary's `affine` field. + allow_missing_keys: don't raise exception if key is missing. Raises: TypeError: When ``meta_key_postfix`` is not a ``str``. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.spacing_transform = Spacing(pixdim, diagonal=diagonal) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) @@ -179,17 +181,19 @@ def __call__( self, data: Mapping[Union[Hashable, str], Dict[str, np.ndarray]] ) -> Dict[Union[Hashable, str], Union[np.ndarray, Dict[str, np.ndarray]]]: d: Dict = dict(data) - for idx, key in enumerate(self.keys): + for key, mode, padding_mode, align_corners, dtype in self.key_iterator( + d, self.mode, self.padding_mode, self.align_corners, self.dtype + ): meta_data = d[f"{key}_{self.meta_key_postfix}"] # resample array of each corresponding key # using affine fetched from d[affine_key] d[key], _, new_affine = self.spacing_transform( data_array=np.asarray(d[key]), affine=meta_data["affine"], - mode=self.mode[idx], - padding_mode=self.padding_mode[idx], - align_corners=self.align_corners[idx], - dtype=self.dtype[idx], + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, + dtype=dtype, ) # set the 'affine' key meta_data["affine"] = new_affine @@ -214,6 +218,7 @@ def __init__( as_closest_canonical: bool = False, labels: Optional[Sequence[Tuple[str, str]]] = tuple(zip("LPI", "RAS")), meta_key_postfix: str = "meta_dict", + allow_missing_keys: bool = False, ) -> None: """ Args: @@ -230,6 +235,7 @@ def __init__( default is `meta_dict`, the meta data is a dictionary object. For example, to handle key `image`, read/write affine matrices from the metadata `image_meta_dict` dictionary's `affine` field. + allow_missing_keys: don't raise exception if key is missing. Raises: TypeError: When ``meta_key_postfix`` is not a ``str``. @@ -238,7 +244,7 @@ def __init__( `nibabel.orientations.ornt2axcodes`. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.ornt_transform = Orientation(axcodes=axcodes, as_closest_canonical=as_closest_canonical, labels=labels) if not isinstance(meta_key_postfix, str): raise TypeError(f"meta_key_postfix must be a str but is {type(meta_key_postfix).__name__}.") @@ -248,7 +254,7 @@ def __call__( self, data: Mapping[Union[Hashable, str], Dict[str, np.ndarray]] ) -> Dict[Union[Hashable, str], Union[np.ndarray, Dict[str, np.ndarray]]]: d: Dict = dict(data) - for key in self.keys: + for key in self.key_iterator(d): meta_data = d[f"{key}_{self.meta_key_postfix}"] d[key], _, new_affine = self.ornt_transform(d[key], affine=meta_data["affine"]) meta_data["affine"] = new_affine @@ -260,19 +266,22 @@ class Rotate90d(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.Rotate90`. """ - def __init__(self, keys: KeysCollection, k: int = 1, spatial_axes: Tuple[int, int] = (0, 1)) -> None: + def __init__( + self, keys: KeysCollection, k: int = 1, spatial_axes: Tuple[int, int] = (0, 1), allow_missing_keys: bool = False + ) -> None: """ Args: k: number of times to rotate by 90 degrees. spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes. Default: (0, 1), this is the first two axis in spatial dimensions. + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.rotator = Rotate90(k, spatial_axes) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.rotator(d[key]) return d @@ -290,6 +299,7 @@ def __init__( prob: float = 0.1, max_k: int = 3, spatial_axes: Tuple[int, int] = (0, 1), + allow_missing_keys: bool = False, ) -> None: """ Args: @@ -301,8 +311,9 @@ def __init__( (Default 3) spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes. Default: (0, 1), this is the first two axis in spatial dimensions. + allow_missing_keys: don't raise exception if key is missing. """ - MapTransform.__init__(self, keys) + MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) self.max_k = max_k @@ -319,7 +330,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Mapping[Hashable, np. d = dict(data) rotator = Rotate90(self._rand_k, self.spatial_axes) - for key in self.keys: + for key in self.key_iterator(d): if self._do_transform: d[key] = rotator(d[key]) return d @@ -344,6 +355,7 @@ class Resized(MapTransform): 'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None. See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate It also can be a sequence of bool or None, each element corresponds to a key in ``keys``. + allow_missing_keys: don't raise exception if key is missing. """ def __init__( @@ -352,16 +364,17 @@ def __init__( spatial_size: Union[Sequence[int], int], mode: InterpolateModeSequence = InterpolateMode.AREA, align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None, + allow_missing_keys: bool = False, ) -> None: - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.resizer = Resize(spatial_size=spatial_size) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for idx, key in enumerate(self.keys): - d[key] = self.resizer(d[key], mode=self.mode[idx], align_corners=self.align_corners[idx]) + for key, mode, align_corners in self.key_iterator(d, self.mode, self.align_corners): + d[key] = self.resizer(d[key], mode=mode, align_corners=align_corners) return d @@ -383,6 +396,7 @@ def __init__( padding_mode: GridSamplePadModeSequence = GridSamplePadMode.REFLECTION, as_tensor_output: bool = True, device: Optional[torch.device] = None, + allow_missing_keys: bool = False, ) -> None: """ Args: @@ -416,12 +430,13 @@ def __init__( as_tensor_output: the computation is implemented using pytorch tensors, this option specifies whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. + allow_missing_keys: don't raise exception if key is missing. See also: - :py:class:`monai.transforms.compose.MapTransform` - :py:class:`RandAffineGrid` for the random affine parameters configurations. """ - MapTransform.__init__(self, keys) + MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) self.rand_affine = RandAffine( prob=1.0, # because probability handled in this class @@ -459,8 +474,8 @@ def __call__( else: grid = create_grid(spatial_size=sp_size) - for idx, key in enumerate(self.keys): - d[key] = self.rand_affine.resampler(d[key], grid, mode=self.mode[idx], padding_mode=self.padding_mode[idx]) + for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): + d[key] = self.rand_affine.resampler(d[key], grid, mode=mode, padding_mode=padding_mode) return d @@ -484,6 +499,7 @@ def __init__( padding_mode: GridSamplePadModeSequence = GridSamplePadMode.REFLECTION, as_tensor_output: bool = False, device: Optional[torch.device] = None, + allow_missing_keys: bool = False, ) -> None: """ Args: @@ -521,12 +537,13 @@ def __init__( as_tensor_output: the computation is implemented using pytorch tensors, this option specifies whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. + allow_missing_keys: don't raise exception if key is missing. See also: - :py:class:`RandAffineGrid` for the random affine parameters configurations. - :py:class:`Affine` for the affine transformation parameters configurations. """ - MapTransform.__init__(self, keys) + MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) self.rand_2d_elastic = Rand2DElastic( spacing=spacing, @@ -576,10 +593,8 @@ def __call__( else: grid = create_grid(spatial_size=sp_size) - for idx, key in enumerate(self.keys): - d[key] = self.rand_2d_elastic.resampler( - d[key], grid, mode=self.mode[idx], padding_mode=self.padding_mode[idx] - ) + for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): + d[key] = self.rand_2d_elastic.resampler(d[key], grid, mode=mode, padding_mode=padding_mode) return d @@ -603,6 +618,7 @@ def __init__( padding_mode: GridSamplePadModeSequence = GridSamplePadMode.REFLECTION, as_tensor_output: bool = False, device: Optional[torch.device] = None, + allow_missing_keys: bool = False, ) -> None: """ Args: @@ -641,12 +657,13 @@ def __init__( as_tensor_output: the computation is implemented using pytorch tensors, this option specifies whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. + allow_missing_keys: don't raise exception if key is missing. See also: - :py:class:`RandAffineGrid` for the random affine parameters configurations. - :py:class:`Affine` for the affine transformation parameters configurations. """ - MapTransform.__init__(self, keys) + MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) self.rand_3d_elastic = Rand3DElastic( sigma_range=sigma_range, @@ -690,10 +707,8 @@ def __call__( grid[:3] += gaussian(offset)[0] * self.rand_3d_elastic.magnitude grid = self.rand_3d_elastic.rand_affine_grid(grid=grid) - for idx, key in enumerate(self.keys): - d[key] = self.rand_3d_elastic.resampler( - d[key], grid, mode=self.mode[idx], padding_mode=self.padding_mode[idx] - ) + for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): + d[key] = self.rand_3d_elastic.resampler(d[key], grid, mode=mode, padding_mode=padding_mode) return d @@ -707,15 +722,21 @@ class Flipd(MapTransform): Args: keys: Keys to pick data for transformation. spatial_axis: Spatial axes along which to flip over. Default is None. + allow_missing_keys: don't raise exception if key is missing. """ - def __init__(self, keys: KeysCollection, spatial_axis: Optional[Union[Sequence[int], int]] = None) -> None: - super().__init__(keys) + def __init__( + self, + keys: KeysCollection, + spatial_axis: Optional[Union[Sequence[int], int]] = None, + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) self.flipper = Flip(spatial_axis=spatial_axis) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.flipper(d[key]) return d @@ -731,6 +752,7 @@ class RandFlipd(RandomizableTransform, MapTransform): keys: Keys to pick data for transformation. prob: Probability of flipping. spatial_axis: Spatial axes along which to flip over. Default is None. + allow_missing_keys: don't raise exception if key is missing. """ def __init__( @@ -738,8 +760,9 @@ def __init__( keys: KeysCollection, prob: float = 0.1, spatial_axis: Optional[Union[Sequence[int], int]] = None, + allow_missing_keys: bool = False, ) -> None: - MapTransform.__init__(self, keys) + MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) self.spatial_axis = spatial_axis @@ -748,7 +771,7 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: self.randomize(None) d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): if self._do_transform: d[key] = self.flipper(d[key]) return d @@ -764,11 +787,12 @@ class RandAxisFlipd(RandomizableTransform, MapTransform): Args: keys: Keys to pick data for transformation. prob: Probability of flipping. + allow_missing_keys: don't raise exception if key is missing. """ - def __init__(self, keys: KeysCollection, prob: float = 0.1) -> None: - MapTransform.__init__(self, keys) + def __init__(self, keys: KeysCollection, prob: float = 0.1, allow_missing_keys: bool = False) -> None: + MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) self._axis: Optional[int] = None @@ -781,7 +805,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda flipper = Flip(spatial_axis=self._axis) d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): if self._do_transform: d[key] = flipper(d[key]) return d @@ -812,6 +836,7 @@ class Rotated(MapTransform): If None, use the data type of input data. To be compatible with other modules, the output data type is always ``np.float32``. It also can be a sequence of dtype or None, each element corresponds to a key in ``keys``. + allow_missing_keys: don't raise exception if key is missing. """ def __init__( @@ -823,8 +848,9 @@ def __init__( padding_mode: GridSamplePadModeSequence = GridSamplePadMode.BORDER, align_corners: Union[Sequence[bool], bool] = False, dtype: Union[Sequence[DtypeLike], DtypeLike] = np.float64, + allow_missing_keys: bool = False, ) -> None: - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.rotator = Rotate(angle=angle, keep_size=keep_size) self.mode = ensure_tuple_rep(mode, len(self.keys)) @@ -834,13 +860,15 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for idx, key in enumerate(self.keys): + for key, mode, padding_mode, align_corners, dtype in self.key_iterator( + d, self.mode, self.padding_mode, self.align_corners, self.dtype + ): d[key] = self.rotator( d[key], - mode=self.mode[idx], - padding_mode=self.padding_mode[idx], - align_corners=self.align_corners[idx], - dtype=self.dtype[idx], + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, + dtype=dtype, ) return d @@ -877,6 +905,7 @@ class RandRotated(RandomizableTransform, MapTransform): If None, use the data type of input data. To be compatible with other modules, the output data type is always ``np.float32``. It also can be a sequence of dtype or None, each element corresponds to a key in ``keys``. + allow_missing_keys: don't raise exception if key is missing. """ def __init__( @@ -891,8 +920,9 @@ def __init__( padding_mode: GridSamplePadModeSequence = GridSamplePadMode.BORDER, align_corners: Union[Sequence[bool], bool] = False, dtype: Union[Sequence[DtypeLike], DtypeLike] = np.float64, + allow_missing_keys: bool = False, ) -> None: - MapTransform.__init__(self, keys) + MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) self.range_x = ensure_tuple(range_x) if len(self.range_x) == 1: @@ -929,13 +959,15 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda angle=self.x if d[self.keys[0]].ndim == 3 else (self.x, self.y, self.z), keep_size=self.keep_size, ) - for idx, key in enumerate(self.keys): + for key, mode, padding_mode, align_corners, dtype in self.key_iterator( + d, self.mode, self.padding_mode, self.align_corners, self.dtype + ): d[key] = rotator( d[key], - mode=self.mode[idx], - padding_mode=self.padding_mode[idx], - align_corners=self.align_corners[idx], - dtype=self.dtype[idx], + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, + dtype=dtype, ) return d @@ -962,6 +994,7 @@ class Zoomd(MapTransform): See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate It also can be a sequence of bool or None, each element corresponds to a key in ``keys``. keep_size: Should keep original size (pad if needed), default is True. + allow_missing_keys: don't raise exception if key is missing. """ def __init__( @@ -972,8 +1005,9 @@ def __init__( padding_mode: NumpyPadModeSequence = NumpyPadMode.EDGE, align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None, keep_size: bool = True, + allow_missing_keys: bool = False, ) -> None: - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) @@ -981,12 +1015,14 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for idx, key in enumerate(self.keys): + for key, mode, padding_mode, align_corners in self.key_iterator( + d, self.mode, self.padding_mode, self.align_corners + ): d[key] = self.zoomer( d[key], - mode=self.mode[idx], - padding_mode=self.padding_mode[idx], - align_corners=self.align_corners[idx], + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, ) return d @@ -1021,6 +1057,7 @@ class RandZoomd(RandomizableTransform, MapTransform): See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate It also can be a sequence of bool or None, each element corresponds to a key in ``keys``. keep_size: Should keep original size (pad if needed), default is True. + allow_missing_keys: don't raise exception if key is missing. """ def __init__( @@ -1033,8 +1070,9 @@ def __init__( padding_mode: NumpyPadModeSequence = NumpyPadMode.EDGE, align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None, keep_size: bool = True, + allow_missing_keys: bool = False, ) -> None: - MapTransform.__init__(self, keys) + MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) self.min_zoom = ensure_tuple(min_zoom) self.max_zoom = ensure_tuple(max_zoom) @@ -1067,12 +1105,14 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda # if 2 zoom factors provided for 3D data, use the first factor for H and W dims, second factor for D dim self._zoom = ensure_tuple_rep(self._zoom[0], img_dims - 2) + ensure_tuple(self._zoom[-1]) zoomer = Zoom(self._zoom, keep_size=self.keep_size) - for idx, key in enumerate(self.keys): + for key, mode, padding_mode, align_corners in self.key_iterator( + d, self.mode, self.padding_mode, self.align_corners + ): d[key] = zoomer( d[key], - mode=self.mode[idx], - padding_mode=self.padding_mode[idx], - align_corners=self.align_corners[idx], + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, ) return d diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 9c9729d250..7a09efa6d5 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -13,7 +13,7 @@ """ from abc import ABC, abstractmethod -from typing import Any, Hashable, Optional, Tuple +from typing import Any, Dict, Generator, Hashable, Iterable, List, Optional, Tuple import numpy as np @@ -178,7 +178,7 @@ def __call__(self, data): if key in data: # update output data with some_transform_function(data[key]). else: - # do nothing or some exceptions handling. + # raise exception unless allow_missing_keys==True. return data Raises: @@ -187,8 +187,9 @@ def __call__(self, data): """ - def __init__(self, keys: KeysCollection) -> None: + def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: self.keys: Tuple[Hashable, ...] = ensure_tuple(keys) + self.allow_missing_keys = allow_missing_keys if not self.keys: raise ValueError("keys must be non empty.") for key in self.keys: @@ -224,3 +225,29 @@ def __call__(self, data): """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + def key_iterator( + self, + data: Dict[Hashable, Any], + *extra_iterables: Optional[Iterable], + ) -> Generator: + """ + Iterate across keys and optionally extra iterables. If key is missing, exception is raised if + `allow_missing_keys==False` (default). If `allow_missing_keys==True`, key is skipped. + + Args: + data: data that the transform will be applied to + extra_iterables: anything else to be iterated through + """ + # if no extra iterables given, create a dummy list of Nones + ex_iters = extra_iterables if extra_iterables else [[None] * len(self.keys)] + + # loop over keys and any extra iterables + _ex_iters: List[Any] + for key, *_ex_iters in zip(self.keys, *ex_iters): + # all normal, yield (what we yield depends on whether extra iterables were given) + if key in data.keys(): + yield (key,) + tuple(_ex_iters) if extra_iterables else key + # if missing keys not allowed, raise + elif not self.allow_missing_keys: + raise KeyError(f"Key was missing ({key}) and allow_missing_keys==False") diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 4a0808fdbb..14f34fb663 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -137,21 +137,22 @@ class Identityd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.Identity`. """ - def __init__(self, keys: KeysCollection) -> None: + def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.identity = Identity() def __call__( self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.identity(d[key]) return d @@ -161,19 +162,20 @@ class AsChannelFirstd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.AsChannelFirst`. """ - def __init__(self, keys: KeysCollection, channel_dim: int = -1) -> None: + def __init__(self, keys: KeysCollection, channel_dim: int = -1, allow_missing_keys: bool = False) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` channel_dim: which dimension of input image is the channel, default is the last dimension. + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.converter = AsChannelFirst(channel_dim=channel_dim) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.converter(d[key]) return d @@ -183,19 +185,20 @@ class AsChannelLastd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.AsChannelLast`. """ - def __init__(self, keys: KeysCollection, channel_dim: int = 0) -> None: + def __init__(self, keys: KeysCollection, channel_dim: int = 0, allow_missing_keys: bool = False) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` channel_dim: which dimension of input image is the channel, default is the first dimension. + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.converter = AsChannelLast(channel_dim=channel_dim) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.converter(d[key]) return d @@ -205,18 +208,19 @@ class AddChanneld(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.AddChannel`. """ - def __init__(self, keys: KeysCollection) -> None: + def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.adder = AddChannel() def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.adder(d[key]) return d @@ -252,19 +256,20 @@ class RepeatChanneld(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.RepeatChannel`. """ - def __init__(self, keys: KeysCollection, repeats: int) -> None: + def __init__(self, keys: KeysCollection, repeats: int, allow_missing_keys: bool = False) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` repeats: the number of repetitions for each element. + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.repeater = RepeatChannel(repeats) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.repeater(d[key]) return d @@ -274,19 +279,20 @@ class RemoveRepeatedChanneld(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.RemoveRepeatedChannel`. """ - def __init__(self, keys: KeysCollection, repeats: int) -> None: + def __init__(self, keys: KeysCollection, repeats: int, allow_missing_keys: bool = False) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` repeats: the number of repetitions for each element. + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.repeater = RemoveRepeatedChannel(repeats) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.repeater(d[key]) return d @@ -303,6 +309,7 @@ def __init__( keys: KeysCollection, output_postfixes: Optional[Sequence[str]] = None, channel_dim: Optional[int] = None, + allow_missing_keys: bool = False, ) -> None: """ Args: @@ -316,9 +323,10 @@ def __init__( to automatically select: if data is numpy array, channel_dim is 0 as `numpy array` is used in the pre transforms, if PyTorch Tensor, channel_dim is 1 as in most of the cases `Tensor` is uses in the post transforms. + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.output_postfixes = output_postfixes self.splitter = SplitChannel(channel_dim=channel_dim) @@ -326,7 +334,7 @@ def __call__( self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): rets = self.splitter(d[key]) postfixes: Sequence = list(range(len(rets))) if self.output_postfixes is None else self.output_postfixes if len(postfixes) != len(rets): @@ -348,6 +356,7 @@ def __init__( self, keys: KeysCollection, dtype: Union[Sequence[Union[DtypeLike, torch.dtype]], DtypeLike, torch.dtype] = np.float32, + allow_missing_keys: bool = False, ) -> None: """ Args: @@ -356,9 +365,10 @@ def __init__( dtype: convert image to this data type, default is `np.float32`. it also can be a sequence of dtypes or torch.dtype, each element corresponds to a key in ``keys``. + allow_missing_keys: don't raise exception if key is missing. """ - MapTransform.__init__(self, keys) + MapTransform.__init__(self, keys, allow_missing_keys) self.dtype = ensure_tuple_rep(dtype, len(self.keys)) self.converter = CastToType() @@ -366,8 +376,8 @@ def __call__( self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: d = dict(data) - for idx, key in enumerate(self.keys): - d[key] = self.converter(d[key], dtype=self.dtype[idx]) + for key, dtype in self.key_iterator(d, self.dtype): + d[key] = self.converter(d[key], dtype=dtype) return d @@ -377,20 +387,21 @@ class ToTensord(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.ToTensor`. """ - def __init__(self, keys: KeysCollection) -> None: + def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.converter = ToTensor() def __call__( self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor, PILImageImage]] ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor, PILImageImage]]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.converter(d[key]) return d @@ -400,20 +411,21 @@ class ToNumpyd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.ToNumpy`. """ - def __init__(self, keys: KeysCollection) -> None: + def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.converter = ToNumpy() def __call__( self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor, PILImageImage]] ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor, PILImageImage]]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.converter(d[key]) return d @@ -423,20 +435,21 @@ class ToPILd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.ToNumpy`. """ - def __init__(self, keys: KeysCollection) -> None: + def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.converter = ToPIL() def __call__( self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor, PILImageImage]] ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor, PILImageImage]]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.converter(d[key]) return d @@ -448,7 +461,7 @@ class DeleteItemsd(MapTransform): """ def __call__(self, data): - return {key: val for key, val in data.items() if key not in self.keys} + return {key: val for key, val in data.items() if key not in self.key_iterator(data)} class SelectItemsd(MapTransform): @@ -458,7 +471,7 @@ class SelectItemsd(MapTransform): """ def __call__(self, data): - result = {key: val for key, val in data.items() if key in self.keys} + result = {key: data[key] for key in self.key_iterator(data)} return result @@ -467,19 +480,20 @@ class SqueezeDimd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.SqueezeDim`. """ - def __init__(self, keys: KeysCollection, dim: int = 0) -> None: + def __init__(self, keys: KeysCollection, dim: int = 0, allow_missing_keys: bool = False) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` dim: dimension to be squeezed. Default: 0 (the first dimension) + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.converter = SqueezeDim(dim=dim) def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.converter(d[key]) return d @@ -498,6 +512,7 @@ def __init__( data_value: Union[Sequence[bool], bool] = False, additional_info: Optional[Union[Sequence[Callable], Callable]] = None, logger_handler: Optional[logging.Handler] = None, + allow_missing_keys: bool = False, ) -> None: """ Args: @@ -517,9 +532,10 @@ def __init__( corresponds to a key in ``keys``. logger_handler: add additional handler to output data: save to file, etc. add existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.prefix = ensure_tuple_rep(prefix, len(self.keys)) self.data_shape = ensure_tuple_rep(data_shape, len(self.keys)) self.value_range = ensure_tuple_rep(value_range, len(self.keys)) @@ -530,14 +546,16 @@ def __init__( def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: d = dict(data) - for idx, key in enumerate(self.keys): + for key, prefix, data_shape, value_range, data_value, additional_info in self.key_iterator( + d, self.prefix, self.data_shape, self.value_range, self.data_value, self.additional_info + ): d[key] = self.printer( d[key], - self.prefix[idx], - self.data_shape[idx], - self.value_range[idx], - self.data_value[idx], - self.additional_info[idx], + prefix, + data_shape, + value_range, + data_value, + additional_info, ) return d @@ -547,23 +565,26 @@ class SimulateDelayd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.SimulateDelay`. """ - def __init__(self, keys: KeysCollection, delay_time: Union[Sequence[float], float] = 0.0) -> None: + def __init__( + self, keys: KeysCollection, delay_time: Union[Sequence[float], float] = 0.0, allow_missing_keys: bool = False + ) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` delay_time: The minimum amount of time, in fractions of seconds, to accomplish this identity task. It also can be a sequence of string, each element corresponds to a key in ``keys``. + allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.delay_time = ensure_tuple_rep(delay_time, len(self.keys)) self.delayer = SimulateDelay() def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: d = dict(data) - for idx, key in enumerate(self.keys): - d[key] = self.delayer(d[key], delay_time=self.delay_time[idx]) + for key, delay_time in self.key_iterator(d, self.delay_time): + d[key] = self.delayer(d[key], delay_time=delay_time) return d @@ -574,7 +595,9 @@ class CopyItemsd(MapTransform): """ - def __init__(self, keys: KeysCollection, times: int, names: KeysCollection) -> None: + def __init__( + self, keys: KeysCollection, times: int, names: KeysCollection, allow_missing_keys: bool = False + ) -> None: """ Args: keys: keys of the corresponding items to be transformed. @@ -584,13 +607,14 @@ def __init__(self, keys: KeysCollection, times: int, names: KeysCollection) -> N names: the names corresponding to the newly copied data, the length should match `len(keys) x times`. for example, if keys is ["img", "seg"] and times is 2, names can be: ["img_1", "seg_1", "img_2", "seg_2"]. + allow_missing_keys: don't raise exception if key is missing. Raises: ValueError: When ``times`` is nonpositive. ValueError: When ``len(names)`` is not ``len(keys) * times``. Incompatible values. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) if times < 1: raise ValueError(f"times must be positive, got {times}.") self.times = times @@ -609,13 +633,14 @@ def __call__(self, data): """ d = dict(data) - for key, new_key in zip(self.keys * self.times, self.names): + for new_key in self.names: if new_key in d: raise KeyError(f"Key {new_key} already exists in data.") - if isinstance(d[key], torch.Tensor): - d[new_key] = d[key].detach().clone() - else: - d[new_key] = copy.deepcopy(d[key]) + for key in self.key_iterator(d): + if isinstance(d[key], torch.Tensor): + d[new_key] = d[key].detach().clone() + else: + d[new_key] = copy.deepcopy(d[key]) return d @@ -626,19 +651,20 @@ class ConcatItemsd(MapTransform): """ - def __init__(self, keys: KeysCollection, name: str, dim: int = 0) -> None: + def __init__(self, keys: KeysCollection, name: str, dim: int = 0, allow_missing_keys: bool = False) -> None: """ Args: keys: keys of the corresponding items to be concatenated together. See also: :py:class:`monai.transforms.compose.MapTransform` name: the name corresponding to the key to store the concatenated data. dim: on which dimension to concatenate the items, default is 0. + allow_missing_keys: don't raise exception if key is missing. Raises: ValueError: When insufficient keys are given (``len(self.keys) < 2``). """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) if len(self.keys) < 2: raise ValueError("Concatenation requires at least 2 keys.") self.name = name @@ -654,7 +680,7 @@ def __call__(self, data): d = dict(data) output = [] data_type = None - for key in self.keys: + for key in self.key_iterator(d): if data_type is None: data_type = type(d[key]) elif not isinstance(d[key], data_type): @@ -690,6 +716,7 @@ class Lambdad(MapTransform): each element corresponds to a key in ``keys``. overwrite: whether to overwrite the original data in the input dictionary with lamdbda function output. default to True. it also can be a sequence of bool, each element corresponds to a key in ``keys``. + allow_missing_keys: don't raise exception if key is missing. """ def __init__( @@ -697,17 +724,18 @@ def __init__( keys: KeysCollection, func: Union[Sequence[Callable], Callable], overwrite: Union[Sequence[bool], bool] = True, + allow_missing_keys: bool = False, ) -> None: - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.func = ensure_tuple_rep(func, len(self.keys)) self.overwrite = ensure_tuple_rep(overwrite, len(self.keys)) self._lambd = Lambda() def __call__(self, data): d = dict(data) - for idx, key in enumerate(self.keys): - ret = self._lambd(d[key], func=self.func[idx]) - if self.overwrite[idx]: + for key, func, overwrite in self.key_iterator(d, self.func, self.overwrite): + ret = self._lambd(d[key], func=func) + if overwrite: d[key] = ret return d @@ -745,6 +773,7 @@ class LabelToMaskd(MapTransform): `select_labels` is the expected channel indices. merge_channels: whether to use `np.any()` to merge the result on channel dim. if yes, will return a single channel mask with binary data. + allow_missing_keys: don't raise exception if key is missing. """ @@ -753,13 +782,14 @@ def __init__( # pytype: disable=annotation-type-mismatch keys: KeysCollection, select_labels: Union[Sequence[int], int], merge_channels: bool = False, + allow_missing_keys: bool = False, ) -> None: # pytype: disable=annotation-type-mismatch - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.converter = LabelToMask(select_labels=select_labels, merge_channels=merge_channels) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.converter(d[key]) return d @@ -781,6 +811,7 @@ class FgBgToIndicesd(MapTransform): image_threshold: if enabled image_key, use ``image > image_threshold`` to determine the valid image content area and select background only in this area. output_shape: expected shape of output indices. if not None, unravel indices to specified shape. + allow_missing_keys: don't raise exception if key is missing. """ @@ -792,8 +823,9 @@ def __init__( image_key: Optional[str] = None, image_threshold: float = 0.0, output_shape: Optional[Sequence[int]] = None, + allow_missing_keys: bool = False, ) -> None: - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.fg_postfix = fg_postfix self.bg_postfix = bg_postfix self.image_key = image_key @@ -802,7 +834,7 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) image = d[self.image_key] if self.image_key else None - for key in self.keys: + for key in self.key_iterator(d): d[str(key) + self.fg_postfix], d[str(key) + self.bg_postfix] = self.converter(d[key], image) return d @@ -819,13 +851,13 @@ class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform): and ET (Enhancing tumor). """ - def __init__(self, keys: KeysCollection): - super().__init__(keys) + def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False): + super().__init__(keys, allow_missing_keys) self.converter = ConvertToMultiChannelBasedOnBratsClasses() def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.converter(d[key]) return d @@ -845,6 +877,7 @@ class AddExtremePointsChanneld(RandomizableTransform, MapTransform): use it for all spatial dimensions. rescale_min: minimum value of output data. rescale_max: maximum value of output data. + allow_missing_keys: don't raise exception if key is missing. """ @@ -857,8 +890,9 @@ def __init__( sigma: Union[Sequence[float], float, Sequence[torch.Tensor], torch.Tensor] = 3.0, rescale_min: float = -1.0, rescale_max: float = 1.0, + allow_missing_keys: bool = False, ): - MapTransform.__init__(self, keys) + MapTransform.__init__(self, keys, allow_missing_keys) self.background = background self.pert = pert self.points: List[Tuple[int, ...]] = [] @@ -879,17 +913,16 @@ def __call__(self, data): # Generate extreme points self.randomize(label[0, :]) - for key in data.keys(): - if key in self.keys: - img = d[key] - points_image = extreme_points_to_image( - points=self.points, - label=label, - sigma=self.sigma, - rescale_min=self.rescale_min, - rescale_max=self.rescale_max, - ) - d[key] = np.concatenate([img, points_image], axis=0) + for key in self.key_iterator(d): + img = d[key] + points_image = extreme_points_to_image( + points=self.points, + label=label, + sigma=self.sigma, + rescale_min=self.rescale_min, + rescale_max=self.rescale_max, + ) + d[key] = np.concatenate([img, points_image], axis=0) return d @@ -900,22 +933,23 @@ class TorchVisiond(MapTransform): data to be dict of PyTorch Tensors, users can easily call `ToTensord` transform to convert Numpy to Tensor. """ - def __init__(self, keys: KeysCollection, name: str, *args, **kwargs) -> None: + def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False, *args, **kwargs) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` name: The transform name in TorchVision package. + allow_missing_keys: don't raise exception if key is missing. args: parameters for the TorchVision transform. kwargs: parameters for the TorchVision transform. """ - super().__init__(keys) + super().__init__(keys, allow_missing_keys) self.trans = TorchVision(name, *args, **kwargs) def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) - for key in self.keys: + for key in self.key_iterator(d): d[key] = self.trans(d[key]) return d