diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 3e23377b36..2d1fe4eccd 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -231,6 +231,26 @@ def randomize(self, data: Optional[Any] = None) -> None: f'Transform "{tfm_name}" in Compose not randomized\n{tfm_name}.{type_error}.', RuntimeWarning ) + def flatten(self): + """Return a Composition with a simple list of transforms, as opposed to any nested Compositions. + + e.g., `t1 = Compose([x, x, x, x, Compose([Compose([x, x]), x, x])]).flatten()` + will result in the equivalent of `t1 = Compose([x, x, x, x, x, x, x, x])`. + + """ + new_transforms = [] + for t in self.transforms: + if isinstance(t, Compose): + new_transforms += t.flatten().transforms + else: + new_transforms.append(t) + + return Compose(new_transforms) + + def __len__(self): + """Return number of transformations.""" + return len(self.flatten().transforms) + def __call__(self, input_): for _transform in self.transforms: input_ = apply_transform(_transform, input_) diff --git a/tests/test_compose.py b/tests/test_compose.py index 3585b3453c..c049044a97 100644 --- a/tests/test_compose.py +++ b/tests/test_compose.py @@ -156,6 +156,17 @@ def test_data_loader_2(self): self.assertAlmostEqual(out_1.cpu().item(), 0.131966779) set_determinism(None) + def test_flatten_and_len(self): + x = AddChannel() + t1 = Compose([x, x, x, x, Compose([Compose([x, x]), x, x])]) + + t2 = t1.flatten() + for t in t2.transforms: + self.assertNotIsInstance(t, Compose) + + # test len + self.assertEqual(len(t1), 8) + if __name__ == "__main__": unittest.main()