From e84ee42c776a8832e54ccefaed247f1f9e24012c Mon Sep 17 00:00:00 2001 From: Tuan Chien Date: Tue, 14 Mar 2023 23:06:42 +1300 Subject: [PATCH 1/7] Add SomeOf transform composer Signed-off-by: Tuan Chien --- docs/source/transforms.rst | 5 + monai/transforms/__init__.py | 2 +- monai/transforms/compose.py | 142 ++++++++++++++++++++++- tests/test_some_of.py | 210 +++++++++++++++++++++++++++++++++++ 4 files changed, 357 insertions(+), 2 deletions(-) create mode 100644 tests/test_some_of.py diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 56fe4bc1e7..aaa4364de6 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -88,6 +88,11 @@ Generic Interfaces .. autoclass:: RandomOrder :members: +`SomeOf` +^^^^^^^^^^^^^ +.. autoclass:: SomeOf + :members: + Functionals ----------- diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 940485cbe0..11790e639b 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -12,7 +12,7 @@ from __future__ import annotations from .adaptors import FunctionSignature, adaptor, apply_alias, to_kwargs -from .compose import Compose, OneOf, RandomOrder +from .compose import Compose, OneOf, RandomOrder, SomeOf from .croppad.array import ( BorderPad, BoundingRect, diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 45e706e143..bc0f58e3ee 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -34,7 +34,7 @@ from monai.utils import MAX_SEED, ensure_tuple, get_seed from monai.utils.enums import TraceKeys -__all__ = ["Compose", "OneOf", "RandomOrder"] +__all__ = ["Compose", "OneOf", "RandomOrder", "SomeOf"] class Compose(Randomizable, InvertibleTransform): @@ -358,3 +358,143 @@ def inverse(self, data): for o in reversed(applied_order): data = apply_transform(self.transforms[o].inverse, data, self.map_items, self.unpack_items, self.log_stats) return data + + +class SomeOf(Compose): + """ + ``SomeOf`` samples a different sequence of transforms to apply each time it is called. + + It can be configured to sample a fixed or varying number of transforms each time its called. Samples are drawn uniformly, + or from user supplied transform weights. When varying the number of transforms sampled per call, the number of transforms + to sample that call is sampled uniformly from a range supplied by the user. + + Args: + transforms: list of callables. + map_items: whether to apply transform to each item in the input `data` if `data` is a list or tuple. + Defaults to `True`. + unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform. + Defaults to `False`. + log_stats: whether to log the detailed information of data and applied transform when error happened, + for NumPy array and PyTorch Tensor, log the data shape and value range, + for other metadata, log the values directly. Default to `False`. + num_transforms: a 2-tuple, int, or None. The 2-tuple specifies the minimum and maximum (inclusive) number of transforms to sample at each iteration. + If an int is given, the lower and upper bounds are set equal. None sets it to `len(transforms)`. Default to `None`. + replace: whether to sample with replacement. Defaults to `False`. + weights: weights to use in for sampling transforms. Will be normalized to 1. Default: None (uniform). + """ + + def __init__( + self, + transforms: Sequence[Callable] | Callable | None = None, + map_items: bool = True, + unpack_items: bool = False, + log_stats: bool = False, + *, + num_transforms: int | tuple[int, int] = None, + replace: bool = False, + weights: list[int] | None = None, + ) -> None: + super().__init__(transforms, map_items, unpack_items, log_stats) + self.min_num_transforms, self.max_num_transforms = self._ensure_valid_num_transforms(num_transforms) + self.replace = replace + self.weights = self._normalize_probabilities(weights) + + def _ensure_valid_num_transforms(self, num_transforms: int | tuple[int, int]) -> tuple[int, int]: + if ( + not isinstance(num_transforms, tuple) + and not isinstance(num_transforms, list) + and not isinstance(num_transforms, int) + and num_transforms is not None + ): + raise ValueError( + f"Expected num_transforms to be of type int, list, tuple or None, but it's {type(num_transforms)}" + ) + + if num_transforms is None: + result = [len(self.transforms), len(self.transforms)] + elif isinstance(num_transforms, int): + n = min(len(self.transforms), num_transforms) + result = [n, n] + else: + if len(num_transforms) != 2: + raise ValueError(f"Expected len(num_transforms)=2, but it was {len(num_transforms)}") + if not isinstance(num_transforms[0], int) or not isinstance(num_transforms[1], int): + raise ValueError( + f"Expected (int,int), but received ({type(num_transforms[0])}, {type(num_transforms[1])})" + ) + + result = [num_transforms[0], num_transforms[1]] + + if result[0] < 0 or result[1] > len(self.transforms): + raise ValueError(f"num_transforms={num_transforms} are out of the bounds [0, {len(self.transforms)}].") + + return ensure_tuple(result) + + # Modified from OneOf + def _normalize_probabilities(self, weights): + if weights is None or len(self.transforms) == 0: + return None + + weights = np.array(weights) + + n_weights = len(weights) + if n_weights != len(self.transforms): + raise ValueError(f"Expected len(weights)={len(self.transforms)}, got: {n_weights}.") + + if np.any(weights < 0): + raise ValueError(f"Probabilities must be greater than or equal to zero, got {weights}.") + + if np.all(weights == 0): + raise ValueError(f"At least one probability must be greater than zero, got {weights}.") + + weights = weights / weights.sum() + + return ensure_tuple(list(weights)) + + def __call__(self, data): + if len(self.transforms) == 0: + return data + + sample_size = self.R.randint(self.min_num_transforms, self.max_num_transforms + 1) + applied_order = self.R.choice(len(self.transforms), sample_size, replace=self.replace, p=self.weights).tolist() + for i in applied_order: + data = apply_transform(self.transforms[i], data, self.map_items, self.unpack_items, self.log_stats) + + if isinstance(data, monai.data.MetaTensor): + self.push_transform(data, extra_info={"applied_order": applied_order}) + elif isinstance(data, Mapping): + for key in data: # dictionary not change size during iteration + if isinstance(data[key], monai.data.MetaTensor) or self.trace_key(key) in data: + self.push_transform(data, key, extra_info={"applied_order": applied_order}) + + return data + + # From RandomOrder + def inverse(self, data): + if len(self.transforms) == 0: + return data + + applied_order = None + if isinstance(data, monai.data.MetaTensor): + applied_order = self.pop_transform(data)[TraceKeys.EXTRA_INFO]["applied_order"] + elif isinstance(data, Mapping): + for key in data: + if isinstance(data[key], monai.data.MetaTensor) or self.trace_key(key) in data: + applied_order = self.pop_transform(data, key)[TraceKeys.EXTRA_INFO]["applied_order"] + else: + raise RuntimeError( + f"Inverse only implemented for Mapping (dictionary) or MetaTensor data, got type {type(data)}." + ) + if applied_order is None: + # no invertible transforms have been applied + return data + + # loop backwards over transforms + for o in reversed(applied_order): + transform = self.transforms[o] + if isinstance(transform, InvertibleTransform): + data = apply_transform( + self.transforms[o].inverse, data, self.map_items, self.unpack_items, self.log_stats + ) + + return data diff --git a/tests/test_some_of.py b/tests/test_some_of.py new file mode 100644 index 0000000000..0cc903bb2d --- /dev/null +++ b/tests/test_some_of.py @@ -0,0 +1,210 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +from parameterized import parameterized + +from monai.data import MetaTensor +from monai.transforms import TraceableTransform, Transform +from monai.transforms.compose import Compose, SomeOf +from monai.utils import set_determinism +from monai.utils.enums import TraceKeys +from tests.test_one_of import NonInv +from tests.test_random_order import InvC, InvD + + +class A(Transform): + def __call__(self, x): + return 2 * x + + +class B(Transform): + def __call__(self, x): + return 3 * x + + +class C(Transform): + def __call__(self, x): + return 5 * x + + +class D(Transform): + def __call__(self, x): + return 7 * x + + +KEYS = ["x", "y"] +TEST_COMPOUND = [ + (SomeOf((A(), B(), C()), num_transforms=3), 2 * 3 * 5), + (Compose((SomeOf((A(), B(), C()), num_transforms=3), D())), 2 * 3 * 5 * 7), + (SomeOf((A(), B(), C(), Compose(D())), num_transforms=4), 2 * 3 * 5 * 7), + (SomeOf(()), 1), + (SomeOf(None), 1), +] + +# Modified from RandomOrder +TEST_INVERSES = [ + (SomeOf((InvC(KEYS), InvD(KEYS))), True, True), + (Compose((SomeOf((InvC(KEYS), InvD(KEYS))), SomeOf((InvD(KEYS), InvC(KEYS))))), True, False), + (SomeOf((SomeOf((InvC(KEYS), InvD(KEYS))), SomeOf((InvD(KEYS), InvC(KEYS))))), True, False), + (SomeOf((Compose((InvC(KEYS), InvD(KEYS))), Compose((InvD(KEYS), InvC(KEYS))))), True, False), + (SomeOf((NonInv(KEYS), NonInv(KEYS))), False, False), + (SomeOf(()), False, False), +] + + +class TestSomeOf(unittest.TestCase): + def setUp(self): + set_determinism(seed=0) + + def tearDown(self): + set_determinism(None) + + def update_transform_count(self, counts, output): + op_count = 0 + + if output % 2 == 0: + counts[0] += 1 + op_count += 1 + if output % 3 == 0: + counts[1] += 1 + op_count += 1 + if output % 5 == 0: + counts[2] += 1 + op_count += 1 + + return op_count + + def test_fixed(self): + iterations = 10000 + num_transforms = 3 + transform_counts = 3 * [0] + subset_size_counts = 4 * [0] + + s = SomeOf((A(), B(), C()), num_transforms=num_transforms) + + for _ in range(iterations): + output = s(1) + subset_size = self.update_transform_count(transform_counts, output) + subset_size_counts[subset_size] += 1 + + for i in range(3): + self.assertEqual(transform_counts[i], iterations) + + for i in range(3): + self.assertEqual(subset_size_counts[i], 0) + + self.assertEqual(subset_size_counts[3], iterations) + + def test_unfixed(self): + iterations = 10000 + num_transforms = (0, 3) + transform_counts = 3 * [0] + subset_size_counts = 4 * [0] + + s = SomeOf((A(), B(), C()), num_transforms=num_transforms) + + for _ in range(iterations): + output = s(1) + subset_size = self.update_transform_count(transform_counts, output) + subset_size_counts[subset_size] += 1 + + for i in range(3): + self.assertAlmostEqual(transform_counts[i] / iterations, 0.5, delta=0.01) + + for i in range(4): + self.assertAlmostEqual(subset_size_counts[i] / iterations, 0.25, delta=0.01) + + def test_non_dict_metatensor(self): + data = MetaTensor(1) + s = SomeOf([A()], num_transforms=1) + out = s(data) + self.assertEqual(out, 2) + inv = s.inverse(out) # A() is not invertible, nothing happens + self.assertEqual(inv, 2) + + @parameterized.expand(TEST_COMPOUND) + def test_compound_pipeline(self, transform, expected_value): + output = transform(1) + self.assertEqual(output, expected_value) + + # Modified from RandomOrder + @parameterized.expand(TEST_INVERSES) + def test_inverse(self, transform, invertible, use_metatensor): + data = {k: (i + 1) * 10.0 if not use_metatensor else MetaTensor((i + 1) * 10.0) for i, k in enumerate(KEYS)} + fwd_data1 = transform(data) + # test call twice won't affect inverse + fwd_data2 = transform(data) + + if invertible: + for k in KEYS: + t = ( + fwd_data1[TraceableTransform.trace_key(k)][-1] + if not use_metatensor + else fwd_data1[k].applied_operations[-1] + ) + # make sure the SomeOf applied_order was stored + self.assertEqual(t[TraceKeys.CLASS_NAME], SomeOf.__name__) + + # call the inverse + fwd_inv_data1 = transform.inverse(fwd_data1) + fwd_inv_data2 = transform.inverse(fwd_data2) + + fwd_data = [fwd_data1, fwd_data2] + fwd_inv_data = [fwd_inv_data1, fwd_inv_data2] + for i, _fwd_inv_data in enumerate(fwd_inv_data): + if invertible: + for k in KEYS: + # check transform was removed + if not use_metatensor: + self.assertTrue( + len(_fwd_inv_data[TraceableTransform.trace_key(k)]) + < len(fwd_data[i][TraceableTransform.trace_key(k)]) + ) + # check data is same as original (and different from forward) + self.assertEqual(_fwd_inv_data[k], data[k]) + self.assertNotEqual(_fwd_inv_data[k], fwd_data[i][k]) + else: + # if not invertible, should not change the data + self.assertDictEqual(fwd_data[i], _fwd_inv_data) + + def test_bad_inverse_data(self): + tr = SomeOf((A(), B(), C()), num_transforms=1, weights=(1, 2, 1)) + self.assertRaises(RuntimeError, tr.inverse, []) + + def test_normalize_weights(self): + tr = SomeOf((A(), B(), C()), num_transforms=1, weights=(1, 2, 1)) + self.assertTupleEqual(tr.weights, (0.25, 0.5, 0.25)) + + tr = SomeOf((), num_transforms=1, weights=(1, 2, 1)) + self.assertIsNone(tr.weights) + + def test_no_weights_arg(self): + tr = SomeOf((A(), B(), C(), D()), num_transforms=1) + self.assertIsNone(tr.weights) + + def test_bad_weights(self): + self.assertRaises(ValueError, SomeOf, (A(), B(), C()), num_transforms=1, weights=(1, 2)) + self.assertRaises(ValueError, SomeOf, (A(), B(), C()), num_transforms=1, weights=(0, 0, 0)) + self.assertRaises(ValueError, SomeOf, (A(), B(), C()), num_transforms=1, weights=(-1, 1, 1)) + + def test_bad_num_transforms(self): + self.assertRaises(ValueError, SomeOf, (A(), B(), C()), num_transforms=(-1, 2)) + self.assertRaises(ValueError, SomeOf, (A(), B(), C()), num_transforms="str") + self.assertRaises(ValueError, SomeOf, (A(), B(), C()), num_transforms=(1, 2, 3)) + self.assertRaises(ValueError, SomeOf, (A(), B(), C()), num_transforms=("a", 1)) + + +if __name__ == "__main__": + unittest.main() From 542462e525859c2f56dfdd220ed3d629e7941e01 Mon Sep 17 00:00:00 2001 From: Tuan Chien Date: Fri, 17 Mar 2023 21:18:01 +1300 Subject: [PATCH 2/7] Formatting fix Signed-off-by: Tuan Chien --- monai/transforms/compose.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index bc0f58e3ee..89341ec0f1 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -364,9 +364,9 @@ class SomeOf(Compose): """ ``SomeOf`` samples a different sequence of transforms to apply each time it is called. - It can be configured to sample a fixed or varying number of transforms each time its called. Samples are drawn uniformly, - or from user supplied transform weights. When varying the number of transforms sampled per call, the number of transforms - to sample that call is sampled uniformly from a range supplied by the user. + It can be configured to sample a fixed or varying number of transforms each time its called. Samples are drawn + uniformly, or from user supplied transform weights. When varying the number of transforms sampled per call, + the number of transforms to sample that call is sampled uniformly from a range supplied by the user. Args: transforms: list of callables. @@ -377,8 +377,10 @@ class SomeOf(Compose): log_stats: whether to log the detailed information of data and applied transform when error happened, for NumPy array and PyTorch Tensor, log the data shape and value range, for other metadata, log the values directly. Default to `False`. - num_transforms: a 2-tuple, int, or None. The 2-tuple specifies the minimum and maximum (inclusive) number of transforms to sample at each iteration. - If an int is given, the lower and upper bounds are set equal. None sets it to `len(transforms)`. Default to `None`. + num_transforms: a 2-tuple, int, or None. The 2-tuple specifies the minimum and maximum (inclusive) number of + transforms to sample at each iteration. + If an int is given, the lower and upper bounds are set equal. None sets it to `len(transforms)`. + Default to `None`. replace: whether to sample with replacement. Defaults to `False`. weights: weights to use in for sampling transforms. Will be normalized to 1. Default: None (uniform). """ From 7cbda26e099bd5e28732941487b734f693862d70 Mon Sep 17 00:00:00 2001 From: Tuan Chien Date: Fri, 17 Mar 2023 21:29:10 +1300 Subject: [PATCH 3/7] Docstring fix Signed-off-by: Tuan Chien --- monai/transforms/compose.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 89341ec0f1..1c37745e67 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -378,9 +378,8 @@ class SomeOf(Compose): for NumPy array and PyTorch Tensor, log the data shape and value range, for other metadata, log the values directly. Default to `False`. num_transforms: a 2-tuple, int, or None. The 2-tuple specifies the minimum and maximum (inclusive) number of - transforms to sample at each iteration. - If an int is given, the lower and upper bounds are set equal. None sets it to `len(transforms)`. - Default to `None`. + transforms to sample at each iteration. If an int is given, the lower and upper bounds are set equal. + None sets it to `len(transforms)`. Default to `None`. replace: whether to sample with replacement. Defaults to `False`. weights: weights to use in for sampling transforms. Will be normalized to 1. Default: None (uniform). """ From db4fdebec6de251b4de0d9b404922d7ec10bc5be Mon Sep 17 00:00:00 2001 From: Tuan Chien Date: Fri, 17 Mar 2023 21:54:54 +1300 Subject: [PATCH 4/7] Linting fix Signed-off-by: Tuan Chien --- monai/transforms/compose.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 1c37745e67..7b87426ed8 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -400,7 +400,7 @@ def __init__( self.replace = replace self.weights = self._normalize_probabilities(weights) - def _ensure_valid_num_transforms(self, num_transforms: int | tuple[int, int]) -> tuple[int, int]: + def _ensure_valid_num_transforms(self, num_transforms: int | tuple[int, int]) -> tuple: if ( not isinstance(num_transforms, tuple) and not isinstance(num_transforms, list) From f5c7e32b96c545f76c7825c5e16fc9cc81d0ca87 Mon Sep 17 00:00:00 2001 From: Tuan Chien Date: Fri, 17 Mar 2023 22:30:17 +1300 Subject: [PATCH 5/7] Linting fix Signed-off-by: Tuan Chien --- monai/transforms/compose.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 7b87426ed8..a67dee4503 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -391,7 +391,7 @@ def __init__( unpack_items: bool = False, log_stats: bool = False, *, - num_transforms: int | tuple[int, int] = None, + num_transforms: int | tuple[int, int] | None = None, replace: bool = False, weights: list[int] | None = None, ) -> None: From 3eb721755310863aec2fc9335df8fb2b40d3b21f Mon Sep 17 00:00:00 2001 From: Tuan Chien Date: Fri, 17 Mar 2023 23:35:27 +1300 Subject: [PATCH 6/7] Linting fix Signed-off-by: Tuan Chien --- monai/transforms/compose.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index a67dee4503..3a67513ea4 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -400,7 +400,7 @@ def __init__( self.replace = replace self.weights = self._normalize_probabilities(weights) - def _ensure_valid_num_transforms(self, num_transforms: int | tuple[int, int]) -> tuple: + def _ensure_valid_num_transforms(self, num_transforms: int | tuple[int, int] | None) -> tuple: if ( not isinstance(num_transforms, tuple) and not isinstance(num_transforms, list) From 39e93697e6fb649e47899e368729ac5aa483fef8 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 23 Mar 2023 20:31:52 +0000 Subject: [PATCH 7/7] update Signed-off-by: Wenqi Li --- monai/utils/__init__.py | 1 + monai/utils/misc.py | 1 + 2 files changed, 2 insertions(+) diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 8210ec924c..75806ce120 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -88,6 +88,7 @@ star_zip_with, str2bool, str2list, + to_tuple_of_dictionaries, zip_with, ) from .module import ( diff --git a/monai/utils/misc.py b/monai/utils/misc.py index f22716a376..0b2df36a5b 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -41,6 +41,7 @@ "ensure_tuple", "ensure_tuple_size", "ensure_tuple_rep", + "to_tuple_of_dictionaries", "fall_back_tuple", "is_scalar_tensor", "is_scalar",