From 6ee87a05870e2c681d264e28b437998212b605ce Mon Sep 17 00:00:00 2001 From: "axel.vlaminck" Date: Wed, 17 Jan 2024 00:31:15 +0100 Subject: [PATCH 1/4] Add test for passing applied operations through ImageFilter Signed-off-by: axel.vlaminck --- tests/test_image_filter.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/test_image_filter.py b/tests/test_image_filter.py index 841a5d5cd5..5f53871940 100644 --- a/tests/test_image_filter.py +++ b/tests/test_image_filter.py @@ -17,6 +17,7 @@ import torch from parameterized import parameterized +from monai.data.meta_tensor import MetaTensor from monai.networks.layers.simplelayers import GaussianFilter from monai.transforms import ImageFilter, ImageFilterd, RandImageFilter, RandImageFilterd @@ -115,6 +116,14 @@ def test_call_3d(self, filter_name): out_tensor = filter(SAMPLE_IMAGE_3D) self.assertEqual(out_tensor.shape[1:], SAMPLE_IMAGE_3D.shape[1:]) + def test_pass_applied_operations(self): + "Test that applied operations are passed through" + applied_operations = ["op1", "op2"] + image = MetaTensor(SAMPLE_IMAGE_2D, applied_operations=applied_operations) + filter = ImageFilter(SUPPORTED_FILTERS[0], 3, **ADDITIONAL_ARGUMENTS) + out_tensor = filter(image) + self.assertEqual(out_tensor.applied_operations, applied_operations) + class TestImageFilterDict(unittest.TestCase): @parameterized.expand(SUPPORTED_FILTERS) From c13338b72e779780f288e17073435cee90a58983 Mon Sep 17 00:00:00 2001 From: "axel.vlaminck" Date: Wed, 17 Jan 2024 00:32:12 +0100 Subject: [PATCH 2/4] Add support for tracking applied operations in ImageFilter Signed-off-by: axel.vlaminck --- monai/transforms/utility/array.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 2322f2123f..5e377d1f4e 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1562,17 +1562,22 @@ def __init__(self, filter: str | NdarrayOrTensor | nn.Module, filter_size: int | self.filter_size = filter_size self.additional_args_for_filter = kwargs - def __call__(self, img: NdarrayOrTensor, meta_dict: dict | None = None) -> NdarrayOrTensor: + def __call__( + self, img: NdarrayOrTensor, meta_dict: dict | None = None, applied_operations: list | None = None + ) -> NdarrayOrTensor: """ Args: img: torch tensor data to apply filter to with shape: [channels, height, width[, depth]] meta_dict: An optional dictionary with metadata + applied_operations: An optional list of operations that have been applied to the data Returns: A MetaTensor with the same shape as `img` and identical metadata """ if isinstance(img, MetaTensor): meta_dict = img.meta + applied_operations = img.applied_operations + img_, prev_type, device = convert_data_type(img, torch.Tensor) ndim = img_.ndim - 1 # assumes channel first format @@ -1582,8 +1587,8 @@ def __call__(self, img: NdarrayOrTensor, meta_dict: dict | None = None) -> Ndarr self.filter = ApplyFilter(self.filter) img_ = self._apply_filter(img_) - if meta_dict: - img_ = MetaTensor(img_, meta=meta_dict) + if meta_dict or applied_operations is not None: + img_ = MetaTensor(img_, meta=meta_dict, applied_operations=applied_operations) else: img_, *_ = convert_data_type(img_, prev_type, device) return img_ From 5e2a8a7c0db4959a0cf5af5607fe051ed35168eb Mon Sep 17 00:00:00 2001 From: "axel.vlaminck" Date: Thu, 18 Jan 2024 13:03:39 +0100 Subject: [PATCH 3/4] Create metatensor when passing a empty meta dict Signed-off-by: axel.vlaminck --- monai/transforms/utility/array.py | 2 +- tests/test_image_filter.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 5e377d1f4e..5dfbcb0e91 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1587,7 +1587,7 @@ def __call__( self.filter = ApplyFilter(self.filter) img_ = self._apply_filter(img_) - if meta_dict or applied_operations is not None: + if meta_dict is not None or applied_operations is not None: img_ = MetaTensor(img_, meta=meta_dict, applied_operations=applied_operations) else: img_, *_ = convert_data_type(img_, prev_type, device) diff --git a/tests/test_image_filter.py b/tests/test_image_filter.py index 5f53871940..02a9f74d81 100644 --- a/tests/test_image_filter.py +++ b/tests/test_image_filter.py @@ -124,6 +124,12 @@ def test_pass_applied_operations(self): out_tensor = filter(image) self.assertEqual(out_tensor.applied_operations, applied_operations) + def test_pass_empty_metadata_dict(self): + "Test that applied operations are passed through" + image = MetaTensor(SAMPLE_IMAGE_2D, meta={}) + filter = ImageFilter(SUPPORTED_FILTERS[0], 3, **ADDITIONAL_ARGUMENTS) + out_tensor = filter(image) + self.assertTrue(isinstance(out_tensor, MetaTensor)) class TestImageFilterDict(unittest.TestCase): @parameterized.expand(SUPPORTED_FILTERS) From 940b5cb5d7b100365fde266a3286516b5ef74ec8 Mon Sep 17 00:00:00 2001 From: "axel.vlaminck" Date: Thu, 18 Jan 2024 16:25:25 +0100 Subject: [PATCH 4/4] fix formatting Signed-off-by: axel.vlaminck --- tests/test_image_filter.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_image_filter.py b/tests/test_image_filter.py index 02a9f74d81..985ea95e79 100644 --- a/tests/test_image_filter.py +++ b/tests/test_image_filter.py @@ -131,6 +131,7 @@ def test_pass_empty_metadata_dict(self): out_tensor = filter(image) self.assertTrue(isinstance(out_tensor, MetaTensor)) + class TestImageFilterDict(unittest.TestCase): @parameterized.expand(SUPPORTED_FILTERS) def test_init_from_string_dict(self, filter_name):