From f6cb8e5e1e474eaf5f39d76b83d8df053f935b74 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Fri, 9 Sep 2022 14:19:50 +0000 Subject: [PATCH 1/2] Update splitdimd with the option of list output Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/transforms/utility/dictionary.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index cde6bd8cc2..2039b3a5a5 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -384,6 +384,7 @@ def __init__( dim: int = 0, keepdim: bool = True, update_meta: bool = True, + list_output: bool = False, allow_missing_keys: bool = False, ) -> None: """ @@ -399,15 +400,34 @@ def __init__( dimension will be squeezed. update_meta: if `True`, copy `[key]_meta_dict` for each output and update affine to reflect the cropped image + list_output: it `True`, the output will be a list of dictionaries with the same keys as original. allow_missing_keys: don't raise exception if key is missing. """ super().__init__(keys, allow_missing_keys) self.output_postfixes = output_postfixes self.splitter = SplitDim(dim, keepdim, update_meta) + self.list_output = list_output + if self.list_output is None and self.output_postfixes is not None: + raise ValueError("`output_postfixes` should not be provided when `list_output` is `True`.") - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__( + self, data: Mapping[Hashable, torch.Tensor] + ) -> Union[Dict[Hashable, torch.Tensor], List[Dict[Hashable, torch.Tensor]]]: d = dict(data) - for key in self.key_iterator(d): + all_keys = list(set(self.key_iterator(d))) + + if self.list_output: + output = [] + results = [self.splitter(d[key]) for key in all_keys] + for row in zip(*results): + new_dict = {k: v for k, v in zip(all_keys, row)} + # fill in the extra keys with unmodified data + for k in set(d.keys()).difference(set(all_keys)): + new_dict[k] = deepcopy(d[k]) + output.append(new_dict) + return output + + for key in all_keys: rets = self.splitter(d[key]) postfixes: Sequence = list(range(len(rets))) if self.output_postfixes is None else self.output_postfixes if len(postfixes) != len(rets): From 3564b38bd641ceb185f5548dd2cc80ccffb23e87 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Fri, 9 Sep 2022 14:20:06 +0000 Subject: [PATCH 2/2] Upadate unittest to include list_output Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- tests/test_splitdimd.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/tests/test_splitdimd.py b/tests/test_splitdimd.py index 1e39439b86..ee8cc043e4 100644 --- a/tests/test_splitdimd.py +++ b/tests/test_splitdimd.py @@ -25,7 +25,8 @@ for p in TEST_NDARRAYS: for keepdim in (True, False): for update_meta in (True, False): - TESTS.append((keepdim, p, update_meta)) + for list_output in (True, False): + TESTS.append((keepdim, p, update_meta, list_output)) class TestSplitDimd(unittest.TestCase): @@ -39,14 +40,18 @@ def setUpClass(cls): cls.data: MetaTensor = loader(data) @parameterized.expand(TESTS) - def test_correct(self, keepdim, im_type, update_meta): + def test_correct(self, keepdim, im_type, update_meta, list_output): data = deepcopy(self.data) data["i"] = im_type(data["i"]) arr = data["i"] for dim in range(arr.ndim): - out = SplitDimd("i", dim=dim, keepdim=keepdim, update_meta=update_meta)(data) - self.assertIsInstance(out, dict) - self.assertEqual(len(out.keys()), len(data.keys()) + arr.shape[dim]) + out = SplitDimd("i", dim=dim, keepdim=keepdim, update_meta=update_meta, list_output=list_output)(data) + if list_output: + self.assertIsInstance(out, list) + self.assertEqual(len(out), arr.shape[dim]) + else: + self.assertIsInstance(out, dict) + self.assertEqual(len(out.keys()), len(data.keys()) + arr.shape[dim]) # if updating metadata, pick some random points and # check same world coordinates between input and output if update_meta: @@ -55,14 +60,20 @@ def test_correct(self, keepdim, im_type, update_meta): split_im_idx = idx[dim] split_idx = deepcopy(idx) split_idx[dim] = 0 - split_im = out[f"i_{split_im_idx}"] + if list_output: + split_im = out[split_im_idx]["i"] + else: + split_im = out[f"i_{split_im_idx}"] if isinstance(data, MetaTensor) and isinstance(split_im, MetaTensor): # idx[1:] to remove channel and then add 1 for 4th element real_world = data.affine @ torch.tensor(idx[1:] + [1]).double() real_world2 = split_im.affine @ torch.tensor(split_idx[1:] + [1]).double() assert_allclose(real_world, real_world2) - out = out["i_0"] + if list_output: + out = out[0]["i"] + else: + out = out["i_0"] expected_ndim = arr.ndim if keepdim else arr.ndim - 1 self.assertEqual(out.ndim, expected_ndim) # assert is a shallow copy