Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion monai/config/deviceconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ def get_system_info() -> OrderedDict:
elif output["System"] == "Darwin":
_dict_append(output, "Mac version", lambda: platform.mac_ver()[0])
else:
linux_ver = re.search(r'PRETTY_NAME="(.*)"', open("/etc/os-release", "r").read())
with open("/etc/os-release", "r") as rel_f:
linux_ver = re.search(r'PRETTY_NAME="(.*)"', rel_f.read())
if linux_ver:
_dict_append(output, "Linux version", lambda: linux_ver.group(1))

Expand Down
85 changes: 47 additions & 38 deletions monai/networks/blocks/warp.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import warnings
from typing import List, Optional, Union

import torch
from torch import nn
from torch.nn import functional as F

from monai.utils import GridSamplePadMode
from monai.config.deviceconfig import USE_COMPILED
from monai.networks.layers.spatial_transforms import grid_pull
from monai.utils import GridSampleMode, GridSamplePadMode

__all__ = ["Warp", "DVF2DDF"]


class Warp(nn.Module):
Expand All @@ -14,7 +19,7 @@ class Warp(nn.Module):

def __init__(
self,
mode: int = 1,
mode=1,
padding_mode: Optional[Union[GridSamplePadMode, str]] = GridSamplePadMode.ZEROS,
):
"""
Expand All @@ -33,10 +38,32 @@ def __init__(
See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
"""
super(Warp, self).__init__()
if mode < 0:
raise ValueError(f"do not support negative mode, got mode={mode}")
self.mode = mode
self.padding_mode: GridSamplePadMode = GridSamplePadMode(padding_mode)
# resolves _interp_mode for different methods
if USE_COMPILED:
self._interp_mode = mode
else:
warnings.warn("monai.networks.blocks.Warp: Using PyTorch native grid_sample.")
self._interp_mode = GridSampleMode.BILINEAR.value # works for both 4D and 5D tensors
if mode == 0:
self._interp_mode = GridSampleMode.NEAREST.value
elif mode == 1:
self._interp_mode = GridSampleMode.BILINEAR.value
elif mode == 3:
self._interp_mode = GridSampleMode.BICUBIC.value # torch.functional.grid_sample only supports 4D
else:
warnings.warn(f"Order-{mode} interpolation is not supported, using linear interpolation.")

# resolves _padding_mode for different methods
padding_mode = GridSamplePadMode(padding_mode).value
if USE_COMPILED:
if padding_mode == GridSamplePadMode.ZEROS.value:
self._padding_mode = 7
elif padding_mode == GridSamplePadMode.BORDER.value:
self._padding_mode = 0
else:
self._padding_mode = 1 # reflection
else:
self._padding_mode = padding_mode # type: ignore

@staticmethod
def get_reference_grid(ddf: torch.Tensor) -> torch.Tensor:
Expand All @@ -46,14 +73,7 @@ def get_reference_grid(ddf: torch.Tensor) -> torch.Tensor:
grid = grid.to(ddf)
return grid

@staticmethod
def normalize_grid(grid: torch.Tensor) -> torch.Tensor:
# (batch, ..., spatial_dims)
for i, dim in enumerate(grid.shape[1:-1]):
grid[..., i] = grid[..., i] * 2 / (dim - 1) - 1
return grid

def forward(self, image: torch.Tensor, ddf: torch.Tensor) -> torch.Tensor:
def forward(self, image: torch.Tensor, ddf: torch.Tensor):
"""
Args:
image: Tensor in shape (batch, num_channels, H, W[, D])
Expand All @@ -73,34 +93,23 @@ def forward(self, image: torch.Tensor, ddf: torch.Tensor) -> torch.Tensor:
grid = self.get_reference_grid(ddf) + ddf
grid = grid.permute([0] + list(range(2, 2 + spatial_dims)) + [1]) # (batch, ..., spatial_dims)

if self.mode > 1:
raise ValueError(f"{self.mode}-order interpolation not yet implemented.")
# if not USE_COMPILED:
# raise ValueError(f"cannot perform {self.mode}-order interpolation without C compile.")
# _padding_mode = self.padding_mode.value
# if _padding_mode == "zeros":
# bound = 7
# elif _padding_mode == "border":
# bound = 0
# else:
# bound = 1
# warped_image: torch.Tensor = grid_pull(
# image,
# grid,
# bound=bound,
# extrapolate=True,
# interpolation=self.mode,
# )
else:
grid = self.normalize_grid(grid)
if not USE_COMPILED: # pytorch native grid_sample
for i, dim in enumerate(grid.shape[1:-1]):
grid[..., i] = grid[..., i] * 2 / (dim - 1) - 1
index_ordering: List[int] = list(range(spatial_dims - 1, -1, -1))
grid = grid[..., index_ordering] # z, y, x -> x, y, z
_interp_mode = "bilinear" if self.mode == 1 else "nearest"
warped_image = F.grid_sample(
image, grid, mode=_interp_mode, padding_mode=self.padding_mode.value, align_corners=True
return F.grid_sample(
image, grid, mode=self._interp_mode, padding_mode=f"{self._padding_mode}", align_corners=True
)

return warped_image
# using csrc resampling
return grid_pull(
image,
grid,
bound=self._padding_mode,
extrapolate=True,
interpolation=self._interp_mode,
)


class DVF2DDF(nn.Module):
Expand Down
6 changes: 6 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ long_description = file:README.md
long_description_content_type = text/markdown; charset=UTF-8
platforms = OS Independent
license = Apache License 2.0
license_files =
LICENSE
project_urls =
Documentation=https://docs.monai.io/
Bug Tracker=https://github.com/Project-MONAI/MONAI/issues
Source Code=https://github.com/Project-MONAI/MONAI

[options]
python_requires = >= 3.6
Expand Down
2 changes: 1 addition & 1 deletion tests/test_detect_envelope.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def test_no_fft_module_error(self):
@SkipIfAtLeastPyTorchVersion((1, 7))
class TestDetectEnvelopeInvalidPyTorch(unittest.TestCase):
def test_invalid_pytorch_error(self):
with self.assertRaisesRegexp(InvalidPyTorchVersionError, "version"):
with self.assertRaisesRegex(InvalidPyTorchVersionError, "version"):
DetectEnvelope()


Expand Down
58 changes: 51 additions & 7 deletions tests/test_warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import numpy as np
import torch
from parameterized import parameterized
from torch.autograd import gradcheck

from monai.config.deviceconfig import USE_COMPILED
from monai.networks.blocks.warp import Warp

LOW_POWER_TEST_CASES = [
LOW_POWER_TEST_CASES = [ # run with BUILD_MONAI=1 to test csrc/resample, BUILD_MONAI=0 to test native grid_sample
[
{"mode": 0, "padding_mode": "zeros"},
{"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 2, 2, 2)},
Expand All @@ -17,31 +19,63 @@
{"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.ones(1, 2, 2, 2)},
torch.tensor([[[[3, 0], [0, 0]]]]),
],
[
{"mode": 1, "padding_mode": "border"},
{
"image": torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float),
"ddf": torch.ones(1, 3, 2, 2, 2) * -1,
},
torch.tensor([[[[[0, 0], [0, 0]], [[0, 0], [0, 0]]]]]),
],
[
{"mode": 1, "padding_mode": "reflection"},
{
"image": torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float),
"ddf": torch.ones(1, 3, 2, 2, 2) * -1,
},
torch.tensor([[[[[7.0, 6.0], [5.0, 4.0]], [[3.0, 2.0], [1.0, 0.0]]]]]),
],
]

HIGH_POWER_TEST_CASES = [
CPP_TEST_CASES = [ # high order, BUILD_MONAI=1 to test csrc/resample
[
{"mode": 2, "padding_mode": "border"},
{
"image": torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float),
"ddf": torch.ones(1, 3, 2, 2, 2) * -1,
},
torch.tensor([[[[[0, 0], [0, 0]], [[0, 0], [0, 0]]]]]),
torch.tensor([[[[[0.0000, 0.1250], [0.2500, 0.3750]], [[0.5000, 0.6250], [0.7500, 0.8750]]]]]),
],
[
{"mode": 2, "padding_mode": "reflection"},
{
"image": torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float),
"ddf": torch.ones(1, 3, 2, 2, 2) * -1,
},
torch.tensor([[[[[5.2500, 4.7500], [4.2500, 3.7500]], [[3.2500, 2.7500], [2.2500, 1.7500]]]]]),
],
[
{"mode": 2, "padding_mode": "zeros"},
{
"image": torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float),
"ddf": torch.ones(1, 3, 2, 2, 2) * -1,
},
torch.tensor([[[[[0.0000, 0.0020], [0.0039, 0.0410]], [[0.0078, 0.0684], [0.0820, 0.6699]]]]]),
],
[
{"mode": 3, "padding_mode": "reflection"},
{"image": torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float), "ddf": torch.ones(1, 3, 2, 2, 2)},
torch.tensor([[[[[7, 6], [5, 4]], [[3, 2], [1, 0]]]]]),
torch.tensor([[[[[4.6667, 4.3333], [4.0000, 3.6667]], [[3.3333, 3.0000], [2.6667, 2.3333]]]]]),
],
]

TEST_CASES = LOW_POWER_TEST_CASES
# if USE_COMPILED:
# TEST_CASES += HIGH_POWER_TEST_CASES
if USE_COMPILED:
TEST_CASES += CPP_TEST_CASES


class TestWarp(unittest.TestCase):
@parameterized.expand(TEST_CASES)
@parameterized.expand(TEST_CASES, skip_on_empty=True)
def test_resample(self, input_param, input_data, expected_val):
warp_layer = Warp(**input_param)
result = warp_layer(**input_data)
Expand All @@ -60,6 +94,16 @@ def test_ill_shape(self):
with self.assertRaisesRegex(ValueError, ""):
warp_layer(image=torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), ddf=torch.zeros(1, 2, 3, 3))

def test_grad(self):
for m in [0, 1, 2, 3]:
for p in ["zeros", "border"]:
warp_layer = Warp(mode=m, padding_mode=p)
input_image = torch.rand((2, 3, 20, 20), dtype=torch.float64) * 10.0
ddf = torch.rand((2, 2, 20, 20), dtype=torch.float64) * 2.0
input_image.requires_grad = True
ddf.requires_grad = False # Jacobian mismatch for output 0 with respect to input 1
gradcheck(warp_layer, (input_image, ddf), atol=1e-2, eps=1e-2)


if __name__ == "__main__":
unittest.main()