diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index b75a18dec1..165d9b732f 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -256,12 +256,9 @@ def inverse(self, data): # and then remove the OneOf transform self.pop_transform(data, key) if index is None: - raise RuntimeError("No invertible transforms have been applied") + # no invertible transforms have been applied + return data - # if applied transform is not InvertibleTransform, throw error _transform = self.transforms[index] - if not isinstance(_transform, InvertibleTransform): - raise RuntimeError(f"Applied OneOf transform is not invertible (applied index: {index}).") - # apply the inverse - return _transform.inverse(data) + return _transform.inverse(data) if isinstance(_transform, InvertibleTransform) else data diff --git a/tests/test_one_of.py b/tests/test_one_of.py index a7cd09f10b..29d13d7d0c 100644 --- a/tests/test_one_of.py +++ b/tests/test_one_of.py @@ -12,9 +12,18 @@ import unittest from copy import deepcopy +import numpy as np from parameterized import parameterized -from monai.transforms import InvertibleTransform, OneOf, TraceableTransform, Transform +from monai.transforms import ( + InvertibleTransform, + OneOf, + RandScaleIntensityd, + RandShiftIntensityd, + Resized, + TraceableTransform, + Transform, +) from monai.transforms.compose import Compose from monai.transforms.transform import MapTransform from monai.utils.enums import TraceKeys @@ -139,32 +148,52 @@ def _match(a, b): _match(p, f) @parameterized.expand(TEST_INVERSES) - def test_inverse(self, transform, should_be_ok): + def test_inverse(self, transform, invertible): data = {k: (i + 1) * 10.0 for i, k in enumerate(KEYS)} fwd_data = transform(data) - if not should_be_ok: - with self.assertRaises(RuntimeError): - transform.inverse(fwd_data) - return - - for k in KEYS: - t = fwd_data[TraceableTransform.trace_key(k)][-1] - # make sure the OneOf index was stored - self.assertEqual(t[TraceKeys.CLASS_NAME], OneOf.__name__) - # make sure index exists and is in bounds - self.assertTrue(0 <= t[TraceKeys.EXTRA_INFO]["index"] < len(transform)) + + if invertible: + for k in KEYS: + t = fwd_data[TraceableTransform.trace_key(k)][-1] + # make sure the OneOf index was stored + self.assertEqual(t[TraceKeys.CLASS_NAME], OneOf.__name__) + # make sure index exists and is in bounds + self.assertTrue(0 <= t[TraceKeys.EXTRA_INFO]["index"] < len(transform)) # call the inverse fwd_inv_data = transform.inverse(fwd_data) - for k in KEYS: - # check transform was removed - self.assertTrue( - len(fwd_inv_data[TraceableTransform.trace_key(k)]) < len(fwd_data[TraceableTransform.trace_key(k)]) - ) - # check data is same as original (and different from forward) - self.assertEqual(fwd_inv_data[k], data[k]) - self.assertNotEqual(fwd_inv_data[k], fwd_data[k]) + if invertible: + for k in KEYS: + # check transform was removed + self.assertTrue( + len(fwd_inv_data[TraceableTransform.trace_key(k)]) < len(fwd_data[TraceableTransform.trace_key(k)]) + ) + # check data is same as original (and different from forward) + self.assertEqual(fwd_inv_data[k], data[k]) + self.assertNotEqual(fwd_inv_data[k], fwd_data[k]) + else: + # if not invertible, should not change the data + self.assertDictEqual(fwd_data, fwd_inv_data) + + def test_inverse_compose(self): + transform = Compose( + [ + Resized(keys="img", spatial_size=[100, 100, 100]), + OneOf( + [ + RandScaleIntensityd(keys="img", factors=0.5, prob=1.0), + RandShiftIntensityd(keys="img", offsets=0.5, prob=1.0), + ] + ), + ] + ) + transform.set_random_state(seed=0) + result = transform({"img": np.ones((1, 101, 102, 103))}) + + result = transform.inverse(result) + # invert to the original spatial shape + self.assertTupleEqual(result["img"].shape, (1, 101, 102, 103)) def test_one_of(self): p = OneOf((A(), B(), C()), (1, 2, 1))