From e0cda55d4ad86efc74e912c023d9bcc5f6d12608 Mon Sep 17 00:00:00 2001 From: Lukas Folle Date: Fri, 8 Aug 2025 09:19:36 +0200 Subject: [PATCH 01/16] added list extend to MultiSampleTrait --- monai/transforms/transform.py | 102 +++++++++++++++++++++++++--------- 1 file changed, 77 insertions(+), 25 deletions(-) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 1a365b8d8e..e7b8268432 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -90,12 +90,22 @@ def _apply_transform( """ from monai.transforms.lazy.functional import apply_pending_transforms_in_order - data = apply_pending_transforms_in_order(transform, data, lazy, overrides, logger_name) + data = apply_pending_transforms_in_order( + transform, data, lazy, overrides, logger_name + ) if isinstance(data, tuple) and unpack_parameters: - return transform(*data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(*data) + return ( + transform(*data, lazy=lazy) + if isinstance(transform, LazyTrait) + else transform(*data) + ) - return transform(data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(data) + return ( + transform(data, lazy=lazy) + if isinstance(transform, LazyTrait) + else transform(data) + ) def apply_transform( @@ -143,31 +153,49 @@ def apply_transform( try: map_items_ = int(map_items) if isinstance(map_items, bool) else map_items if isinstance(data, (list, tuple)) and map_items_ > 0: - return [ - apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides) - for item in data - ] - return _apply_transform(transform, data, unpack_items, lazy, overrides, log_stats) + res = [] + for item in data: + res_item = _apply_transform( + transform, item, unpack_items, lazy, overrides, log_stats + ) + if isinstance(res_item, list | tuple): + res.extend(res_item) + else: + res.append(res_item) + return res + return _apply_transform( + transform, data, unpack_items, lazy, overrides, log_stats + ) except Exception as e: # if in debug mode, don't swallow exception so that the breakpoint # appears where the exception was raised. if MONAIEnvVars.debug(): raise - if log_stats is not False and not isinstance(transform, transforms.compose.Compose): + if log_stats is not False and not isinstance( + transform, transforms.compose.Compose + ): # log the input data information of exact transform in the transform chain if isinstance(log_stats, str): - datastats = transforms.utility.array.DataStats(data_shape=False, value_range=False, name=log_stats) + datastats = transforms.utility.array.DataStats( + data_shape=False, value_range=False, name=log_stats + ) else: - datastats = transforms.utility.array.DataStats(data_shape=False, value_range=False) + datastats = transforms.utility.array.DataStats( + data_shape=False, value_range=False + ) logger = logging.getLogger(datastats._logger_name) - logger.error(f"\n=== Transform input info -- {type(transform).__name__} ===") + logger.error( + f"\n=== Transform input info -- {type(transform).__name__} ===" + ) if isinstance(data, (list, tuple)): data = data[0] def _log_stats(data, prefix: str | None = "Data"): if isinstance(data, (np.ndarray, torch.Tensor)): # log data type, shape, range for array - datastats(img=data, data_shape=True, value_range=True, prefix=prefix) + datastats( + img=data, data_shape=True, value_range=True, prefix=prefix + ) else: # log data type and value for other metadata datastats(img=data, data_value=True, prefix=prefix) @@ -194,7 +222,9 @@ class Randomizable(ThreadUnsafe, RandomizableTrait): R: np.random.RandomState = np.random.RandomState() - def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> Randomizable: + def set_random_state( + self, seed: int | None = None, state: np.random.RandomState | None = None + ) -> Randomizable: """ Set the random state locally, to control the randomness, the derived classes should use :py:attr:`self.R` instead of `np.random` to introduce random @@ -212,14 +242,20 @@ def set_random_state(self, seed: int | None = None, state: np.random.RandomState """ if seed is not None: - _seed = np.int64(id(seed) if not isinstance(seed, (int, np.integer)) else seed) - _seed = _seed % MAX_SEED # need to account for Numpy2.0 which doesn't silently convert to int64 + _seed = np.int64( + id(seed) if not isinstance(seed, (int, np.integer)) else seed + ) + _seed = ( + _seed % MAX_SEED + ) # need to account for Numpy2.0 which doesn't silently convert to int64 self.R = np.random.RandomState(_seed) return self if state is not None: if not isinstance(state, np.random.RandomState): - raise TypeError(f"state must be None or a np.random.RandomState but is {type(state).__name__}.") + raise TypeError( + f"state must be None or a np.random.RandomState but is {type(state).__name__}." + ) self.R = state return self @@ -238,7 +274,9 @@ def randomize(self, data: Any) -> None: Raises: NotImplementedError: When the subclass does not override this method. """ - raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + raise NotImplementedError( + f"Subclass {self.__class__.__name__} must implement this method." + ) class Transform(ABC): @@ -294,7 +332,9 @@ def __call__(self, data: Any): NotImplementedError: When the subclass does not override this method. """ - raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + raise NotImplementedError( + f"Subclass {self.__class__.__name__} must implement this method." + ) class LazyTransform(Transform, LazyTrait): @@ -397,11 +437,15 @@ def __call__(self, data): def __new__(cls, *args, **kwargs): if config.USE_META_DICT: # call_update after MapTransform.__call__ - cls.__call__ = transforms.attach_hook(cls.__call__, MapTransform.call_update, "post") # type: ignore + cls.__call__ = transforms.attach_hook( + cls.__call__, MapTransform.call_update, "post" + ) # type: ignore if hasattr(cls, "inverse"): # inverse_update before InvertibleTransform.inverse - cls.inverse: Any = transforms.attach_hook(cls.inverse, transforms.InvertibleTransform.inverse_update) + cls.inverse: Any = transforms.attach_hook( + cls.inverse, transforms.InvertibleTransform.inverse_update + ) return Transform.__new__(cls) def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: @@ -412,7 +456,9 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No raise ValueError("keys must be non empty.") for key in self.keys: if not isinstance(key, Hashable): - raise TypeError(f"keys must be one of (Hashable, Iterable[Hashable]) but is {type(keys).__name__}.") + raise TypeError( + f"keys must be one of (Hashable, Iterable[Hashable]) but is {type(keys).__name__}." + ) def call_update(self, data): """ @@ -432,7 +478,9 @@ def call_update(self, data): for k in dict_i: if not isinstance(dict_i[k], MetaTensor): continue - list_d[idx] = transforms.sync_meta_info(k, dict_i, t=not isinstance(self, transforms.InvertD)) + list_d[idx] = transforms.sync_meta_info( + k, dict_i, t=not isinstance(self, transforms.InvertD) + ) return list_d[0] if is_dict else list_d @abstractmethod @@ -460,9 +508,13 @@ def __call__(self, data): An updated dictionary version of ``data`` by applying the transform. """ - raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + raise NotImplementedError( + f"Subclass {self.__class__.__name__} must implement this method." + ) - def key_iterator(self, data: Mapping[Hashable, Any], *extra_iterables: Iterable | None) -> Generator: + def key_iterator( + self, data: Mapping[Hashable, Any], *extra_iterables: Iterable | None + ) -> Generator: """ Iterate across keys and optionally extra iterables. If key is missing, exception is raised if `allow_missing_keys==False` (default). If `allow_missing_keys==True`, key is skipped. From 1ad24af4315caba871b1bc1604951518755fa784 Mon Sep 17 00:00:00 2001 From: Lukas Folle Date: Fri, 8 Aug 2025 09:19:36 +0200 Subject: [PATCH 02/16] DCO Remediation Commit for Lukas Folle I, Lukas Folle , hereby add my Signed-off-by to this commit: e0cda55d4ad86efc74e912c023d9bcc5f6d12608 Signed-off-by: Lukas Folle --- monai/transforms/transform.py | 102 +++++++++++++++++++++++++--------- 1 file changed, 77 insertions(+), 25 deletions(-) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 1a365b8d8e..e7b8268432 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -90,12 +90,22 @@ def _apply_transform( """ from monai.transforms.lazy.functional import apply_pending_transforms_in_order - data = apply_pending_transforms_in_order(transform, data, lazy, overrides, logger_name) + data = apply_pending_transforms_in_order( + transform, data, lazy, overrides, logger_name + ) if isinstance(data, tuple) and unpack_parameters: - return transform(*data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(*data) + return ( + transform(*data, lazy=lazy) + if isinstance(transform, LazyTrait) + else transform(*data) + ) - return transform(data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(data) + return ( + transform(data, lazy=lazy) + if isinstance(transform, LazyTrait) + else transform(data) + ) def apply_transform( @@ -143,31 +153,49 @@ def apply_transform( try: map_items_ = int(map_items) if isinstance(map_items, bool) else map_items if isinstance(data, (list, tuple)) and map_items_ > 0: - return [ - apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides) - for item in data - ] - return _apply_transform(transform, data, unpack_items, lazy, overrides, log_stats) + res = [] + for item in data: + res_item = _apply_transform( + transform, item, unpack_items, lazy, overrides, log_stats + ) + if isinstance(res_item, list | tuple): + res.extend(res_item) + else: + res.append(res_item) + return res + return _apply_transform( + transform, data, unpack_items, lazy, overrides, log_stats + ) except Exception as e: # if in debug mode, don't swallow exception so that the breakpoint # appears where the exception was raised. if MONAIEnvVars.debug(): raise - if log_stats is not False and not isinstance(transform, transforms.compose.Compose): + if log_stats is not False and not isinstance( + transform, transforms.compose.Compose + ): # log the input data information of exact transform in the transform chain if isinstance(log_stats, str): - datastats = transforms.utility.array.DataStats(data_shape=False, value_range=False, name=log_stats) + datastats = transforms.utility.array.DataStats( + data_shape=False, value_range=False, name=log_stats + ) else: - datastats = transforms.utility.array.DataStats(data_shape=False, value_range=False) + datastats = transforms.utility.array.DataStats( + data_shape=False, value_range=False + ) logger = logging.getLogger(datastats._logger_name) - logger.error(f"\n=== Transform input info -- {type(transform).__name__} ===") + logger.error( + f"\n=== Transform input info -- {type(transform).__name__} ===" + ) if isinstance(data, (list, tuple)): data = data[0] def _log_stats(data, prefix: str | None = "Data"): if isinstance(data, (np.ndarray, torch.Tensor)): # log data type, shape, range for array - datastats(img=data, data_shape=True, value_range=True, prefix=prefix) + datastats( + img=data, data_shape=True, value_range=True, prefix=prefix + ) else: # log data type and value for other metadata datastats(img=data, data_value=True, prefix=prefix) @@ -194,7 +222,9 @@ class Randomizable(ThreadUnsafe, RandomizableTrait): R: np.random.RandomState = np.random.RandomState() - def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> Randomizable: + def set_random_state( + self, seed: int | None = None, state: np.random.RandomState | None = None + ) -> Randomizable: """ Set the random state locally, to control the randomness, the derived classes should use :py:attr:`self.R` instead of `np.random` to introduce random @@ -212,14 +242,20 @@ def set_random_state(self, seed: int | None = None, state: np.random.RandomState """ if seed is not None: - _seed = np.int64(id(seed) if not isinstance(seed, (int, np.integer)) else seed) - _seed = _seed % MAX_SEED # need to account for Numpy2.0 which doesn't silently convert to int64 + _seed = np.int64( + id(seed) if not isinstance(seed, (int, np.integer)) else seed + ) + _seed = ( + _seed % MAX_SEED + ) # need to account for Numpy2.0 which doesn't silently convert to int64 self.R = np.random.RandomState(_seed) return self if state is not None: if not isinstance(state, np.random.RandomState): - raise TypeError(f"state must be None or a np.random.RandomState but is {type(state).__name__}.") + raise TypeError( + f"state must be None or a np.random.RandomState but is {type(state).__name__}." + ) self.R = state return self @@ -238,7 +274,9 @@ def randomize(self, data: Any) -> None: Raises: NotImplementedError: When the subclass does not override this method. """ - raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + raise NotImplementedError( + f"Subclass {self.__class__.__name__} must implement this method." + ) class Transform(ABC): @@ -294,7 +332,9 @@ def __call__(self, data: Any): NotImplementedError: When the subclass does not override this method. """ - raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + raise NotImplementedError( + f"Subclass {self.__class__.__name__} must implement this method." + ) class LazyTransform(Transform, LazyTrait): @@ -397,11 +437,15 @@ def __call__(self, data): def __new__(cls, *args, **kwargs): if config.USE_META_DICT: # call_update after MapTransform.__call__ - cls.__call__ = transforms.attach_hook(cls.__call__, MapTransform.call_update, "post") # type: ignore + cls.__call__ = transforms.attach_hook( + cls.__call__, MapTransform.call_update, "post" + ) # type: ignore if hasattr(cls, "inverse"): # inverse_update before InvertibleTransform.inverse - cls.inverse: Any = transforms.attach_hook(cls.inverse, transforms.InvertibleTransform.inverse_update) + cls.inverse: Any = transforms.attach_hook( + cls.inverse, transforms.InvertibleTransform.inverse_update + ) return Transform.__new__(cls) def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: @@ -412,7 +456,9 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No raise ValueError("keys must be non empty.") for key in self.keys: if not isinstance(key, Hashable): - raise TypeError(f"keys must be one of (Hashable, Iterable[Hashable]) but is {type(keys).__name__}.") + raise TypeError( + f"keys must be one of (Hashable, Iterable[Hashable]) but is {type(keys).__name__}." + ) def call_update(self, data): """ @@ -432,7 +478,9 @@ def call_update(self, data): for k in dict_i: if not isinstance(dict_i[k], MetaTensor): continue - list_d[idx] = transforms.sync_meta_info(k, dict_i, t=not isinstance(self, transforms.InvertD)) + list_d[idx] = transforms.sync_meta_info( + k, dict_i, t=not isinstance(self, transforms.InvertD) + ) return list_d[0] if is_dict else list_d @abstractmethod @@ -460,9 +508,13 @@ def __call__(self, data): An updated dictionary version of ``data`` by applying the transform. """ - raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + raise NotImplementedError( + f"Subclass {self.__class__.__name__} must implement this method." + ) - def key_iterator(self, data: Mapping[Hashable, Any], *extra_iterables: Iterable | None) -> Generator: + def key_iterator( + self, data: Mapping[Hashable, Any], *extra_iterables: Iterable | None + ) -> Generator: """ Iterate across keys and optionally extra iterables. If key is missing, exception is raised if `allow_missing_keys==False` (default). If `allow_missing_keys==True`, key is skipped. From eeb7e12d604a4c46f5b44d7484b40aa9cac6e9d3 Mon Sep 17 00:00:00 2001 From: Lukas Folle Date: Fri, 8 Aug 2025 09:35:57 +0200 Subject: [PATCH 03/16] fixed type errors --- monai/transforms/transform.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index e7b8268432..73c4093792 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -153,12 +153,12 @@ def apply_transform( try: map_items_ = int(map_items) if isinstance(map_items, bool) else map_items if isinstance(data, (list, tuple)) and map_items_ > 0: - res = [] + res: list[ReturnType] = [] for item in data: res_item = _apply_transform( transform, item, unpack_items, lazy, overrides, log_stats ) - if isinstance(res_item, list | tuple): + if isinstance(res_item, (list, tuple)): res.extend(res_item) else: res.append(res_item) From c011103a67f54b7994f035d61dc4edc6e1fefb5a Mon Sep 17 00:00:00 2001 From: Lukas Folle Date: Fri, 8 Aug 2025 09:35:57 +0200 Subject: [PATCH 04/16] DCO Remediation Commit for Lukas Folle I, Lukas Folle , hereby add my Signed-off-by to this commit: e0cda55d4ad86efc74e912c023d9bcc5f6d12608 Signed-off-by: Lukas Folle --- monai/transforms/transform.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index e7b8268432..73c4093792 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -153,12 +153,12 @@ def apply_transform( try: map_items_ = int(map_items) if isinstance(map_items, bool) else map_items if isinstance(data, (list, tuple)) and map_items_ > 0: - res = [] + res: list[ReturnType] = [] for item in data: res_item = _apply_transform( transform, item, unpack_items, lazy, overrides, log_stats ) - if isinstance(res_item, list | tuple): + if isinstance(res_item, (list, tuple)): res.extend(res_item) else: res.append(res_item) From 77c138d4751826f810a674744c9f949ee48e6f0d Mon Sep 17 00:00:00 2001 From: Lukas Folle Date: Fri, 8 Aug 2025 09:54:20 +0200 Subject: [PATCH 05/16] DCO Remediation Commit for Lukas Folle I, Lukas Folle , hereby add my Signed-off-by to this commit: e0cda55d4ad86efc74e912c023d9bcc5f6d12608 Signed-off-by: Lukas Folle --- monai/transforms/transform.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 73c4093792..0c9d8c3cdf 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -155,9 +155,7 @@ def apply_transform( if isinstance(data, (list, tuple)) and map_items_ > 0: res: list[ReturnType] = [] for item in data: - res_item = _apply_transform( - transform, item, unpack_items, lazy, overrides, log_stats - ) + res_item = _apply_transform(transform, item, unpack_items, lazy, overrides, log_stats) if isinstance(res_item, (list, tuple)): res.extend(res_item) else: From 7df8cb919c43f8343c76b2dd750c3ad832ccdf1b Mon Sep 17 00:00:00 2001 From: Lukas Folle Date: Fri, 8 Aug 2025 11:46:18 +0200 Subject: [PATCH 06/16] avoided breaking map_item functionality DCO Remediation Commit for Lukas Folle I, Lukas Folle , hereby add my Signed-off-by to this commit: e0cda55d4ad86efc74e912c023d9bcc5f6d12608 Signed-off-by: Lukas Folle --- monai/transforms/transform.py | 102 ++++++++++------------------------ 1 file changed, 30 insertions(+), 72 deletions(-) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 0c9d8c3cdf..65ef429e33 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -90,22 +90,12 @@ def _apply_transform( """ from monai.transforms.lazy.functional import apply_pending_transforms_in_order - data = apply_pending_transforms_in_order( - transform, data, lazy, overrides, logger_name - ) + data = apply_pending_transforms_in_order(transform, data, lazy, overrides, logger_name) if isinstance(data, tuple) and unpack_parameters: - return ( - transform(*data, lazy=lazy) - if isinstance(transform, LazyTrait) - else transform(*data) - ) + return transform(*data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(*data) - return ( - transform(data, lazy=lazy) - if isinstance(transform, LazyTrait) - else transform(data) - ) + return transform(data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(data) def apply_transform( @@ -155,45 +145,38 @@ def apply_transform( if isinstance(data, (list, tuple)) and map_items_ > 0: res: list[ReturnType] = [] for item in data: - res_item = _apply_transform(transform, item, unpack_items, lazy, overrides, log_stats) - if isinstance(res_item, (list, tuple)): - res.extend(res_item) + res_item = apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides) + # Only extend if we're at the leaf level (map_items_ == 1) and the transform + # actually returned a list (not preserving nested structure) + if isinstance(res_item, list) and map_items_ == 1: + if not isinstance(item, (list, tuple)): + res.extend(res_item) + else: + res.append(res_item) else: res.append(res_item) return res - return _apply_transform( - transform, data, unpack_items, lazy, overrides, log_stats - ) + return _apply_transform(transform, data, unpack_items, lazy, overrides, log_stats) except Exception as e: # if in debug mode, don't swallow exception so that the breakpoint # appears where the exception was raised. if MONAIEnvVars.debug(): raise - if log_stats is not False and not isinstance( - transform, transforms.compose.Compose - ): + if log_stats is not False and not isinstance(transform, transforms.compose.Compose): # log the input data information of exact transform in the transform chain if isinstance(log_stats, str): - datastats = transforms.utility.array.DataStats( - data_shape=False, value_range=False, name=log_stats - ) + datastats = transforms.utility.array.DataStats(data_shape=False, value_range=False, name=log_stats) else: - datastats = transforms.utility.array.DataStats( - data_shape=False, value_range=False - ) + datastats = transforms.utility.array.DataStats(data_shape=False, value_range=False) logger = logging.getLogger(datastats._logger_name) - logger.error( - f"\n=== Transform input info -- {type(transform).__name__} ===" - ) + logger.error(f"\n=== Transform input info -- {type(transform).__name__} ===") if isinstance(data, (list, tuple)): data = data[0] def _log_stats(data, prefix: str | None = "Data"): if isinstance(data, (np.ndarray, torch.Tensor)): # log data type, shape, range for array - datastats( - img=data, data_shape=True, value_range=True, prefix=prefix - ) + datastats(img=data, data_shape=True, value_range=True, prefix=prefix) else: # log data type and value for other metadata datastats(img=data, data_value=True, prefix=prefix) @@ -220,9 +203,7 @@ class Randomizable(ThreadUnsafe, RandomizableTrait): R: np.random.RandomState = np.random.RandomState() - def set_random_state( - self, seed: int | None = None, state: np.random.RandomState | None = None - ) -> Randomizable: + def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> Randomizable: """ Set the random state locally, to control the randomness, the derived classes should use :py:attr:`self.R` instead of `np.random` to introduce random @@ -240,20 +221,14 @@ def set_random_state( """ if seed is not None: - _seed = np.int64( - id(seed) if not isinstance(seed, (int, np.integer)) else seed - ) - _seed = ( - _seed % MAX_SEED - ) # need to account for Numpy2.0 which doesn't silently convert to int64 + _seed = np.int64(id(seed) if not isinstance(seed, (int, np.integer)) else seed) + _seed = _seed % MAX_SEED # need to account for Numpy2.0 which doesn't silently convert to int64 self.R = np.random.RandomState(_seed) return self if state is not None: if not isinstance(state, np.random.RandomState): - raise TypeError( - f"state must be None or a np.random.RandomState but is {type(state).__name__}." - ) + raise TypeError(f"state must be None or a np.random.RandomState but is {type(state).__name__}.") self.R = state return self @@ -272,9 +247,7 @@ def randomize(self, data: Any) -> None: Raises: NotImplementedError: When the subclass does not override this method. """ - raise NotImplementedError( - f"Subclass {self.__class__.__name__} must implement this method." - ) + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") class Transform(ABC): @@ -330,9 +303,7 @@ def __call__(self, data: Any): NotImplementedError: When the subclass does not override this method. """ - raise NotImplementedError( - f"Subclass {self.__class__.__name__} must implement this method." - ) + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") class LazyTransform(Transform, LazyTrait): @@ -435,15 +406,11 @@ def __call__(self, data): def __new__(cls, *args, **kwargs): if config.USE_META_DICT: # call_update after MapTransform.__call__ - cls.__call__ = transforms.attach_hook( - cls.__call__, MapTransform.call_update, "post" - ) # type: ignore + cls.__call__ = transforms.attach_hook(cls.__call__, MapTransform.call_update, "post") # type: ignore if hasattr(cls, "inverse"): # inverse_update before InvertibleTransform.inverse - cls.inverse: Any = transforms.attach_hook( - cls.inverse, transforms.InvertibleTransform.inverse_update - ) + cls.inverse: Any = transforms.attach_hook(cls.inverse, transforms.InvertibleTransform.inverse_update) return Transform.__new__(cls) def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: @@ -454,9 +421,7 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No raise ValueError("keys must be non empty.") for key in self.keys: if not isinstance(key, Hashable): - raise TypeError( - f"keys must be one of (Hashable, Iterable[Hashable]) but is {type(keys).__name__}." - ) + raise TypeError(f"keys must be one of (Hashable, Iterable[Hashable]) but is {type(keys).__name__}.") def call_update(self, data): """ @@ -476,9 +441,7 @@ def call_update(self, data): for k in dict_i: if not isinstance(dict_i[k], MetaTensor): continue - list_d[idx] = transforms.sync_meta_info( - k, dict_i, t=not isinstance(self, transforms.InvertD) - ) + list_d[idx] = transforms.sync_meta_info(k, dict_i, t=not isinstance(self, transforms.InvertD)) return list_d[0] if is_dict else list_d @abstractmethod @@ -506,13 +469,9 @@ def __call__(self, data): An updated dictionary version of ``data`` by applying the transform. """ - raise NotImplementedError( - f"Subclass {self.__class__.__name__} must implement this method." - ) + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") - def key_iterator( - self, data: Mapping[Hashable, Any], *extra_iterables: Iterable | None - ) -> Generator: + def key_iterator(self, data: Mapping[Hashable, Any], *extra_iterables: Iterable | None) -> Generator: """ Iterate across keys and optionally extra iterables. If key is missing, exception is raised if `allow_missing_keys==False` (default). If `allow_missing_keys==True`, key is skipped. @@ -532,8 +491,7 @@ def key_iterator( yield (key,) + tuple(_ex_iters) if extra_iterables else key elif not self.allow_missing_keys: raise KeyError( - f"Key `{key}` of transform `{self.__class__.__name__}` was missing in the data" - " and allow_missing_keys==False." + f"Key `{key}` of transform `{self.__class__.__name__}` was missing in the data and allow_missing_keys==False." ) def first_key(self, data: dict[Hashable, Any]): From be4601826787f1991b34369bb3a678d5da252c55 Mon Sep 17 00:00:00 2001 From: Lukas Folle Date: Fri, 8 Aug 2025 13:26:34 +0200 Subject: [PATCH 07/16] fixed wrong type annotation DCO Remediation Commit for Lukas Folle I, Lukas Folle , hereby add my Signed-off-by to this commit: e0cda55d4ad86efc74e912c023d9bcc5f6d12608 Signed-off-by: Lukas Folle --- monai/transforms/transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 65ef429e33..d9a16d53e7 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -143,7 +143,7 @@ def apply_transform( try: map_items_ = int(map_items) if isinstance(map_items, bool) else map_items if isinstance(data, (list, tuple)) and map_items_ > 0: - res: list[ReturnType] = [] + res: list[Any] = [] for item in data: res_item = apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides) # Only extend if we're at the leaf level (map_items_ == 1) and the transform From 2d5877455022444675e200afee36077a709fa784 Mon Sep 17 00:00:00 2001 From: Lukas Folle Date: Tue, 16 Sep 2025 11:36:53 +0200 Subject: [PATCH 08/16] added test for many multisample transforms; refactored code --- monai/transforms/transform.py | 7 ++----- tests/transforms/compose/test_compose.py | 22 ++++++++++++++++++++++ 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index d9a16d53e7..05a08e0743 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -148,11 +148,8 @@ def apply_transform( res_item = apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides) # Only extend if we're at the leaf level (map_items_ == 1) and the transform # actually returned a list (not preserving nested structure) - if isinstance(res_item, list) and map_items_ == 1: - if not isinstance(item, (list, tuple)): - res.extend(res_item) - else: - res.append(res_item) + if isinstance(res_item, list) and map_items_ == 1 and not isinstance(item, (list, tuple)): + res.extend(res_item) else: res.append(res_item) return res diff --git a/tests/transforms/compose/test_compose.py b/tests/transforms/compose/test_compose.py index e6727c976f..9abf635b13 100644 --- a/tests/transforms/compose/test_compose.py +++ b/tests/transforms/compose/test_compose.py @@ -282,6 +282,28 @@ def test_flatten_and_len(self): def test_backwards_compatible_imports(self): from monai.transforms.transform import MapTransform, RandomizableTransform, Transform # noqa: F401 + def test_list_extend_multi_sample_trait(self): + from monai.transforms import CenterSpatialCrop, RandSpatialCropSamples + + center_crop = CenterSpatialCrop([128, 128]) + multi_sample_transform = RandSpatialCropSamples([64, 64], 1) + + img = torch.zeros([1, 512, 512]) + + assert execute_compose(img, [center_crop]).shape == torch.Size([1, 128, 128]) + single_multi_sample_trait_result = execute_compose(img, [multi_sample_transform, center_crop]) + assert ( + isinstance(single_multi_sample_trait_result, list) + and len(single_multi_sample_trait_result) == 1 + and single_multi_sample_trait_result[0].shape == torch.Size([1, 64, 64]) + ) + double_multi_sample_trait_result = execute_compose(img, [multi_sample_transform, multi_sample_transform, center_crop]) + assert ( + isinstance(double_multi_sample_trait_result, list) + and len(double_multi_sample_trait_result) == 1 + and double_multi_sample_trait_result[0].shape == torch.Size([1, 64, 64]) + ) + TEST_COMPOSE_EXECUTE_TEST_CASES = [ [None, tuple()], From ee74761cb623688b7d8a7a8100f78dcfdd16e365 Mon Sep 17 00:00:00 2001 From: Lukas Folle Date: Tue, 16 Sep 2025 11:36:53 +0200 Subject: [PATCH 09/16] DCO Remediation Commit for Lukas Folle I, Lukas Folle , hereby add my Signed-off-by to this commit: 2d5877455022444675e200afee36077a709fa784 Signed-off-by: Lukas Folle --- monai/transforms/transform.py | 7 ++----- tests/transforms/compose/test_compose.py | 22 ++++++++++++++++++++++ 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index d9a16d53e7..05a08e0743 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -148,11 +148,8 @@ def apply_transform( res_item = apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides) # Only extend if we're at the leaf level (map_items_ == 1) and the transform # actually returned a list (not preserving nested structure) - if isinstance(res_item, list) and map_items_ == 1: - if not isinstance(item, (list, tuple)): - res.extend(res_item) - else: - res.append(res_item) + if isinstance(res_item, list) and map_items_ == 1 and not isinstance(item, (list, tuple)): + res.extend(res_item) else: res.append(res_item) return res diff --git a/tests/transforms/compose/test_compose.py b/tests/transforms/compose/test_compose.py index e6727c976f..9abf635b13 100644 --- a/tests/transforms/compose/test_compose.py +++ b/tests/transforms/compose/test_compose.py @@ -282,6 +282,28 @@ def test_flatten_and_len(self): def test_backwards_compatible_imports(self): from monai.transforms.transform import MapTransform, RandomizableTransform, Transform # noqa: F401 + def test_list_extend_multi_sample_trait(self): + from monai.transforms import CenterSpatialCrop, RandSpatialCropSamples + + center_crop = CenterSpatialCrop([128, 128]) + multi_sample_transform = RandSpatialCropSamples([64, 64], 1) + + img = torch.zeros([1, 512, 512]) + + assert execute_compose(img, [center_crop]).shape == torch.Size([1, 128, 128]) + single_multi_sample_trait_result = execute_compose(img, [multi_sample_transform, center_crop]) + assert ( + isinstance(single_multi_sample_trait_result, list) + and len(single_multi_sample_trait_result) == 1 + and single_multi_sample_trait_result[0].shape == torch.Size([1, 64, 64]) + ) + double_multi_sample_trait_result = execute_compose(img, [multi_sample_transform, multi_sample_transform, center_crop]) + assert ( + isinstance(double_multi_sample_trait_result, list) + and len(double_multi_sample_trait_result) == 1 + and double_multi_sample_trait_result[0].shape == torch.Size([1, 64, 64]) + ) + TEST_COMPOSE_EXECUTE_TEST_CASES = [ [None, tuple()], From 9377b63faf9cbd16d35990532df12ea94b61c6e6 Mon Sep 17 00:00:00 2001 From: Lukas Folle Date: Tue, 16 Sep 2025 11:51:08 +0200 Subject: [PATCH 10/16] added slight cleanup and additional test DCO Remediation Commit for Lukas Folle I, Lukas Folle , hereby add my Signed-off-by to this commit: 2d5877455022444675e200afee36077a709fa784 Signed-off-by: Lukas Folle --- tests/transforms/compose/test_compose.py | 36 ++++++++++++++---------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/tests/transforms/compose/test_compose.py b/tests/transforms/compose/test_compose.py index 9abf635b13..01c7d92e7d 100644 --- a/tests/transforms/compose/test_compose.py +++ b/tests/transforms/compose/test_compose.py @@ -283,26 +283,32 @@ def test_backwards_compatible_imports(self): from monai.transforms.transform import MapTransform, RandomizableTransform, Transform # noqa: F401 def test_list_extend_multi_sample_trait(self): - from monai.transforms import CenterSpatialCrop, RandSpatialCropSamples - - center_crop = CenterSpatialCrop([128, 128]) - multi_sample_transform = RandSpatialCropSamples([64, 64], 1) + center_crop = mt.CenterSpatialCrop([128, 128]) + multi_sample_transform = mt.RandSpatialCropSamples([64, 64], 1) img = torch.zeros([1, 512, 512]) - assert execute_compose(img, [center_crop]).shape == torch.Size([1, 128, 128]) + self.assertEqual(execute_compose(img, [center_crop]).shape, torch.Size([1, 128, 128])) single_multi_sample_trait_result = execute_compose(img, [multi_sample_transform, center_crop]) - assert ( - isinstance(single_multi_sample_trait_result, list) - and len(single_multi_sample_trait_result) == 1 - and single_multi_sample_trait_result[0].shape == torch.Size([1, 64, 64]) - ) + self.assertIsInstance(single_multi_sample_trait_result, list) + self.assertEqual(len(single_multi_sample_trait_result), 1) + self.assertEqual(single_multi_sample_trait_result[0].shape, torch.Size([1, 64, 64])) + double_multi_sample_trait_result = execute_compose(img, [multi_sample_transform, multi_sample_transform, center_crop]) - assert ( - isinstance(double_multi_sample_trait_result, list) - and len(double_multi_sample_trait_result) == 1 - and double_multi_sample_trait_result[0].shape == torch.Size([1, 64, 64]) - ) + self.assertIsInstance(double_multi_sample_trait_result, list) + self.assertEqual(len(double_multi_sample_trait_result), 1) + self.assertEqual(double_multi_sample_trait_result[0].shape, torch.Size([1, 64, 64])) + + def test_multi_sample_trait_cardinality(self): + img = torch.zeros([1, 128, 128]) + t2 = mt.RandSpatialCropSamples([32, 32], num_samples=2) + + # chaining should multiply counts: 2 x 2 = 4, flattened + res = execute_compose(img, [t2, t2]) + self.assertIsInstance(res, list) + self.assertEqual(len(res), 4) + for r in res: + self.assertEqual(r.shape, torch.Size([1, 32, 32])) TEST_COMPOSE_EXECUTE_TEST_CASES = [ From 5135fb42af4a11b6ed2fd38a4907caae430781df Mon Sep 17 00:00:00 2001 From: Lukas Folle Date: Fri, 10 Oct 2025 09:08:53 +0200 Subject: [PATCH 11/16] changed compose to explicit flattening Signed-off-by: Lukas Folle --- monai/transforms/__init__.py | 4 ++++ monai/transforms/traits.py | 13 ++++++++++++- monai/transforms/transform.py | 18 ++++++------------ monai/transforms/utility/array.py | 16 +++++++++++++++- monai/transforms/utility/dictionary.py | 24 +++++++++++++++++++++++- tests/transforms/compose/test_compose.py | 10 +++++++--- 6 files changed, 67 insertions(+), 18 deletions(-) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index d15042181b..5f4c2d9289 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -546,6 +546,7 @@ TorchVision, ToTensor, Transpose, + FlattenSequence, ) from .utility.dictionary import ( AddCoordinateChannelsd, @@ -671,6 +672,9 @@ Transposed, TransposeD, TransposeDict, + FlattenSequenced, + FlattenSequenceD, + FlattenSequenceDict, ) from .utils import ( Fourier, diff --git a/monai/transforms/traits.py b/monai/transforms/traits.py index 016effc59d..45d081f2e6 100644 --- a/monai/transforms/traits.py +++ b/monai/transforms/traits.py @@ -14,7 +14,7 @@ from __future__ import annotations -__all__ = ["LazyTrait", "InvertibleTrait", "RandomizableTrait", "MultiSampleTrait", "ThreadUnsafe"] +__all__ = ["LazyTrait", "InvertibleTrait", "RandomizableTrait", "MultiSampleTrait", "ThreadUnsafe", "ReduceTrait"] from typing import Any @@ -99,3 +99,14 @@ class ThreadUnsafe: """ pass + + +class ReduceTrait: + """ + An interface to indicate that the transform has the capability to reduce multiple samples + into a single sample. + This interface can be extended from by people adapting transforms to the MONAI framework as well + as by implementors of MONAI transforms. + """ + + pass diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 05a08e0743..a0575d997d 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -25,7 +25,7 @@ from monai import config, transforms from monai.config import KeysCollection from monai.data.meta_tensor import MetaTensor -from monai.transforms.traits import LazyTrait, RandomizableTrait, ThreadUnsafe +from monai.transforms.traits import LazyTrait, RandomizableTrait, ThreadUnsafe, ReduceTrait from monai.utils import MAX_SEED, ensure_tuple, first from monai.utils.enums import TransformBackends from monai.utils.misc import MONAIEnvVars @@ -142,17 +142,11 @@ def apply_transform( """ try: map_items_ = int(map_items) if isinstance(map_items, bool) else map_items - if isinstance(data, (list, tuple)) and map_items_ > 0: - res: list[Any] = [] - for item in data: - res_item = apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides) - # Only extend if we're at the leaf level (map_items_ == 1) and the transform - # actually returned a list (not preserving nested structure) - if isinstance(res_item, list) and map_items_ == 1 and not isinstance(item, (list, tuple)): - res.extend(res_item) - else: - res.append(res_item) - return res + if isinstance(data, (list, tuple)) and map_items_ > 0 and not isinstance(transform, ReduceTrait): + return [ + apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides) + for item in data + ] return _apply_transform(transform, data, unpack_items, lazy, overrides, log_stats) except Exception as e: # if in debug mode, don't swallow exception so that the breakpoint diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 18a0f7f32f..1ba0d599ed 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -43,7 +43,7 @@ median_filter, ) from monai.transforms.inverse import InvertibleTransform, TraceableTransform -from monai.transforms.traits import MultiSampleTrait +from monai.transforms.traits import MultiSampleTrait, ReduceTrait from monai.transforms.transform import Randomizable, RandomizableTrait, RandomizableTransform, Transform from monai.transforms.utils import ( apply_affine_to_points, @@ -110,6 +110,7 @@ "ImageFilter", "RandImageFilter", "ApplyTransformToPoints", + "FlattenSequence" ] @@ -1950,3 +1951,16 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: data = inverse_transform(data, transform[TraceKeys.EXTRA_INFO]["image_affine"]) return data + + +class FlattenSequence(Transform, ReduceTrait): + def __init__(self): + super().__init__() + + def __call__(self, data): + if isinstance(data, (list, tuple)): + if len(data) == 0: + return data + if isinstance(data[0], (list, tuple)): + return [item for sublist in data for item in sublist] + return data diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 7dd2397a74..1d21e5d2d9 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -30,7 +30,7 @@ from monai.data.meta_tensor import MetaObj, MetaTensor from monai.data.utils import no_collation from monai.transforms.inverse import InvertibleTransform -from monai.transforms.traits import MultiSampleTrait, RandomizableTrait +from monai.transforms.traits import MultiSampleTrait, RandomizableTrait, ReduceTrait from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform from monai.transforms.utility.array import ( AddCoordinateChannels, @@ -64,6 +64,7 @@ TorchVision, ToTensor, Transpose, + FlattenSequence ) from monai.transforms.utils import extreme_points_to_image, get_extreme_points from monai.transforms.utils_pytorch_numpy_unification import concatenate @@ -191,6 +192,9 @@ "ApplyTransformToPointsd", "ApplyTransformToPointsD", "ApplyTransformToPointsDict", + "FlattenSequenced", + "FlattenSequenceD", + "FlattenSequenceDict" ] DEFAULT_POST_FIX = PostFix.meta() @@ -1906,6 +1910,23 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch return d +class FlattenSequenced(MapTransform, ReduceTrait): + def __init__( + self, + keys: KeysCollection, + allow_missing_keys: bool = False, + **kwargs, + ) -> None: + super().__init__(keys, allow_missing_keys) + self.flatten_sequence = FlattenSequence(**kwargs) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.flatten_sequence(d[key]) + return d + + RandImageFilterD = RandImageFilterDict = RandImageFilterd ImageFilterD = ImageFilterDict = ImageFilterd IdentityD = IdentityDict = Identityd @@ -1949,3 +1970,4 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch AddCoordinateChannelsD = AddCoordinateChannelsDict = AddCoordinateChannelsd FlattenSubKeysD = FlattenSubKeysDict = FlattenSubKeysd ApplyTransformToPointsD = ApplyTransformToPointsDict = ApplyTransformToPointsd +FlattenSequenceD = FlattenSequenceDict = FlattenSequenced diff --git a/tests/transforms/compose/test_compose.py b/tests/transforms/compose/test_compose.py index 01c7d92e7d..afbf26a8ab 100644 --- a/tests/transforms/compose/test_compose.py +++ b/tests/transforms/compose/test_compose.py @@ -285,16 +285,19 @@ def test_backwards_compatible_imports(self): def test_list_extend_multi_sample_trait(self): center_crop = mt.CenterSpatialCrop([128, 128]) multi_sample_transform = mt.RandSpatialCropSamples([64, 64], 1) + flatten_sequence_transform = mt.FlattenSequence() img = torch.zeros([1, 512, 512]) self.assertEqual(execute_compose(img, [center_crop]).shape, torch.Size([1, 128, 128])) - single_multi_sample_trait_result = execute_compose(img, [multi_sample_transform, center_crop]) + single_multi_sample_trait_result = execute_compose(img, [multi_sample_transform, center_crop, flatten_sequence_transform]) self.assertIsInstance(single_multi_sample_trait_result, list) self.assertEqual(len(single_multi_sample_trait_result), 1) self.assertEqual(single_multi_sample_trait_result[0].shape, torch.Size([1, 64, 64])) - double_multi_sample_trait_result = execute_compose(img, [multi_sample_transform, multi_sample_transform, center_crop]) + double_multi_sample_trait_result = execute_compose(img, [ + multi_sample_transform, multi_sample_transform, flatten_sequence_transform, center_crop + ]) self.assertIsInstance(double_multi_sample_trait_result, list) self.assertEqual(len(double_multi_sample_trait_result), 1) self.assertEqual(double_multi_sample_trait_result[0].shape, torch.Size([1, 64, 64])) @@ -302,9 +305,10 @@ def test_list_extend_multi_sample_trait(self): def test_multi_sample_trait_cardinality(self): img = torch.zeros([1, 128, 128]) t2 = mt.RandSpatialCropSamples([32, 32], num_samples=2) + flatten_sequence_transform = mt.FlattenSequence() # chaining should multiply counts: 2 x 2 = 4, flattened - res = execute_compose(img, [t2, t2]) + res = execute_compose(img, [t2, t2, flatten_sequence_transform]) self.assertIsInstance(res, list) self.assertEqual(len(res), 4) for r in res: From fee6cd30f2864b64aa6f4d343c7e706a30ef825d Mon Sep 17 00:00:00 2001 From: Lukas Folle Date: Fri, 10 Oct 2025 09:09:20 +0200 Subject: [PATCH 12/16] added documentation Signed-off-by: Lukas Folle --- docs/source/transforms.rst | 17 +++++++++++++++++ monai/transforms/utility/array.py | 16 ++++++++++++++++ monai/transforms/utility/dictionary.py | 9 +++++++++ 3 files changed, 42 insertions(+) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index d2585daf63..2d5d452dc0 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -37,6 +37,11 @@ Generic Interfaces .. autoclass:: MultiSampleTrait :members: +`ReduceTrait` +^^^^^^^^^^^^^^^^^^ +.. autoclass:: ReduceTrait + :members: + `Randomizable` ^^^^^^^^^^^^^^ .. autoclass:: Randomizable @@ -1252,6 +1257,12 @@ Utility :members: :special-members: __call__ +`FlattenSequence` +"""""""""""""""""""""""" +.. autoclass:: FlattenSequence + :members: + :special-members: __call__ + Dictionary Transforms --------------------- @@ -2337,6 +2348,12 @@ Utility (Dict) :members: :special-members: __call__ +`FlattenSequenced` +""""""""""""""""""""""""" +.. autoclass:: FlattenSequenced + :members: + :special-members: __call__ + MetaTensor ^^^^^^^^^^ diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 1ba0d599ed..e7d589d13c 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1954,6 +1954,22 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: class FlattenSequence(Transform, ReduceTrait): + """ + Flatten a nested sequence (list or tuple) by one level. + If the input is a sequence of sequences, it will flatten them into a single sequence. + Non-nested sequences and other data types are returned unchanged. + + For example: + + .. code-block:: python + + flatten = FlattenSequence() + data = [[1, 2], [3, 4], [5, 6]] + print(flatten(data)) + [1, 2, 3, 4, 5, 6] + + """ + def __init__(self): super().__init__() diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 1d21e5d2d9..5043d8931b 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -1911,6 +1911,15 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch class FlattenSequenced(MapTransform, ReduceTrait): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.FlattenSequence`. + + Args: + keys: keys of the corresponding items to be transformed. + See also: monai.transforms.MapTransform + allow_missing_keys: + Don't raise exception if key is missing. + """ def __init__( self, keys: KeysCollection, From 416584d89b26d78061ca61511ee4e43ef4827c38 Mon Sep 17 00:00:00 2001 From: Lukas Folle Date: Fri, 10 Oct 2025 09:20:16 +0200 Subject: [PATCH 13/16] fixed doc build; fixed isort Signed-off-by: Lukas Folle --- monai/transforms/__init__.py | 10 +++++----- monai/transforms/transform.py | 2 +- monai/transforms/utility/array.py | 2 +- monai/transforms/utility/dictionary.py | 12 ++++-------- tests/transforms/compose/test_compose.py | 10 ++++++---- 5 files changed, 17 insertions(+), 19 deletions(-) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 5f4c2d9289..0ab9fe63d5 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -506,7 +506,7 @@ ZoomDict, ) from .spatial.functional import spatial_resample -from .traits import LazyTrait, MultiSampleTrait, RandomizableTrait, ThreadUnsafe +from .traits import LazyTrait, MultiSampleTrait, RandomizableTrait, ReduceTrait, ThreadUnsafe from .transform import LazyTransform, MapTransform, Randomizable, RandomizableTransform, Transform, apply_transform from .utility.array import ( AddCoordinateChannels, @@ -521,6 +521,7 @@ EnsureChannelFirst, EnsureType, FgBgToIndices, + FlattenSequence, Identity, ImageFilter, IntensityStats, @@ -546,7 +547,6 @@ TorchVision, ToTensor, Transpose, - FlattenSequence, ) from .utility.dictionary import ( AddCoordinateChannelsd, @@ -594,6 +594,9 @@ FgBgToIndicesd, FgBgToIndicesD, FgBgToIndicesDict, + FlattenSequenced, + FlattenSequenceD, + FlattenSequenceDict, FlattenSubKeysd, FlattenSubKeysD, FlattenSubKeysDict, @@ -672,9 +675,6 @@ Transposed, TransposeD, TransposeDict, - FlattenSequenced, - FlattenSequenceD, - FlattenSequenceDict, ) from .utils import ( Fourier, diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index a0575d997d..1eedc7c333 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -25,7 +25,7 @@ from monai import config, transforms from monai.config import KeysCollection from monai.data.meta_tensor import MetaTensor -from monai.transforms.traits import LazyTrait, RandomizableTrait, ThreadUnsafe, ReduceTrait +from monai.transforms.traits import LazyTrait, RandomizableTrait, ReduceTrait, ThreadUnsafe from monai.utils import MAX_SEED, ensure_tuple, first from monai.utils.enums import TransformBackends from monai.utils.misc import MONAIEnvVars diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index e7d589d13c..3f89ef899f 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -110,7 +110,7 @@ "ImageFilter", "RandImageFilter", "ApplyTransformToPoints", - "FlattenSequence" + "FlattenSequence", ] diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 5043d8931b..996b603e00 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -45,6 +45,7 @@ EnsureChannelFirst, EnsureType, FgBgToIndices, + FlattenSequence, Identity, ImageFilter, IntensityStats, @@ -64,7 +65,6 @@ TorchVision, ToTensor, Transpose, - FlattenSequence ) from monai.transforms.utils import extreme_points_to_image, get_extreme_points from monai.transforms.utils_pytorch_numpy_unification import concatenate @@ -194,7 +194,7 @@ "ApplyTransformToPointsDict", "FlattenSequenced", "FlattenSequenceD", - "FlattenSequenceDict" + "FlattenSequenceDict", ] DEFAULT_POST_FIX = PostFix.meta() @@ -1920,12 +1920,8 @@ class FlattenSequenced(MapTransform, ReduceTrait): allow_missing_keys: Don't raise exception if key is missing. """ - def __init__( - self, - keys: KeysCollection, - allow_missing_keys: bool = False, - **kwargs, - ) -> None: + + def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False, **kwargs) -> None: super().__init__(keys, allow_missing_keys) self.flatten_sequence = FlattenSequence(**kwargs) diff --git a/tests/transforms/compose/test_compose.py b/tests/transforms/compose/test_compose.py index afbf26a8ab..12547f9ec2 100644 --- a/tests/transforms/compose/test_compose.py +++ b/tests/transforms/compose/test_compose.py @@ -290,14 +290,16 @@ def test_list_extend_multi_sample_trait(self): img = torch.zeros([1, 512, 512]) self.assertEqual(execute_compose(img, [center_crop]).shape, torch.Size([1, 128, 128])) - single_multi_sample_trait_result = execute_compose(img, [multi_sample_transform, center_crop, flatten_sequence_transform]) + single_multi_sample_trait_result = execute_compose( + img, [multi_sample_transform, center_crop, flatten_sequence_transform] + ) self.assertIsInstance(single_multi_sample_trait_result, list) self.assertEqual(len(single_multi_sample_trait_result), 1) self.assertEqual(single_multi_sample_trait_result[0].shape, torch.Size([1, 64, 64])) - double_multi_sample_trait_result = execute_compose(img, [ - multi_sample_transform, multi_sample_transform, flatten_sequence_transform, center_crop - ]) + double_multi_sample_trait_result = execute_compose( + img, [multi_sample_transform, multi_sample_transform, flatten_sequence_transform, center_crop] + ) self.assertIsInstance(double_multi_sample_trait_result, list) self.assertEqual(len(double_multi_sample_trait_result), 1) self.assertEqual(double_multi_sample_trait_result[0].shape, torch.Size([1, 64, 64])) From a8f3fe9c20266ecab5ac3736bace5644e9701ba4 Mon Sep 17 00:00:00 2001 From: Lukas Folle Date: Fri, 10 Oct 2025 09:40:32 +0200 Subject: [PATCH 14/16] added type hints and fixed potential bug Signed-off-by: Lukas Folle --- monai/transforms/utility/array.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 3f89ef899f..0332bcd646 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1973,10 +1973,17 @@ class FlattenSequence(Transform, ReduceTrait): def __init__(self): super().__init__() - def __call__(self, data): - if isinstance(data, (list, tuple)): - if len(data) == 0: - return data - if isinstance(data[0], (list, tuple)): - return [item for sublist in data for item in sublist] - return data +def __call__(self, data: list | tuple | Any) -> list | tuple | Any: + """ + Flatten a nested sequence by one level. + Args: + data: Input data, can be a nested sequence. + Returns: + Flattened list if input is a nested sequence, otherwise returns data unchanged. + """ + if isinstance(data, (list, tuple)): + if len(data) == 0: + return data + if all(isinstance(item, (list, tuple)) for item in data): + return [item for sublist in data for item in sublist] + return data From c707a2ce0b812c447daf621621b4080af6595381 Mon Sep 17 00:00:00 2001 From: Lukas Folle Date: Fri, 10 Oct 2025 09:52:59 +0200 Subject: [PATCH 15/16] formatted Signed-off-by: Lukas Folle --- monai/transforms/utility/array.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 0332bcd646..2ac37f2f81 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1973,17 +1973,17 @@ class FlattenSequence(Transform, ReduceTrait): def __init__(self): super().__init__() -def __call__(self, data: list | tuple | Any) -> list | tuple | Any: - """ - Flatten a nested sequence by one level. - Args: - data: Input data, can be a nested sequence. - Returns: - Flattened list if input is a nested sequence, otherwise returns data unchanged. - """ - if isinstance(data, (list, tuple)): - if len(data) == 0: - return data - if all(isinstance(item, (list, tuple)) for item in data): - return [item for sublist in data for item in sublist] - return data + def __call__(self, data: list | tuple | Any) -> list | tuple | Any: + """ + Flatten a nested sequence by one level. + Args: + data: Input data, can be a nested sequence. + Returns: + Flattened list if input is a nested sequence, otherwise returns data unchanged. + """ + if isinstance(data, (list, tuple)): + if len(data) == 0: + return data + if all(isinstance(item, (list, tuple)) for item in data): + return [item for sublist in data for item in sublist] + return data From 2644f47ce91b934ebe8ed55527c84a516c56440a Mon Sep 17 00:00:00 2001 From: Lukas Folle Date: Fri, 10 Oct 2025 10:25:09 +0200 Subject: [PATCH 16/16] ignored mypy error Signed-off-by: Lukas Folle --- monai/transforms/utility/dictionary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 996b603e00..95c59e07bc 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -1928,7 +1928,7 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False, **kwa def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): - d[key] = self.flatten_sequence(d[key]) + d[key] = self.flatten_sequence(d[key]) # type: ignore[assignment] return d