From 33a1cbbf4778f97480feb0d1fbf20fe27f3a853a Mon Sep 17 00:00:00 2001 From: Puyang Wang Date: Mon, 31 Mar 2025 12:44:37 -0400 Subject: [PATCH 1/8] Fix: correctly apply fftshift to real-valued data inputs Correctly apply fftshift to real-valued data inputs Signed-off-by: Puyang Wang --- monai/networks/blocks/fft_utils_t.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/monai/networks/blocks/fft_utils_t.py b/monai/networks/blocks/fft_utils_t.py index 2aa4054f64..243e6f777f 100644 --- a/monai/networks/blocks/fft_utils_t.py +++ b/monai/networks/blocks/fft_utils_t.py @@ -187,20 +187,19 @@ def fftn_centered_t(im: Tensor, spatial_dims: int, is_complex: bool = True) -> T output2 = fftn_centered(im, spatial_dims=2, is_complex=True) """ # define spatial dims to perform ifftshift, fftshift, and fft - shift = list(range(-spatial_dims, 0)) + dims = list(range(-spatial_dims, 0)) if is_complex: if im.shape[-1] != 2: raise ValueError(f"img.shape[-1] is not 2 ({im.shape[-1]}).") - shift = list(range(-spatial_dims - 1, -1)) - dims = list(range(-spatial_dims, 0)) - - x = ifftshift(im, shift) + x = ifftshift(im, [d - 1 for d in dims]) + else: + x = ifftshift(im, dims) if is_complex: x = torch.view_as_real(torch.fft.fftn(torch.view_as_complex(x), dim=dims, norm="ortho")) else: x = torch.view_as_real(torch.fft.fftn(x, dim=dims, norm="ortho")) - out: Tensor = fftshift(x, shift) + out: Tensor = fftshift(x, [d - 1 for d in dims]) return out From 8e428de8e6a40ebfb45f85201d798c10778ce935 Mon Sep 17 00:00:00 2001 From: Puyang Wang Date: Mon, 31 Mar 2025 13:08:09 -0400 Subject: [PATCH 2/8] Update test_fft_utils.py Signed-off-by: Puyang Wang --- tests/data/test_fft_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/data/test_fft_utils.py b/tests/data/test_fft_utils.py index f09cb26ae4..bceccad979 100644 --- a/tests/data/test_fft_utils.py +++ b/tests/data/test_fft_utils.py @@ -21,7 +21,7 @@ # im = [[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]] res = [ - [[[0.0, 0.0], [0.0, 3.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]] + [[[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 3.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]] ] TESTS = [] for p in TEST_NDARRAYS: From c1596fd0126e1e7d5c05eeb2e6e9f4ffadc3603b Mon Sep 17 00:00:00 2001 From: Puyang Wang Date: Mon, 31 Mar 2025 13:24:19 -0400 Subject: [PATCH 3/8] Update test_fft_utils.py Signed-off-by: Puyang Wang --- tests/data/test_fft_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/data/test_fft_utils.py b/tests/data/test_fft_utils.py index bceccad979..da038203bc 100644 --- a/tests/data/test_fft_utils.py +++ b/tests/data/test_fft_utils.py @@ -21,7 +21,7 @@ # im = [[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]] res = [ - [[[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 3.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]] + [[[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [3.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]] ] TESTS = [] for p in TEST_NDARRAYS: From 2779f2b4a09459303c0f2336b13dc4d8d1a119cc Mon Sep 17 00:00:00 2001 From: Puyang Wang Date: Tue, 1 Apr 2025 15:36:17 -0400 Subject: [PATCH 4/8] fewer lines Signed-off-by: Puyang Wang --- monai/networks/blocks/fft_utils_t.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/monai/networks/blocks/fft_utils_t.py b/monai/networks/blocks/fft_utils_t.py index 243e6f777f..5900458eaf 100644 --- a/monai/networks/blocks/fft_utils_t.py +++ b/monai/networks/blocks/fft_utils_t.py @@ -191,14 +191,11 @@ def fftn_centered_t(im: Tensor, spatial_dims: int, is_complex: bool = True) -> T if is_complex: if im.shape[-1] != 2: raise ValueError(f"img.shape[-1] is not 2 ({im.shape[-1]}).") - x = ifftshift(im, [d - 1 for d in dims]) + x = torch.view_as_complex(ifftshift(im, [d - 1 for d in dims])) else: x = ifftshift(im, dims) - if is_complex: - x = torch.view_as_real(torch.fft.fftn(torch.view_as_complex(x), dim=dims, norm="ortho")) - else: - x = torch.view_as_real(torch.fft.fftn(x, dim=dims, norm="ortho")) + x = torch.view_as_real(torch.fft.fftn(x, dim=dims, norm="ortho")) out: Tensor = fftshift(x, [d - 1 for d in dims]) From d1b13f2cc66eb0ccf5ac78e04cdd4c7f5445c768 Mon Sep 17 00:00:00 2001 From: Puyang Wang Date: Tue, 1 Apr 2025 16:59:01 -0400 Subject: [PATCH 5/8] fix fftn_centered_t and ifftn_centered_t Signed-off-by: Puyang Wang --- monai/networks/blocks/fft_utils_t.py | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/monai/networks/blocks/fft_utils_t.py b/monai/networks/blocks/fft_utils_t.py index 5900458eaf..4fa6768935 100644 --- a/monai/networks/blocks/fft_utils_t.py +++ b/monai/networks/blocks/fft_utils_t.py @@ -139,21 +139,17 @@ def ifftn_centered_t(ksp: Tensor, spatial_dims: int, is_complex: bool = True) -> output2 = ifftn_centered(ksp, spatial_dims=2, is_complex=True) """ # define spatial dims to perform ifftshift, fftshift, and ifft - shift = list(range(-spatial_dims, 0)) + dims = list(range(-spatial_dims, 0)) if is_complex: if ksp.shape[-1] != 2: raise ValueError(f"ksp.shape[-1] is not 2 ({ksp.shape[-1]}).") - shift = list(range(-spatial_dims - 1, -1)) - dims = list(range(-spatial_dims, 0)) + x = torch.view_as_complex(ksp) - x = ifftshift(ksp, shift) + x = ifftshift(ksp, dims) - if is_complex: - x = torch.view_as_real(torch.fft.ifftn(torch.view_as_complex(x), dim=dims, norm="ortho")) - else: - x = torch.view_as_real(torch.fft.ifftn(x, dim=dims, norm="ortho")) + x = fftshift(torch.fft.ifftn(x, dim=dims, norm="ortho"), dims) - out: Tensor = fftshift(x, shift) + out: Tensor = torch.view_as_real(x) return out @@ -191,12 +187,12 @@ def fftn_centered_t(im: Tensor, spatial_dims: int, is_complex: bool = True) -> T if is_complex: if im.shape[-1] != 2: raise ValueError(f"img.shape[-1] is not 2 ({im.shape[-1]}).") - x = torch.view_as_complex(ifftshift(im, [d - 1 for d in dims])) - else: - x = ifftshift(im, dims) + x = torch.view_as_complex(im) + + x = ifftshift(im, dims) - x = torch.view_as_real(torch.fft.fftn(x, dim=dims, norm="ortho")) + x = fftshift(torch.fft.fftn(x, dim=dims, norm="ortho"), dims) - out: Tensor = fftshift(x, [d - 1 for d in dims]) + out: Tensor = torch.view_as_real(x) return out From 06b377b6c7597c64a7bbe560c5c72353b8d5930c Mon Sep 17 00:00:00 2001 From: Puyang Wang Date: Tue, 1 Apr 2025 17:22:55 -0400 Subject: [PATCH 6/8] fix Signed-off-by: Puyang Wang --- monai/networks/blocks/fft_utils_t.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/blocks/fft_utils_t.py b/monai/networks/blocks/fft_utils_t.py index 4fa6768935..dd552d3a1e 100644 --- a/monai/networks/blocks/fft_utils_t.py +++ b/monai/networks/blocks/fft_utils_t.py @@ -145,7 +145,7 @@ def ifftn_centered_t(ksp: Tensor, spatial_dims: int, is_complex: bool = True) -> raise ValueError(f"ksp.shape[-1] is not 2 ({ksp.shape[-1]}).") x = torch.view_as_complex(ksp) - x = ifftshift(ksp, dims) + x = ifftshift(x, dims) x = fftshift(torch.fft.ifftn(x, dim=dims, norm="ortho"), dims) From 6f6f94e269b8a51bf0c46208a20bf087f37ce92a Mon Sep 17 00:00:00 2001 From: Puyang Wang Date: Tue, 1 Apr 2025 17:42:22 -0400 Subject: [PATCH 7/8] roll back Signed-off-by: Puyang Wang --- monai/networks/blocks/fft_utils_t.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/monai/networks/blocks/fft_utils_t.py b/monai/networks/blocks/fft_utils_t.py index dd552d3a1e..91c4d1baed 100644 --- a/monai/networks/blocks/fft_utils_t.py +++ b/monai/networks/blocks/fft_utils_t.py @@ -143,13 +143,13 @@ def ifftn_centered_t(ksp: Tensor, spatial_dims: int, is_complex: bool = True) -> if is_complex: if ksp.shape[-1] != 2: raise ValueError(f"ksp.shape[-1] is not 2 ({ksp.shape[-1]}).") - x = torch.view_as_complex(ksp) + x = torch.view_as_complex(ifftshift(ksp, [d - 1 for d in dims])) + else: + x = ifftshift(ksp, dims) - x = ifftshift(x, dims) + x = torch.view_as_real(torch.fft.fftn(x, dim=dims, norm="ortho")) - x = fftshift(torch.fft.ifftn(x, dim=dims, norm="ortho"), dims) - - out: Tensor = torch.view_as_real(x) + out: Tensor = fftshift(x, [d - 1 for d in dims]) return out @@ -187,12 +187,12 @@ def fftn_centered_t(im: Tensor, spatial_dims: int, is_complex: bool = True) -> T if is_complex: if im.shape[-1] != 2: raise ValueError(f"img.shape[-1] is not 2 ({im.shape[-1]}).") - x = torch.view_as_complex(im) - - x = ifftshift(im, dims) + x = torch.view_as_complex(ifftshift(im, [d - 1 for d in dims])) + else: + x = ifftshift(im, dims) - x = fftshift(torch.fft.fftn(x, dim=dims, norm="ortho"), dims) + x = torch.view_as_real(torch.fft.fftn(x, dim=dims, norm="ortho")) - out: Tensor = torch.view_as_real(x) + out: Tensor = fftshift(x, [d - 1 for d in dims]) return out From feed7af0fb7be7fbb74fdc59ae2e013fb078f355 Mon Sep 17 00:00:00 2001 From: Puyang Wang Date: Tue, 1 Apr 2025 18:21:30 -0400 Subject: [PATCH 8/8] typo Signed-off-by: Puyang Wang --- monai/networks/blocks/fft_utils_t.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/blocks/fft_utils_t.py b/monai/networks/blocks/fft_utils_t.py index 91c4d1baed..18079ca8c3 100644 --- a/monai/networks/blocks/fft_utils_t.py +++ b/monai/networks/blocks/fft_utils_t.py @@ -147,7 +147,7 @@ def ifftn_centered_t(ksp: Tensor, spatial_dims: int, is_complex: bool = True) -> else: x = ifftshift(ksp, dims) - x = torch.view_as_real(torch.fft.fftn(x, dim=dims, norm="ortho")) + x = torch.view_as_real(torch.fft.ifftn(x, dim=dims, norm="ortho")) out: Tensor = fftshift(x, [d - 1 for d in dims])