diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 584f67bc62..e045a7e741 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -88,6 +88,11 @@ Generic Interfaces .. autoclass:: RandomOrder :members: +`SomeOf` +^^^^^^^^^^^^^ +.. autoclass:: SomeOf + :members: + Functionals ----------- diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 940485cbe0..11790e639b 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -12,7 +12,7 @@ from __future__ import annotations from .adaptors import FunctionSignature, adaptor, apply_alias, to_kwargs -from .compose import Compose, OneOf, RandomOrder +from .compose import Compose, OneOf, RandomOrder, SomeOf from .croppad.array import ( BorderPad, BoundingRect, diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 0997d53dad..60456e9bc7 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -34,12 +34,11 @@ Transform, apply_transform, ) -from monai.utils import MAX_SEED, TraceKeys, ensure_tuple, get_seed -from monai.utils.misc import to_tuple_of_dictionaries +from monai.utils import MAX_SEED, TraceKeys, ensure_tuple, get_seed, to_tuple_of_dictionaries logger = get_logger(__name__) -__all__ = ["Compose", "OneOf", "RandomOrder", "evaluate_with_overrides"] +__all__ = ["Compose", "OneOf", "RandomOrder", "evaluate_with_overrides", "SomeOf"] def evaluate_with_overrides( @@ -521,3 +520,144 @@ def inverse(self, data): self.transforms[o].inverse, data, self.map_items, self.unpack_items, self.log_stats ) return data + + +class SomeOf(Compose): + """ + ``SomeOf`` samples a different sequence of transforms to apply each time it is called. + + It can be configured to sample a fixed or varying number of transforms each time its called. Samples are drawn + uniformly, or from user supplied transform weights. When varying the number of transforms sampled per call, + the number of transforms to sample that call is sampled uniformly from a range supplied by the user. + + Args: + transforms: list of callables. + map_items: whether to apply transform to each item in the input `data` if `data` is a list or tuple. + Defaults to `True`. + unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform. + Defaults to `False`. + log_stats: whether to log the detailed information of data and applied transform when error happened, + for NumPy array and PyTorch Tensor, log the data shape and value range, + for other metadata, log the values directly. Default to `False`. + num_transforms: a 2-tuple, int, or None. The 2-tuple specifies the minimum and maximum (inclusive) number of + transforms to sample at each iteration. If an int is given, the lower and upper bounds are set equal. + None sets it to `len(transforms)`. Default to `None`. + replace: whether to sample with replacement. Defaults to `False`. + weights: weights to use in for sampling transforms. Will be normalized to 1. Default: None (uniform). + """ + + def __init__( + self, + transforms: Sequence[Callable] | Callable | None = None, + map_items: bool = True, + unpack_items: bool = False, + log_stats: bool = False, + *, + num_transforms: int | tuple[int, int] | None = None, + replace: bool = False, + weights: list[int] | None = None, + ) -> None: + super().__init__(transforms, map_items, unpack_items, log_stats) + self.min_num_transforms, self.max_num_transforms = self._ensure_valid_num_transforms(num_transforms) + self.replace = replace + self.weights = self._normalize_probabilities(weights) + + def _ensure_valid_num_transforms(self, num_transforms: int | tuple[int, int] | None) -> tuple: + if ( + not isinstance(num_transforms, tuple) + and not isinstance(num_transforms, list) + and not isinstance(num_transforms, int) + and num_transforms is not None + ): + raise ValueError( + f"Expected num_transforms to be of type int, list, tuple or None, but it's {type(num_transforms)}" + ) + + if num_transforms is None: + result = [len(self.transforms), len(self.transforms)] + elif isinstance(num_transforms, int): + n = min(len(self.transforms), num_transforms) + result = [n, n] + else: + if len(num_transforms) != 2: + raise ValueError(f"Expected len(num_transforms)=2, but it was {len(num_transforms)}") + if not isinstance(num_transforms[0], int) or not isinstance(num_transforms[1], int): + raise ValueError( + f"Expected (int,int), but received ({type(num_transforms[0])}, {type(num_transforms[1])})" + ) + + result = [num_transforms[0], num_transforms[1]] + + if result[0] < 0 or result[1] > len(self.transforms): + raise ValueError(f"num_transforms={num_transforms} are out of the bounds [0, {len(self.transforms)}].") + + return ensure_tuple(result) + + # Modified from OneOf + def _normalize_probabilities(self, weights): + if weights is None or len(self.transforms) == 0: + return None + + weights = np.array(weights) + + n_weights = len(weights) + if n_weights != len(self.transforms): + raise ValueError(f"Expected len(weights)={len(self.transforms)}, got: {n_weights}.") + + if np.any(weights < 0): + raise ValueError(f"Probabilities must be greater than or equal to zero, got {weights}.") + + if np.all(weights == 0): + raise ValueError(f"At least one probability must be greater than zero, got {weights}.") + + weights = weights / weights.sum() + + return ensure_tuple(list(weights)) + + def __call__(self, data): + if len(self.transforms) == 0: + return data + + sample_size = self.R.randint(self.min_num_transforms, self.max_num_transforms + 1) + applied_order = self.R.choice(len(self.transforms), sample_size, replace=self.replace, p=self.weights).tolist() + for i in applied_order: + data = apply_transform(self.transforms[i], data, self.map_items, self.unpack_items, self.log_stats) + + if isinstance(data, monai.data.MetaTensor): + self.push_transform(data, extra_info={"applied_order": applied_order}) + elif isinstance(data, Mapping): + for key in data: # dictionary not change size during iteration + if isinstance(data[key], monai.data.MetaTensor) or self.trace_key(key) in data: + self.push_transform(data, key, extra_info={"applied_order": applied_order}) + + return data + + # From RandomOrder + def inverse(self, data): + if len(self.transforms) == 0: + return data + + applied_order = None + if isinstance(data, monai.data.MetaTensor): + applied_order = self.pop_transform(data)[TraceKeys.EXTRA_INFO]["applied_order"] + elif isinstance(data, Mapping): + for key in data: + if isinstance(data[key], monai.data.MetaTensor) or self.trace_key(key) in data: + applied_order = self.pop_transform(data, key)[TraceKeys.EXTRA_INFO]["applied_order"] + else: + raise RuntimeError( + f"Inverse only implemented for Mapping (dictionary) or MetaTensor data, got type {type(data)}." + ) + if applied_order is None: + # no invertible transforms have been applied + return data + + # loop backwards over transforms + for o in reversed(applied_order): + transform = self.transforms[o] + if isinstance(transform, InvertibleTransform): + data = apply_transform( + self.transforms[o].inverse, data, self.map_items, self.unpack_items, self.log_stats + ) + + return data diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 8210ec924c..75806ce120 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -88,6 +88,7 @@ star_zip_with, str2bool, str2list, + to_tuple_of_dictionaries, zip_with, ) from .module import ( diff --git a/monai/utils/misc.py b/monai/utils/misc.py index f22716a376..0b2df36a5b 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -41,6 +41,7 @@ "ensure_tuple", "ensure_tuple_size", "ensure_tuple_rep", + "to_tuple_of_dictionaries", "fall_back_tuple", "is_scalar_tensor", "is_scalar", diff --git a/tests/test_some_of.py b/tests/test_some_of.py new file mode 100644 index 0000000000..0cc903bb2d --- /dev/null +++ b/tests/test_some_of.py @@ -0,0 +1,210 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +from parameterized import parameterized + +from monai.data import MetaTensor +from monai.transforms import TraceableTransform, Transform +from monai.transforms.compose import Compose, SomeOf +from monai.utils import set_determinism +from monai.utils.enums import TraceKeys +from tests.test_one_of import NonInv +from tests.test_random_order import InvC, InvD + + +class A(Transform): + def __call__(self, x): + return 2 * x + + +class B(Transform): + def __call__(self, x): + return 3 * x + + +class C(Transform): + def __call__(self, x): + return 5 * x + + +class D(Transform): + def __call__(self, x): + return 7 * x + + +KEYS = ["x", "y"] +TEST_COMPOUND = [ + (SomeOf((A(), B(), C()), num_transforms=3), 2 * 3 * 5), + (Compose((SomeOf((A(), B(), C()), num_transforms=3), D())), 2 * 3 * 5 * 7), + (SomeOf((A(), B(), C(), Compose(D())), num_transforms=4), 2 * 3 * 5 * 7), + (SomeOf(()), 1), + (SomeOf(None), 1), +] + +# Modified from RandomOrder +TEST_INVERSES = [ + (SomeOf((InvC(KEYS), InvD(KEYS))), True, True), + (Compose((SomeOf((InvC(KEYS), InvD(KEYS))), SomeOf((InvD(KEYS), InvC(KEYS))))), True, False), + (SomeOf((SomeOf((InvC(KEYS), InvD(KEYS))), SomeOf((InvD(KEYS), InvC(KEYS))))), True, False), + (SomeOf((Compose((InvC(KEYS), InvD(KEYS))), Compose((InvD(KEYS), InvC(KEYS))))), True, False), + (SomeOf((NonInv(KEYS), NonInv(KEYS))), False, False), + (SomeOf(()), False, False), +] + + +class TestSomeOf(unittest.TestCase): + def setUp(self): + set_determinism(seed=0) + + def tearDown(self): + set_determinism(None) + + def update_transform_count(self, counts, output): + op_count = 0 + + if output % 2 == 0: + counts[0] += 1 + op_count += 1 + if output % 3 == 0: + counts[1] += 1 + op_count += 1 + if output % 5 == 0: + counts[2] += 1 + op_count += 1 + + return op_count + + def test_fixed(self): + iterations = 10000 + num_transforms = 3 + transform_counts = 3 * [0] + subset_size_counts = 4 * [0] + + s = SomeOf((A(), B(), C()), num_transforms=num_transforms) + + for _ in range(iterations): + output = s(1) + subset_size = self.update_transform_count(transform_counts, output) + subset_size_counts[subset_size] += 1 + + for i in range(3): + self.assertEqual(transform_counts[i], iterations) + + for i in range(3): + self.assertEqual(subset_size_counts[i], 0) + + self.assertEqual(subset_size_counts[3], iterations) + + def test_unfixed(self): + iterations = 10000 + num_transforms = (0, 3) + transform_counts = 3 * [0] + subset_size_counts = 4 * [0] + + s = SomeOf((A(), B(), C()), num_transforms=num_transforms) + + for _ in range(iterations): + output = s(1) + subset_size = self.update_transform_count(transform_counts, output) + subset_size_counts[subset_size] += 1 + + for i in range(3): + self.assertAlmostEqual(transform_counts[i] / iterations, 0.5, delta=0.01) + + for i in range(4): + self.assertAlmostEqual(subset_size_counts[i] / iterations, 0.25, delta=0.01) + + def test_non_dict_metatensor(self): + data = MetaTensor(1) + s = SomeOf([A()], num_transforms=1) + out = s(data) + self.assertEqual(out, 2) + inv = s.inverse(out) # A() is not invertible, nothing happens + self.assertEqual(inv, 2) + + @parameterized.expand(TEST_COMPOUND) + def test_compound_pipeline(self, transform, expected_value): + output = transform(1) + self.assertEqual(output, expected_value) + + # Modified from RandomOrder + @parameterized.expand(TEST_INVERSES) + def test_inverse(self, transform, invertible, use_metatensor): + data = {k: (i + 1) * 10.0 if not use_metatensor else MetaTensor((i + 1) * 10.0) for i, k in enumerate(KEYS)} + fwd_data1 = transform(data) + # test call twice won't affect inverse + fwd_data2 = transform(data) + + if invertible: + for k in KEYS: + t = ( + fwd_data1[TraceableTransform.trace_key(k)][-1] + if not use_metatensor + else fwd_data1[k].applied_operations[-1] + ) + # make sure the SomeOf applied_order was stored + self.assertEqual(t[TraceKeys.CLASS_NAME], SomeOf.__name__) + + # call the inverse + fwd_inv_data1 = transform.inverse(fwd_data1) + fwd_inv_data2 = transform.inverse(fwd_data2) + + fwd_data = [fwd_data1, fwd_data2] + fwd_inv_data = [fwd_inv_data1, fwd_inv_data2] + for i, _fwd_inv_data in enumerate(fwd_inv_data): + if invertible: + for k in KEYS: + # check transform was removed + if not use_metatensor: + self.assertTrue( + len(_fwd_inv_data[TraceableTransform.trace_key(k)]) + < len(fwd_data[i][TraceableTransform.trace_key(k)]) + ) + # check data is same as original (and different from forward) + self.assertEqual(_fwd_inv_data[k], data[k]) + self.assertNotEqual(_fwd_inv_data[k], fwd_data[i][k]) + else: + # if not invertible, should not change the data + self.assertDictEqual(fwd_data[i], _fwd_inv_data) + + def test_bad_inverse_data(self): + tr = SomeOf((A(), B(), C()), num_transforms=1, weights=(1, 2, 1)) + self.assertRaises(RuntimeError, tr.inverse, []) + + def test_normalize_weights(self): + tr = SomeOf((A(), B(), C()), num_transforms=1, weights=(1, 2, 1)) + self.assertTupleEqual(tr.weights, (0.25, 0.5, 0.25)) + + tr = SomeOf((), num_transforms=1, weights=(1, 2, 1)) + self.assertIsNone(tr.weights) + + def test_no_weights_arg(self): + tr = SomeOf((A(), B(), C(), D()), num_transforms=1) + self.assertIsNone(tr.weights) + + def test_bad_weights(self): + self.assertRaises(ValueError, SomeOf, (A(), B(), C()), num_transforms=1, weights=(1, 2)) + self.assertRaises(ValueError, SomeOf, (A(), B(), C()), num_transforms=1, weights=(0, 0, 0)) + self.assertRaises(ValueError, SomeOf, (A(), B(), C()), num_transforms=1, weights=(-1, 1, 1)) + + def test_bad_num_transforms(self): + self.assertRaises(ValueError, SomeOf, (A(), B(), C()), num_transforms=(-1, 2)) + self.assertRaises(ValueError, SomeOf, (A(), B(), C()), num_transforms="str") + self.assertRaises(ValueError, SomeOf, (A(), B(), C()), num_transforms=(1, 2, 3)) + self.assertRaises(ValueError, SomeOf, (A(), B(), C()), num_transforms=("a", 1)) + + +if __name__ == "__main__": + unittest.main()