diff --git a/monai/networks/blocks/fft_utils_t.py b/monai/networks/blocks/fft_utils_t.py index 2aa4054f64..18079ca8c3 100644 --- a/monai/networks/blocks/fft_utils_t.py +++ b/monai/networks/blocks/fft_utils_t.py @@ -139,21 +139,17 @@ def ifftn_centered_t(ksp: Tensor, spatial_dims: int, is_complex: bool = True) -> output2 = ifftn_centered(ksp, spatial_dims=2, is_complex=True) """ # define spatial dims to perform ifftshift, fftshift, and ifft - shift = list(range(-spatial_dims, 0)) + dims = list(range(-spatial_dims, 0)) if is_complex: if ksp.shape[-1] != 2: raise ValueError(f"ksp.shape[-1] is not 2 ({ksp.shape[-1]}).") - shift = list(range(-spatial_dims - 1, -1)) - dims = list(range(-spatial_dims, 0)) - - x = ifftshift(ksp, shift) - - if is_complex: - x = torch.view_as_real(torch.fft.ifftn(torch.view_as_complex(x), dim=dims, norm="ortho")) + x = torch.view_as_complex(ifftshift(ksp, [d - 1 for d in dims])) else: - x = torch.view_as_real(torch.fft.ifftn(x, dim=dims, norm="ortho")) + x = ifftshift(ksp, dims) + + x = torch.view_as_real(torch.fft.ifftn(x, dim=dims, norm="ortho")) - out: Tensor = fftshift(x, shift) + out: Tensor = fftshift(x, [d - 1 for d in dims]) return out @@ -187,20 +183,16 @@ def fftn_centered_t(im: Tensor, spatial_dims: int, is_complex: bool = True) -> T output2 = fftn_centered(im, spatial_dims=2, is_complex=True) """ # define spatial dims to perform ifftshift, fftshift, and fft - shift = list(range(-spatial_dims, 0)) + dims = list(range(-spatial_dims, 0)) if is_complex: if im.shape[-1] != 2: raise ValueError(f"img.shape[-1] is not 2 ({im.shape[-1]}).") - shift = list(range(-spatial_dims - 1, -1)) - dims = list(range(-spatial_dims, 0)) - - x = ifftshift(im, shift) - - if is_complex: - x = torch.view_as_real(torch.fft.fftn(torch.view_as_complex(x), dim=dims, norm="ortho")) + x = torch.view_as_complex(ifftshift(im, [d - 1 for d in dims])) else: - x = torch.view_as_real(torch.fft.fftn(x, dim=dims, norm="ortho")) + x = ifftshift(im, dims) + + x = torch.view_as_real(torch.fft.fftn(x, dim=dims, norm="ortho")) - out: Tensor = fftshift(x, shift) + out: Tensor = fftshift(x, [d - 1 for d in dims]) return out diff --git a/tests/data/test_fft_utils.py b/tests/data/test_fft_utils.py index f09cb26ae4..da038203bc 100644 --- a/tests/data/test_fft_utils.py +++ b/tests/data/test_fft_utils.py @@ -21,7 +21,7 @@ # im = [[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]] res = [ - [[[0.0, 0.0], [0.0, 3.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]] + [[[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [3.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]] ] TESTS = [] for p in TEST_NDARRAYS: