From e1a4c12f6aef4a745d8a6a2a85327f420c5e6a21 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Tue, 19 Jan 2021 15:43:26 +0000 Subject: [PATCH 1/7] 1442 use pull-grid only for above linear interpolation Signed-off-by: kate-sann5100 --- monai/networks/blocks/warp.py | 43 ++++++++++++++++++++++------------- tests/test_warp.py | 42 ++++------------------------------ 2 files changed, 31 insertions(+), 54 deletions(-) diff --git a/monai/networks/blocks/warp.py b/monai/networks/blocks/warp.py index 56b289e394..07438a442d 100644 --- a/monai/networks/blocks/warp.py +++ b/monai/networks/blocks/warp.py @@ -17,24 +17,32 @@ class Warp(nn.Module): def __init__( self, spatial_dims: int, - mode: Optional[Union[GridSampleMode, str]] = GridSampleMode.BILINEAR, + mode: int = 1, padding_mode: Optional[Union[GridSamplePadMode, str]] = GridSamplePadMode.ZEROS, ): """ Args: spatial_dims: {2, 3}. number of spatial dimensions - mode: {``"bilinear"``, ``"nearest"``} - Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + mode: interpolation mode to calculate output values, defaults to 1. + Possible values are:: + + - 0 or 'nearest' or InterpolationType.nearest + - 1 or 'linear' or InterpolationType.linear + - 2 or 'quadratic' or InterpolationType.quadratic + - 3 or 'cubic' or InterpolationType.cubic + - 4 or 'fourth' or InterpolationType.fourth + - etc. padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"border"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample """ super(Warp, self).__init__() if spatial_dims not in [2, 3]: - raise ValueError(f"got unsupported spatial_dims = {spatial_dims}, only support 2-d and 3-d input") + raise ValueError(f"got unsupported spatial_dims={spatial_dims}, only support 2-d and 3-d input") self.spatial_dims = spatial_dims - self.mode: GridSampleMode = GridSampleMode(mode) + if mode < 0: + raise ValueError(f"do not support negative mode, got mode={mode}") + self.mode = mode self.padding_mode: GridSamplePadMode = GridSamplePadMode(padding_mode) @staticmethod @@ -77,7 +85,17 @@ def forward(self, image: torch.Tensor, ddf: torch.Tensor) -> torch.Tensor: grid = self.get_reference_grid(ddf) + ddf grid = grid.permute([0] + list(range(2, 2 + self.spatial_dims)) + [1]) # (batch, ..., self.spatial_dims) - if USE_COMPILED: + if self.mode <= 1: + grid = self.normalize_grid(grid) + index_ordering: List[int] = list(range(self.spatial_dims - 1, -1, -1)) + grid = grid[..., index_ordering] # z, y, x -> x, y, z + _interp_mode = "bilinear" if self.mode == 1 else "nearest" + warped_image = F.grid_sample( + image, grid, mode=_interp_mode, padding_mode=self.padding_mode.value, align_corners=True + ) + else: + if not USE_COMPILED: + raise ValueError(f"cannot perform {self.mode}-order interpolation without C compile.") _padding_mode = self.padding_mode.value if _padding_mode == "zeros": bound = 7 @@ -85,19 +103,12 @@ def forward(self, image: torch.Tensor, ddf: torch.Tensor) -> torch.Tensor: bound = 0 else: bound = 1 - _interp_mode = self.mode.value warped_image: torch.Tensor = grid_pull( image, grid, bound=bound, extrapolate=True, - interpolation=1 if _interp_mode == "bilinear" else _interp_mode, - ) - else: - grid = self.normalize_grid(grid) - index_ordering: List[int] = list(range(self.spatial_dims - 1, -1, -1)) - grid = grid[..., index_ordering] # z, y, x -> x, y, z - warped_image = F.grid_sample( - image, grid, mode=self.mode.value, padding_mode=self.padding_mode.value, align_corners=True + interpolation=self.mode, ) + return warped_image diff --git a/tests/test_warp.py b/tests/test_warp.py index ba8bc9a994..0c6abb19e9 100644 --- a/tests/test_warp.py +++ b/tests/test_warp.py @@ -6,53 +6,19 @@ from monai.networks.blocks.warp import Warp -TEST_CASE = [ - [ - {"spatial_dims": 2, "mode": "bilinear", "padding_mode": "zeros"}, - {"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 2, 2, 2)}, - torch.arange(4).reshape((1, 1, 2, 2)), - ], - [ - {"spatial_dims": 2, "mode": "nearest", "padding_mode": "zeros"}, - {"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 2, 4, 4)}, - torch.tensor([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]]) - .unsqueeze(0) - .unsqueeze(0), - ], - [ - {"spatial_dims": 2, "mode": "bilinear", "padding_mode": "border"}, - {"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 2, 4, 4)}, - torch.tensor([[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0], [2.0, 2.0, 3, 3.0], [2.0, 2.0, 3.0, 3.0]]) - .unsqueeze(0) - .unsqueeze(0), - ], - [ - {"spatial_dims": 2, "mode": "nearest", "padding_mode": "reflection"}, - {"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 2, 4, 4)}, - torch.tensor([[3.0, 2.0, 3.0, 2.0], [1.0, 0.0, 1.0, 0.0], [3.0, 2.0, 3.0, 2.0], [1.0, 0.0, 1.0, 0.0]]) - .unsqueeze(0) - .unsqueeze(0), - ], - [ - {"spatial_dims": 3, "mode": "bilinear", "padding_mode": "zeros"}, - {"image": torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 3, 2, 2, 2)}, - torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float), - ], -] - TEST_CASES = [ [ - {"spatial_dims": 2, "mode": "bilinear", "padding_mode": "zeros"}, + {"spatial_dims": 2, "mode": 0, "padding_mode": "zeros"}, {"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 2, 2, 2)}, torch.arange(4).reshape((1, 1, 2, 2)), ], [ - {"spatial_dims": 2, "mode": "bilinear", "padding_mode": "zeros"}, + {"spatial_dims": 2, "mode": 1, "padding_mode": "zeros"}, {"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.ones(1, 2, 2, 2)}, torch.tensor([[[[3, 0], [0, 0]]]]), ], [ - {"spatial_dims": 3, "mode": "nearest", "padding_mode": "border"}, + {"spatial_dims": 3, "mode": 2, "padding_mode": "border"}, { "image": torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float), "ddf": torch.ones(1, 3, 2, 2, 2) * -1, @@ -60,7 +26,7 @@ torch.tensor([[[[[0, 0], [0, 0]], [[0, 0], [0, 0]]]]]), ], [ - {"spatial_dims": 3, "mode": "nearest", "padding_mode": "reflection"}, + {"spatial_dims": 3, "mode": 3, "padding_mode": "reflection"}, {"image": torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float), "ddf": torch.ones(1, 3, 2, 2, 2)}, torch.tensor([[[[[7, 6], [5, 4]], [[3, 2], [1, 0]]]]]), ], From 87fe559675323325ffa0800a165791dbfb051af5 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Tue, 19 Jan 2021 16:06:38 +0000 Subject: [PATCH 2/7] 1452 reformat Signed-off-by: kate-sann5100 --- monai/networks/blocks/warp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/blocks/warp.py b/monai/networks/blocks/warp.py index 07438a442d..6dfc810cde 100644 --- a/monai/networks/blocks/warp.py +++ b/monai/networks/blocks/warp.py @@ -6,7 +6,7 @@ from monai.config import USE_COMPILED from monai.networks.layers import grid_pull -from monai.utils import GridSampleMode, GridSamplePadMode +from monai.utils import GridSamplePadMode class Warp(nn.Module): From b0b7fc75c6fa9433afd578cac613e054189e005b Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Tue, 19 Jan 2021 16:26:39 +0000 Subject: [PATCH 3/7] 1452 adjust test Signed-off-by: kate-sann5100 --- tests/test_warp.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/test_warp.py b/tests/test_warp.py index 0c6abb19e9..ec9e9768a3 100644 --- a/tests/test_warp.py +++ b/tests/test_warp.py @@ -4,9 +4,10 @@ import torch from parameterized import parameterized +from monai.config import USE_COMPILED from monai.networks.blocks.warp import Warp -TEST_CASES = [ +LOW_POWER_TEST_CASES = [ [ {"spatial_dims": 2, "mode": 0, "padding_mode": "zeros"}, {"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 2, 2, 2)}, @@ -17,6 +18,9 @@ {"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.ones(1, 2, 2, 2)}, torch.tensor([[[[3, 0], [0, 0]]]]), ], +] + +HIGH_POWER_TEST_CASES = [ [ {"spatial_dims": 3, "mode": 2, "padding_mode": "border"}, { @@ -32,6 +36,10 @@ ], ] +TEST_CASES = LOW_POWER_TEST_CASES +if USE_COMPILED: + TEST_CASES += HIGH_POWER_TEST_CASES + class TestWarp(unittest.TestCase): @parameterized.expand(TEST_CASES) From 74245d0de077c3857f894b07e89f80137a39823e Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Tue, 19 Jan 2021 16:38:08 +0000 Subject: [PATCH 4/7] 1452 reformat Signed-off-by: kate-sann5100 --- monai/networks/blocks/warp.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/monai/networks/blocks/warp.py b/monai/networks/blocks/warp.py index 6dfc810cde..42b237d84e 100644 --- a/monai/networks/blocks/warp.py +++ b/monai/networks/blocks/warp.py @@ -85,15 +85,7 @@ def forward(self, image: torch.Tensor, ddf: torch.Tensor) -> torch.Tensor: grid = self.get_reference_grid(ddf) + ddf grid = grid.permute([0] + list(range(2, 2 + self.spatial_dims)) + [1]) # (batch, ..., self.spatial_dims) - if self.mode <= 1: - grid = self.normalize_grid(grid) - index_ordering: List[int] = list(range(self.spatial_dims - 1, -1, -1)) - grid = grid[..., index_ordering] # z, y, x -> x, y, z - _interp_mode = "bilinear" if self.mode == 1 else "nearest" - warped_image = F.grid_sample( - image, grid, mode=_interp_mode, padding_mode=self.padding_mode.value, align_corners=True - ) - else: + if self.mode > 1: if not USE_COMPILED: raise ValueError(f"cannot perform {self.mode}-order interpolation without C compile.") _padding_mode = self.padding_mode.value @@ -110,5 +102,13 @@ def forward(self, image: torch.Tensor, ddf: torch.Tensor) -> torch.Tensor: extrapolate=True, interpolation=self.mode, ) + else: + grid = self.normalize_grid(grid) + index_ordering: List[int] = list(range(self.spatial_dims - 1, -1, -1)) + grid = grid[..., index_ordering] # z, y, x -> x, y, z + _interp_mode = "bilinear" if self.mode == 1 else "nearest" + warped_image = F.grid_sample( + image, grid, mode=_interp_mode, padding_mode=self.padding_mode.value, align_corners=True + ) return warped_image From 282d8913fe64426db477e7f59c8b47d6a48226d6 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Wed, 27 Jan 2021 14:22:51 +0000 Subject: [PATCH 5/7] 1452 raise ValueError for mode > 1 Signed-off-by: kate-sann5100 --- monai/networks/blocks/warp.py | 35 ++++++++++++++++++----------------- tests/test_warp.py | 4 ++-- 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/monai/networks/blocks/warp.py b/monai/networks/blocks/warp.py index 42b237d84e..8de78e1f93 100644 --- a/monai/networks/blocks/warp.py +++ b/monai/networks/blocks/warp.py @@ -48,7 +48,7 @@ def __init__( @staticmethod def get_reference_grid(ddf: torch.Tensor) -> torch.Tensor: mesh_points = [torch.arange(0, dim) for dim in ddf.shape[2:]] - grid = torch.stack(torch.meshgrid(*mesh_points[::-1]), dim=0) # (spatial_dims, ...) + grid = torch.stack(torch.meshgrid(*mesh_points), dim=0) # (spatial_dims, ...) grid = torch.stack([grid] * ddf.shape[0], dim=0) # (batch, spatial_dims, ...) grid = grid.to(ddf) return grid @@ -86,22 +86,23 @@ def forward(self, image: torch.Tensor, ddf: torch.Tensor) -> torch.Tensor: grid = grid.permute([0] + list(range(2, 2 + self.spatial_dims)) + [1]) # (batch, ..., self.spatial_dims) if self.mode > 1: - if not USE_COMPILED: - raise ValueError(f"cannot perform {self.mode}-order interpolation without C compile.") - _padding_mode = self.padding_mode.value - if _padding_mode == "zeros": - bound = 7 - elif _padding_mode == "border": - bound = 0 - else: - bound = 1 - warped_image: torch.Tensor = grid_pull( - image, - grid, - bound=bound, - extrapolate=True, - interpolation=self.mode, - ) + raise ValueError(f"{self.mode}-order interpolation not yet implemented.") + # if not USE_COMPILED: + # raise ValueError(f"cannot perform {self.mode}-order interpolation without C compile.") + # _padding_mode = self.padding_mode.value + # if _padding_mode == "zeros": + # bound = 7 + # elif _padding_mode == "border": + # bound = 0 + # else: + # bound = 1 + # warped_image: torch.Tensor = grid_pull( + # image, + # grid, + # bound=bound, + # extrapolate=True, + # interpolation=self.mode, + # ) else: grid = self.normalize_grid(grid) index_ordering: List[int] = list(range(self.spatial_dims - 1, -1, -1)) diff --git a/tests/test_warp.py b/tests/test_warp.py index ec9e9768a3..075f313148 100644 --- a/tests/test_warp.py +++ b/tests/test_warp.py @@ -37,8 +37,8 @@ ] TEST_CASES = LOW_POWER_TEST_CASES -if USE_COMPILED: - TEST_CASES += HIGH_POWER_TEST_CASES +# if USE_COMPILED: +# TEST_CASES += HIGH_POWER_TEST_CASES class TestWarp(unittest.TestCase): From a9496a107116a9220c62a25616fcf072282ed501 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Wed, 27 Jan 2021 14:28:50 +0000 Subject: [PATCH 6/7] 1452 reformat Signed-off-by: kate-sann5100 --- monai/networks/blocks/warp.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/monai/networks/blocks/warp.py b/monai/networks/blocks/warp.py index 8de78e1f93..60e23f6750 100644 --- a/monai/networks/blocks/warp.py +++ b/monai/networks/blocks/warp.py @@ -4,8 +4,6 @@ from torch import nn from torch.nn import functional as F -from monai.config import USE_COMPILED -from monai.networks.layers import grid_pull from monai.utils import GridSamplePadMode From c98b41d483b0d6f7e7ea205e597f698035f4dfed Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Wed, 27 Jan 2021 14:34:04 +0000 Subject: [PATCH 7/7] 1452 reformat Signed-off-by: kate-sann5100 --- tests/test_warp.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_warp.py b/tests/test_warp.py index 075f313148..69ae997e38 100644 --- a/tests/test_warp.py +++ b/tests/test_warp.py @@ -4,7 +4,6 @@ import torch from parameterized import parameterized -from monai.config import USE_COMPILED from monai.networks.blocks.warp import Warp LOW_POWER_TEST_CASES = [