diff --git a/monai/config/deviceconfig.py b/monai/config/deviceconfig.py index be77a1d975..213be56b5c 100644 --- a/monai/config/deviceconfig.py +++ b/monai/config/deviceconfig.py @@ -131,7 +131,8 @@ def get_system_info() -> OrderedDict: elif output["System"] == "Darwin": _dict_append(output, "Mac version", lambda: platform.mac_ver()[0]) else: - linux_ver = re.search(r'PRETTY_NAME="(.*)"', open("/etc/os-release", "r").read()) + with open("/etc/os-release", "r") as rel_f: + linux_ver = re.search(r'PRETTY_NAME="(.*)"', rel_f.read()) if linux_ver: _dict_append(output, "Linux version", lambda: linux_ver.group(1)) diff --git a/monai/networks/blocks/warp.py b/monai/networks/blocks/warp.py index 35d2a88f12..1013540288 100644 --- a/monai/networks/blocks/warp.py +++ b/monai/networks/blocks/warp.py @@ -1,10 +1,15 @@ +import warnings from typing import List, Optional, Union import torch from torch import nn from torch.nn import functional as F -from monai.utils import GridSamplePadMode +from monai.config.deviceconfig import USE_COMPILED +from monai.networks.layers.spatial_transforms import grid_pull +from monai.utils import GridSampleMode, GridSamplePadMode + +__all__ = ["Warp", "DVF2DDF"] class Warp(nn.Module): @@ -14,7 +19,7 @@ class Warp(nn.Module): def __init__( self, - mode: int = 1, + mode=1, padding_mode: Optional[Union[GridSamplePadMode, str]] = GridSamplePadMode.ZEROS, ): """ @@ -33,10 +38,32 @@ def __init__( See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample """ super(Warp, self).__init__() - if mode < 0: - raise ValueError(f"do not support negative mode, got mode={mode}") - self.mode = mode - self.padding_mode: GridSamplePadMode = GridSamplePadMode(padding_mode) + # resolves _interp_mode for different methods + if USE_COMPILED: + self._interp_mode = mode + else: + warnings.warn("monai.networks.blocks.Warp: Using PyTorch native grid_sample.") + self._interp_mode = GridSampleMode.BILINEAR.value # works for both 4D and 5D tensors + if mode == 0: + self._interp_mode = GridSampleMode.NEAREST.value + elif mode == 1: + self._interp_mode = GridSampleMode.BILINEAR.value + elif mode == 3: + self._interp_mode = GridSampleMode.BICUBIC.value # torch.functional.grid_sample only supports 4D + else: + warnings.warn(f"Order-{mode} interpolation is not supported, using linear interpolation.") + + # resolves _padding_mode for different methods + padding_mode = GridSamplePadMode(padding_mode).value + if USE_COMPILED: + if padding_mode == GridSamplePadMode.ZEROS.value: + self._padding_mode = 7 + elif padding_mode == GridSamplePadMode.BORDER.value: + self._padding_mode = 0 + else: + self._padding_mode = 1 # reflection + else: + self._padding_mode = padding_mode # type: ignore @staticmethod def get_reference_grid(ddf: torch.Tensor) -> torch.Tensor: @@ -46,14 +73,7 @@ def get_reference_grid(ddf: torch.Tensor) -> torch.Tensor: grid = grid.to(ddf) return grid - @staticmethod - def normalize_grid(grid: torch.Tensor) -> torch.Tensor: - # (batch, ..., spatial_dims) - for i, dim in enumerate(grid.shape[1:-1]): - grid[..., i] = grid[..., i] * 2 / (dim - 1) - 1 - return grid - - def forward(self, image: torch.Tensor, ddf: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor, ddf: torch.Tensor): """ Args: image: Tensor in shape (batch, num_channels, H, W[, D]) @@ -73,34 +93,23 @@ def forward(self, image: torch.Tensor, ddf: torch.Tensor) -> torch.Tensor: grid = self.get_reference_grid(ddf) + ddf grid = grid.permute([0] + list(range(2, 2 + spatial_dims)) + [1]) # (batch, ..., spatial_dims) - if self.mode > 1: - raise ValueError(f"{self.mode}-order interpolation not yet implemented.") - # if not USE_COMPILED: - # raise ValueError(f"cannot perform {self.mode}-order interpolation without C compile.") - # _padding_mode = self.padding_mode.value - # if _padding_mode == "zeros": - # bound = 7 - # elif _padding_mode == "border": - # bound = 0 - # else: - # bound = 1 - # warped_image: torch.Tensor = grid_pull( - # image, - # grid, - # bound=bound, - # extrapolate=True, - # interpolation=self.mode, - # ) - else: - grid = self.normalize_grid(grid) + if not USE_COMPILED: # pytorch native grid_sample + for i, dim in enumerate(grid.shape[1:-1]): + grid[..., i] = grid[..., i] * 2 / (dim - 1) - 1 index_ordering: List[int] = list(range(spatial_dims - 1, -1, -1)) grid = grid[..., index_ordering] # z, y, x -> x, y, z - _interp_mode = "bilinear" if self.mode == 1 else "nearest" - warped_image = F.grid_sample( - image, grid, mode=_interp_mode, padding_mode=self.padding_mode.value, align_corners=True + return F.grid_sample( + image, grid, mode=self._interp_mode, padding_mode=f"{self._padding_mode}", align_corners=True ) - return warped_image + # using csrc resampling + return grid_pull( + image, + grid, + bound=self._padding_mode, + extrapolate=True, + interpolation=self._interp_mode, + ) class DVF2DDF(nn.Module): diff --git a/setup.cfg b/setup.cfg index 15e6a6d127..f06c56d001 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,6 +8,12 @@ long_description = file:README.md long_description_content_type = text/markdown; charset=UTF-8 platforms = OS Independent license = Apache License 2.0 +license_files = + LICENSE +project_urls = + Documentation=https://docs.monai.io/ + Bug Tracker=https://github.com/Project-MONAI/MONAI/issues + Source Code=https://github.com/Project-MONAI/MONAI [options] python_requires = >= 3.6 diff --git a/tests/test_detect_envelope.py b/tests/test_detect_envelope.py index 47b3a66305..ded0290de2 100644 --- a/tests/test_detect_envelope.py +++ b/tests/test_detect_envelope.py @@ -156,7 +156,7 @@ def test_no_fft_module_error(self): @SkipIfAtLeastPyTorchVersion((1, 7)) class TestDetectEnvelopeInvalidPyTorch(unittest.TestCase): def test_invalid_pytorch_error(self): - with self.assertRaisesRegexp(InvalidPyTorchVersionError, "version"): + with self.assertRaisesRegex(InvalidPyTorchVersionError, "version"): DetectEnvelope() diff --git a/tests/test_warp.py b/tests/test_warp.py index 613b6fb4ab..a2af441a5b 100644 --- a/tests/test_warp.py +++ b/tests/test_warp.py @@ -3,10 +3,12 @@ import numpy as np import torch from parameterized import parameterized +from torch.autograd import gradcheck +from monai.config.deviceconfig import USE_COMPILED from monai.networks.blocks.warp import Warp -LOW_POWER_TEST_CASES = [ +LOW_POWER_TEST_CASES = [ # run with BUILD_MONAI=1 to test csrc/resample, BUILD_MONAI=0 to test native grid_sample [ {"mode": 0, "padding_mode": "zeros"}, {"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 2, 2, 2)}, @@ -17,31 +19,63 @@ {"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.ones(1, 2, 2, 2)}, torch.tensor([[[[3, 0], [0, 0]]]]), ], + [ + {"mode": 1, "padding_mode": "border"}, + { + "image": torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float), + "ddf": torch.ones(1, 3, 2, 2, 2) * -1, + }, + torch.tensor([[[[[0, 0], [0, 0]], [[0, 0], [0, 0]]]]]), + ], + [ + {"mode": 1, "padding_mode": "reflection"}, + { + "image": torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float), + "ddf": torch.ones(1, 3, 2, 2, 2) * -1, + }, + torch.tensor([[[[[7.0, 6.0], [5.0, 4.0]], [[3.0, 2.0], [1.0, 0.0]]]]]), + ], ] -HIGH_POWER_TEST_CASES = [ +CPP_TEST_CASES = [ # high order, BUILD_MONAI=1 to test csrc/resample [ {"mode": 2, "padding_mode": "border"}, { "image": torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float), "ddf": torch.ones(1, 3, 2, 2, 2) * -1, }, - torch.tensor([[[[[0, 0], [0, 0]], [[0, 0], [0, 0]]]]]), + torch.tensor([[[[[0.0000, 0.1250], [0.2500, 0.3750]], [[0.5000, 0.6250], [0.7500, 0.8750]]]]]), + ], + [ + {"mode": 2, "padding_mode": "reflection"}, + { + "image": torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float), + "ddf": torch.ones(1, 3, 2, 2, 2) * -1, + }, + torch.tensor([[[[[5.2500, 4.7500], [4.2500, 3.7500]], [[3.2500, 2.7500], [2.2500, 1.7500]]]]]), + ], + [ + {"mode": 2, "padding_mode": "zeros"}, + { + "image": torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float), + "ddf": torch.ones(1, 3, 2, 2, 2) * -1, + }, + torch.tensor([[[[[0.0000, 0.0020], [0.0039, 0.0410]], [[0.0078, 0.0684], [0.0820, 0.6699]]]]]), ], [ {"mode": 3, "padding_mode": "reflection"}, {"image": torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float), "ddf": torch.ones(1, 3, 2, 2, 2)}, - torch.tensor([[[[[7, 6], [5, 4]], [[3, 2], [1, 0]]]]]), + torch.tensor([[[[[4.6667, 4.3333], [4.0000, 3.6667]], [[3.3333, 3.0000], [2.6667, 2.3333]]]]]), ], ] TEST_CASES = LOW_POWER_TEST_CASES -# if USE_COMPILED: -# TEST_CASES += HIGH_POWER_TEST_CASES +if USE_COMPILED: + TEST_CASES += CPP_TEST_CASES class TestWarp(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TEST_CASES, skip_on_empty=True) def test_resample(self, input_param, input_data, expected_val): warp_layer = Warp(**input_param) result = warp_layer(**input_data) @@ -60,6 +94,16 @@ def test_ill_shape(self): with self.assertRaisesRegex(ValueError, ""): warp_layer(image=torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), ddf=torch.zeros(1, 2, 3, 3)) + def test_grad(self): + for m in [0, 1, 2, 3]: + for p in ["zeros", "border"]: + warp_layer = Warp(mode=m, padding_mode=p) + input_image = torch.rand((2, 3, 20, 20), dtype=torch.float64) * 10.0 + ddf = torch.rand((2, 2, 20, 20), dtype=torch.float64) * 2.0 + input_image.requires_grad = True + ddf.requires_grad = False # Jacobian mismatch for output 0 with respect to input 1 + gradcheck(warp_layer, (input_image, ddf), atol=1e-2, eps=1e-2) + if __name__ == "__main__": unittest.main()