From 7040b4747b95946f38c3148fc0e936828242a277 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 17 Sep 2021 11:30:54 +0800 Subject: [PATCH 1/3] [DLMED] enhance tensor transforms Signed-off-by: Nic Ma --- monai/transforms/intensity/array.py | 6 ++++-- monai/transforms/spatial/array.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index a1423c8ee5..6a9dd500a8 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -955,7 +955,7 @@ def __init__(self, window_length: int, order: int, axis: int = 1, mode: str = "z self.mode = mode self.img_t: torch.Tensor = torch.tensor(0.0) - def __call__(self, img: NdarrayOrTensor) -> torch.Tensor: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Args: img: array containing input data. Must be real and in shape [channels, spatial1, spatial2, ...]. @@ -969,7 +969,9 @@ def __call__(self, img: NdarrayOrTensor) -> torch.Tensor: # add one to transform axis because a batch axis will be added at dimension 0 savgol_filter = SavitzkyGolayFilter(self.window_length, self.order, self.axis + 1, self.mode) # convert to Tensor and add Batch axis expected by HilbertTransform - out: torch.Tensor = savgol_filter(self.img_t.unsqueeze(0)).squeeze(0) + out = savgol_filter(self.img_t.unsqueeze(0)).squeeze(0) + out, *_ = convert_to_dst_type(out, dst=img) + return out diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 8b1fb854f2..7e9f7f407c 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -571,7 +571,7 @@ class Zoom(Transform): """ - backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + backend = [TransformBackends.TORCH] def __init__( self, From 8eb4dd6b2a8d8267565bf4704a3f1c151864be62 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 17 Sep 2021 23:37:21 +0800 Subject: [PATCH 2/3] [DLMED] fix tests Signed-off-by: Nic Ma --- tests/test_savitzky_golay_smooth.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_savitzky_golay_smooth.py b/tests/test_savitzky_golay_smooth.py index 45d0ea3e4d..71405466c4 100644 --- a/tests/test_savitzky_golay_smooth.py +++ b/tests/test_savitzky_golay_smooth.py @@ -62,7 +62,7 @@ class TestSavitzkyGolaySmooth(unittest.TestCase): @parameterized.expand([TEST_CASE_SINGLE_VALUE, TEST_CASE_2D_AXIS_2, TEST_CASE_SINE_SMOOTH]) def test_value(self, arguments, image, expected_data, atol): for p in TEST_NDARRAYS: - result = SavitzkyGolaySmooth(**arguments)(p(image)) + result = SavitzkyGolaySmooth(**arguments)(p(image.astype(np.float32))) torch.testing.assert_allclose(result, p(expected_data.astype(np.float32)), rtol=1e-7, atol=atol) @@ -70,7 +70,7 @@ class TestSavitzkyGolaySmoothREP(unittest.TestCase): @parameterized.expand([TEST_CASE_SINGLE_VALUE_REP]) def test_value(self, arguments, image, expected_data, atol): for p in TEST_NDARRAYS: - result = SavitzkyGolaySmooth(**arguments)(p(image)) + result = SavitzkyGolaySmooth(**arguments)(p(image.astype(np.float32))) torch.testing.assert_allclose(result, p(expected_data.astype(np.float32)), rtol=1e-7, atol=atol) From 439034040e7a6fe5176a24c3d810f4fcc02999ba Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sat, 18 Sep 2021 00:12:49 +0800 Subject: [PATCH 3/3] [DLMED] fix mypy Signed-off-by: Nic Ma --- monai/transforms/intensity/array.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 6a9dd500a8..a8babfc659 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -969,8 +969,8 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: # add one to transform axis because a batch axis will be added at dimension 0 savgol_filter = SavitzkyGolayFilter(self.window_length, self.order, self.axis + 1, self.mode) # convert to Tensor and add Batch axis expected by HilbertTransform - out = savgol_filter(self.img_t.unsqueeze(0)).squeeze(0) - out, *_ = convert_to_dst_type(out, dst=img) + smoothed = savgol_filter(self.img_t.unsqueeze(0)).squeeze(0) + out, *_ = convert_to_dst_type(smoothed, dst=img) return out