From 90b9d8574d01bf0dbf5a76566f7ba2e934e8f825 Mon Sep 17 00:00:00 2001 From: Rich <33289025+rijobro@users.noreply.github.com> Date: Tue, 23 Feb 2021 15:11:55 +0000 Subject: [PATCH 1/2] Compose len Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/compose.py | 4 ++++ tests/test_compose.py | 7 +++++++ 2 files changed, 11 insertions(+) diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 3e23377b36..0bea767fa9 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -231,6 +231,10 @@ def randomize(self, data: Optional[Any] = None) -> None: f'Transform "{tfm_name}" in Compose not randomized\n{tfm_name}.{type_error}.', RuntimeWarning ) + def __len__(self): + """Return number of transformations.""" + return sum(len(t) if isinstance(t, Compose) else 1 for t in self.transforms) + def __call__(self, input_): for _transform in self.transforms: input_ = apply_transform(_transform, input_) diff --git a/tests/test_compose.py b/tests/test_compose.py index 3585b3453c..2103cdaa36 100644 --- a/tests/test_compose.py +++ b/tests/test_compose.py @@ -156,6 +156,13 @@ def test_data_loader_2(self): self.assertAlmostEqual(out_1.cpu().item(), 0.131966779) set_determinism(None) + def test_len(self): + x = AddChannel() + t = Compose([x, x, x, x, Compose([Compose([x, x]), x, x])]) + + # test len + self.assertEqual(len(t), 8) + if __name__ == "__main__": unittest.main() From 777f219bc4d486f08c50da1a3205f382615fb168 Mon Sep 17 00:00:00 2001 From: Rich <33289025+rijobro@users.noreply.github.com> Date: Tue, 23 Feb 2021 16:19:12 +0000 Subject: [PATCH 2/2] Compose.flatten() Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/compose.py | 18 +++++++++++++++++- tests/test_compose.py | 10 +++++++--- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 0bea767fa9..2d1fe4eccd 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -231,9 +231,25 @@ def randomize(self, data: Optional[Any] = None) -> None: f'Transform "{tfm_name}" in Compose not randomized\n{tfm_name}.{type_error}.', RuntimeWarning ) + def flatten(self): + """Return a Composition with a simple list of transforms, as opposed to any nested Compositions. + + e.g., `t1 = Compose([x, x, x, x, Compose([Compose([x, x]), x, x])]).flatten()` + will result in the equivalent of `t1 = Compose([x, x, x, x, x, x, x, x])`. + + """ + new_transforms = [] + for t in self.transforms: + if isinstance(t, Compose): + new_transforms += t.flatten().transforms + else: + new_transforms.append(t) + + return Compose(new_transforms) + def __len__(self): """Return number of transformations.""" - return sum(len(t) if isinstance(t, Compose) else 1 for t in self.transforms) + return len(self.flatten().transforms) def __call__(self, input_): for _transform in self.transforms: diff --git a/tests/test_compose.py b/tests/test_compose.py index 2103cdaa36..c049044a97 100644 --- a/tests/test_compose.py +++ b/tests/test_compose.py @@ -156,12 +156,16 @@ def test_data_loader_2(self): self.assertAlmostEqual(out_1.cpu().item(), 0.131966779) set_determinism(None) - def test_len(self): + def test_flatten_and_len(self): x = AddChannel() - t = Compose([x, x, x, x, Compose([Compose([x, x]), x, x])]) + t1 = Compose([x, x, x, x, Compose([Compose([x, x]), x, x])]) + + t2 = t1.flatten() + for t in t2.transforms: + self.assertNotIsInstance(t, Compose) # test len - self.assertEqual(len(t), 8) + self.assertEqual(len(t1), 8) if __name__ == "__main__":