From adcf47ba42e21bc907c1305aca767d2fd7b1c8b9 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Mon, 18 Jan 2021 21:58:27 +0000 Subject: [PATCH 1/3] 1452 add Warp layer and test Signed-off-by: kate-sann5100 --- monai/networks/blocks/__init__.py | 1 + monai/networks/blocks/warp.py | 96 +++++++++++++++++++++++++++++++ tests/test_warp.py | 91 +++++++++++++++++++++++++++++ 3 files changed, 188 insertions(+) create mode 100644 monai/networks/blocks/warp.py create mode 100644 tests/test_warp.py diff --git a/monai/networks/blocks/__init__.py b/monai/networks/blocks/__init__.py index c33feb4e2b..13c45d5c06 100644 --- a/monai/networks/blocks/__init__.py +++ b/monai/networks/blocks/__init__.py @@ -27,3 +27,4 @@ SEResNeXtBottleneck, ) from .upsample import SubpixelUpsample, Subpixelupsample, SubpixelUpSample, Upsample, UpSample +from .warp import Warp \ No newline at end of file diff --git a/monai/networks/blocks/warp.py b/monai/networks/blocks/warp.py new file mode 100644 index 0000000000..08cdde4212 --- /dev/null +++ b/monai/networks/blocks/warp.py @@ -0,0 +1,96 @@ +from typing import List, Optional, Union + +import torch +from torch import nn +from torch.nn import functional as F + +from monai.config import USE_COMPILED +from monai.networks.layers import grid_pull +from monai.utils import GridSampleMode, GridSamplePadMode + + +class Warp(nn.Module): + def __init__( + self, + spatial_dims: int, + mode: Optional[Union[GridSampleMode, str]] = GridSampleMode.BILINEAR, + padding_mode: Optional[Union[GridSamplePadMode, str]] = GridSamplePadMode.ZEROS, + ): + """ + warp an image with given DDF. + supports spatially 2D or 3D (num_channels, H, W[, D]). + Args: + spatial_dims: number of spatial dimensions + mode: {``"bilinear"``, ``"nearest"``} + Interpolation mode to calculate output values. Defaults to ``"bilinear"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. Defaults to ``"border"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + """ + super(Warp, self).__init__() + if spatial_dims not in [2, 3]: + raise ValueError(f"got unsupported spatial_dims = {spatial_dims}, only support 2-d and 3-d input") + self.spatial_dims = spatial_dims + self.mode: GridSampleMode = GridSampleMode(mode) + self.padding_mode: GridSamplePadMode = GridSamplePadMode(padding_mode) + + @staticmethod + def get_reference_grid(ddf: torch.Tensor) -> torch.Tensor: + mesh_points = [torch.arange(0, dim) for dim in ddf.shape[2:]] + grid = torch.stack(torch.meshgrid(*mesh_points[::-1]), dim=0) # (spatial_dims, ...) + grid = torch.stack([grid] * ddf.shape[0], dim=0) # (batch, spatial_dims, ...) + grid = grid.to(ddf) + return grid + + @staticmethod + def normalize_grid(image: torch.Tensor, grid: torch.Tensor) -> torch.Tensor: + # (batch, ..., self.spatial_dims) + for i, (img_dim, ddf_dim) in enumerate(zip(image.shape[2:], grid.shape[1:-1])): + grid[..., i] = (1 - ddf_dim + grid[..., i] * 2) / (img_dim - 1) + return grid + + def forward(self, image: torch.Tensor, ddf: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Tensor in shape (batch, num_channels, H, W[, D]) + ddf: Tensor in shape (batch, num_channels, H, W[, D]) + + Returns: + warped_image: Tensor in the same shape as ddf (batch, num_channels, H, W[, D]) + """ + if len(image.shape) != 2 + self.spatial_dims: + raise ValueError(f"expecting {self.spatial_dims + 2}-d input, " f"got input in shape {image.shape}") + if len(ddf.shape) != 2 + self.spatial_dims or ddf.shape[1] != self.spatial_dims: + raise ValueError( + f"expecting {self.spatial_dims + 2}-d ddf with {self.spatial_dims} channels, " + f"got ddf in shape {ddf.shape}" + ) + + grid = self.get_reference_grid(ddf) + ddf + grid = grid.permute([0] + list(range(2, 2 + self.spatial_dims)) + [1]) # (batch, ..., self.spatial_dims) + + if USE_COMPILED: + _padding_mode = self.padding_mode.value + if _padding_mode == "zeros": + bound = 7 + elif _padding_mode == "border": + bound = 0 + else: + bound = 1 + _interp_mode = self.mode.value + warped_image: torch.Tensor = grid_pull( + image, + grid, + bound=bound, + extrapolate=True, + interpolation=1 if _interp_mode == "bilinear" else _interp_mode, + ) + else: + grid = self.normalize_grid(image, grid) + index_ordering: List[int] = list(range(self.spatial_dims - 1, -1, -1)) + grid = grid[..., index_ordering] # z, y, x -> x, y, z + warped_image = F.grid_sample( + image, grid, mode=self.mode.value, padding_mode=self.padding_mode.value, align_corners=True + ) + return warped_image diff --git a/tests/test_warp.py b/tests/test_warp.py new file mode 100644 index 0000000000..07a02c7070 --- /dev/null +++ b/tests/test_warp.py @@ -0,0 +1,91 @@ +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks.blocks.warp import Warp + +TEST_CASES = [ + [ + {"spatial_dims": 2, "mode": "bilinear", "padding_mode": "zeros"}, + {"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 2, 2, 2)}, + torch.arange(4).reshape((1, 1, 2, 2)), + ], + [ + {"spatial_dims": 2, "mode": "nearest", "padding_mode": "zeros"}, + {"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 2, 4, 4)}, + torch.tensor([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]]) + .unsqueeze(0) + .unsqueeze(0), + ], + [ + {"spatial_dims": 2, "mode": "bilinear", "padding_mode": "border"}, + {"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 2, 4, 4)}, + torch.tensor([[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0], [2.0, 2.0, 3, 3.0], [2.0, 2.0, 3.0, 3.0]]) + .unsqueeze(0) + .unsqueeze(0), + ], + [ + {"spatial_dims": 2, "mode": "nearest", "padding_mode": "reflection"}, + {"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 2, 4, 4)}, + torch.tensor([[3.0, 2.0, 3.0, 2.0], [1.0, 0.0, 1.0, 0.0], [3.0, 2.0, 3.0, 2.0], [1.0, 0.0, 1.0, 0.0]]) + .unsqueeze(0) + .unsqueeze(0), + ], + [ + {"spatial_dims": 3, "mode": "bilinear", "padding_mode": "zeros"}, + {"image": torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 3, 4, 4, 4)}, + torch.tensor( + [ + [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0, 0.0], [0.0, 4.0, 5.0, 0.0], [0.0, 6.0, 7.0, 0.0], [0.0, 0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], + ] + ) + .unsqueeze(0) + .unsqueeze(0), + ], + [ + {"spatial_dims": 3, "mode": "bilinear", "padding_mode": "border"}, + {"image": torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 3, 4, 4, 4)}, + torch.tensor( + [ + [[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0], [2.0, 2.0, 3.0, 3.0], [2.0, 2.0, 3.0, 3.0]], + [[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0], [2.0, 2.0, 3.0, 3.0], [2.0, 2.0, 3.0, 3.0]], + [[4.0, 4.0, 5.0, 5.0], [4.0, 4.0, 5.0, 5.0], [6.0, 6.0, 7.0, 7.0], [6.0, 6.0, 7.0, 7.0]], + [[4.0, 4.0, 5.0, 5.0], [4.0, 4.0, 5.0, 5.0], [6.0, 6.0, 7.0, 7.0], [6.0, 6.0, 7.0, 7.0]], + ] + ) + .unsqueeze(0) + .unsqueeze(0), + ], +] + + +class TestWarp(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_resample(self, input_param, input_data, expected_val): + warp_layer = Warp(**input_param) + result = warp_layer(**input_data) + np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) + + def test_ill_shape(self): + warp_layer = Warp(spatial_dims=2) + with self.assertRaisesRegex(ValueError, ""): + warp_layer( + image=torch.arange(4).reshape((1, 1, 1, 2, 2)).to(dtype=torch.float), ddf=torch.zeros(1, 2, 2, 2) + ) + with self.assertRaisesRegex(ValueError, ""): + warp_layer( + image=torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), ddf=torch.zeros(1, 2, 1, 2, 2) + ) + + def test_ill_opts(self): + with self.assertRaisesRegex(ValueError, ""): + Warp(spatial_dims=4) + + +if __name__ == "__main__": + unittest.main() From 691233f0fb0e40741d51f7219ba1e78a49d0e18b Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Mon, 18 Jan 2021 22:09:07 +0000 Subject: [PATCH 2/3] 1452 add documentation Signed-off-by: kate-sann5100 --- docs/source/networks.rst | 4 ++++ monai/networks/blocks/__init__.py | 2 +- monai/networks/blocks/warp.py | 10 ++++++---- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 6a05d72b66..b97b36f5f4 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -134,6 +134,10 @@ Blocks .. autoclass:: LocalNetFeatureExtractorBlock :members: +`Warp` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: Warp + :members: Layers ------ diff --git a/monai/networks/blocks/__init__.py b/monai/networks/blocks/__init__.py index 13c45d5c06..8ac06f8776 100644 --- a/monai/networks/blocks/__init__.py +++ b/monai/networks/blocks/__init__.py @@ -27,4 +27,4 @@ SEResNeXtBottleneck, ) from .upsample import SubpixelUpsample, Subpixelupsample, SubpixelUpSample, Upsample, UpSample -from .warp import Warp \ No newline at end of file +from .warp import Warp diff --git a/monai/networks/blocks/warp.py b/monai/networks/blocks/warp.py index 08cdde4212..944d286ef7 100644 --- a/monai/networks/blocks/warp.py +++ b/monai/networks/blocks/warp.py @@ -10,6 +10,10 @@ class Warp(nn.Module): + """ + Warp an image with given DDF. + """ + def __init__( self, spatial_dims: int, @@ -17,10 +21,8 @@ def __init__( padding_mode: Optional[Union[GridSamplePadMode, str]] = GridSamplePadMode.ZEROS, ): """ - warp an image with given DDF. - supports spatially 2D or 3D (num_channels, H, W[, D]). Args: - spatial_dims: number of spatial dimensions + spatial_dims: {2, 3}. number of spatial dimensions mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample @@ -57,7 +59,7 @@ def forward(self, image: torch.Tensor, ddf: torch.Tensor) -> torch.Tensor: ddf: Tensor in shape (batch, num_channels, H, W[, D]) Returns: - warped_image: Tensor in the same shape as ddf (batch, num_channels, H, W[, D]) + warped_image in the same shape as ddf (batch, num_channels, H, W[, D]) """ if len(image.shape) != 2 + self.spatial_dims: raise ValueError(f"expecting {self.spatial_dims + 2}-d input, " f"got input in shape {image.shape}") From 70a4847a4df6e3d1a546b572666ef02364f322d9 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Mon, 18 Jan 2021 23:36:20 +0000 Subject: [PATCH 3/3] 1452 enforce same image and ddf size, add test case Signed-off-by: kate-sann5100 --- monai/networks/blocks/warp.py | 17 +++++++---- tests/test_warp.py | 53 +++++++++++++++++++---------------- 2 files changed, 40 insertions(+), 30 deletions(-) diff --git a/monai/networks/blocks/warp.py b/monai/networks/blocks/warp.py index 944d286ef7..56b289e394 100644 --- a/monai/networks/blocks/warp.py +++ b/monai/networks/blocks/warp.py @@ -46,20 +46,20 @@ def get_reference_grid(ddf: torch.Tensor) -> torch.Tensor: return grid @staticmethod - def normalize_grid(image: torch.Tensor, grid: torch.Tensor) -> torch.Tensor: + def normalize_grid(grid: torch.Tensor) -> torch.Tensor: # (batch, ..., self.spatial_dims) - for i, (img_dim, ddf_dim) in enumerate(zip(image.shape[2:], grid.shape[1:-1])): - grid[..., i] = (1 - ddf_dim + grid[..., i] * 2) / (img_dim - 1) + for i, dim in enumerate(grid.shape[1:-1]): + grid[..., i] = grid[..., i] * 2 / (dim - 1) - 1 return grid def forward(self, image: torch.Tensor, ddf: torch.Tensor) -> torch.Tensor: """ Args: image: Tensor in shape (batch, num_channels, H, W[, D]) - ddf: Tensor in shape (batch, num_channels, H, W[, D]) + ddf: Tensor in the same spatial size as image, in shape (batch, spatial_dims, H, W[, D]) Returns: - warped_image in the same shape as ddf (batch, num_channels, H, W[, D]) + warped_image in the same shape as image (batch, num_channels, H, W[, D]) """ if len(image.shape) != 2 + self.spatial_dims: raise ValueError(f"expecting {self.spatial_dims + 2}-d input, " f"got input in shape {image.shape}") @@ -68,6 +68,11 @@ def forward(self, image: torch.Tensor, ddf: torch.Tensor) -> torch.Tensor: f"expecting {self.spatial_dims + 2}-d ddf with {self.spatial_dims} channels, " f"got ddf in shape {ddf.shape}" ) + if image.shape[0] != ddf.shape[0] or image.shape[2:] != ddf.shape[2:]: + raise ValueError( + "expecting image and ddf of same batch size and spatial size, " + f"got image of shape {image.shape}, ddf of shape {ddf.shape}" + ) grid = self.get_reference_grid(ddf) + ddf grid = grid.permute([0] + list(range(2, 2 + self.spatial_dims)) + [1]) # (batch, ..., self.spatial_dims) @@ -89,7 +94,7 @@ def forward(self, image: torch.Tensor, ddf: torch.Tensor) -> torch.Tensor: interpolation=1 if _interp_mode == "bilinear" else _interp_mode, ) else: - grid = self.normalize_grid(image, grid) + grid = self.normalize_grid(grid) index_ordering: List[int] = list(range(self.spatial_dims - 1, -1, -1)) grid = grid[..., index_ordering] # z, y, x -> x, y, z warped_image = F.grid_sample( diff --git a/tests/test_warp.py b/tests/test_warp.py index 07a02c7070..ba8bc9a994 100644 --- a/tests/test_warp.py +++ b/tests/test_warp.py @@ -6,7 +6,7 @@ from monai.networks.blocks.warp import Warp -TEST_CASES = [ +TEST_CASE = [ [ {"spatial_dims": 2, "mode": "bilinear", "padding_mode": "zeros"}, {"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 2, 2, 2)}, @@ -35,31 +35,34 @@ ], [ {"spatial_dims": 3, "mode": "bilinear", "padding_mode": "zeros"}, - {"image": torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 3, 4, 4, 4)}, - torch.tensor( - [ - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 4.0, 5.0, 0.0], [0.0, 6.0, 7.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - ] - ) - .unsqueeze(0) - .unsqueeze(0), + {"image": torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 3, 2, 2, 2)}, + torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float), ], +] + +TEST_CASES = [ [ - {"spatial_dims": 3, "mode": "bilinear", "padding_mode": "border"}, - {"image": torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 3, 4, 4, 4)}, - torch.tensor( - [ - [[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0], [2.0, 2.0, 3.0, 3.0], [2.0, 2.0, 3.0, 3.0]], - [[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0], [2.0, 2.0, 3.0, 3.0], [2.0, 2.0, 3.0, 3.0]], - [[4.0, 4.0, 5.0, 5.0], [4.0, 4.0, 5.0, 5.0], [6.0, 6.0, 7.0, 7.0], [6.0, 6.0, 7.0, 7.0]], - [[4.0, 4.0, 5.0, 5.0], [4.0, 4.0, 5.0, 5.0], [6.0, 6.0, 7.0, 7.0], [6.0, 6.0, 7.0, 7.0]], - ] - ) - .unsqueeze(0) - .unsqueeze(0), + {"spatial_dims": 2, "mode": "bilinear", "padding_mode": "zeros"}, + {"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 2, 2, 2)}, + torch.arange(4).reshape((1, 1, 2, 2)), + ], + [ + {"spatial_dims": 2, "mode": "bilinear", "padding_mode": "zeros"}, + {"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.ones(1, 2, 2, 2)}, + torch.tensor([[[[3, 0], [0, 0]]]]), + ], + [ + {"spatial_dims": 3, "mode": "nearest", "padding_mode": "border"}, + { + "image": torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float), + "ddf": torch.ones(1, 3, 2, 2, 2) * -1, + }, + torch.tensor([[[[[0, 0], [0, 0]], [[0, 0], [0, 0]]]]]), + ], + [ + {"spatial_dims": 3, "mode": "nearest", "padding_mode": "reflection"}, + {"image": torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float), "ddf": torch.ones(1, 3, 2, 2, 2)}, + torch.tensor([[[[[7, 6], [5, 4]], [[3, 2], [1, 0]]]]]), ], ] @@ -81,6 +84,8 @@ def test_ill_shape(self): warp_layer( image=torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), ddf=torch.zeros(1, 2, 1, 2, 2) ) + with self.assertRaisesRegex(ValueError, ""): + warp_layer(image=torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), ddf=torch.zeros(1, 2, 3, 3)) def test_ill_opts(self): with self.assertRaisesRegex(ValueError, ""):