From 5632890e2f0e89b5049e230502f58699bb05cfe6 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Mon, 8 Feb 2021 15:28:18 +0000 Subject: [PATCH 1/3] 1559 add DVF2DDF Signed-off-by: kate-sann5100 --- docs/source/networks.rst | 7 +++++- monai/networks/blocks/__init__.py | 2 +- monai/networks/blocks/warp.py | 37 ++++++++++++++++++++++++++++++- tests/test_dvf2ddf.py | 37 +++++++++++++++++++++++++++++++ 4 files changed, 80 insertions(+), 3 deletions(-) create mode 100644 tests/test_dvf2ddf.py diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 7c22964835..0e5e0a1fb9 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -135,10 +135,15 @@ Blocks :members: `Warp` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~ .. autoclass:: Warp :members: +`DVF2DDF` +~~~~~~~~~~ +.. autoclass:: DVF2DDF + :members: + Layers ------ diff --git a/monai/networks/blocks/__init__.py b/monai/networks/blocks/__init__.py index 8ac06f8776..4a2e31928e 100644 --- a/monai/networks/blocks/__init__.py +++ b/monai/networks/blocks/__init__.py @@ -27,4 +27,4 @@ SEResNeXtBottleneck, ) from .upsample import SubpixelUpsample, Subpixelupsample, SubpixelUpSample, Upsample, UpSample -from .warp import Warp +from .warp import DVF2DDF, Warp diff --git a/monai/networks/blocks/warp.py b/monai/networks/blocks/warp.py index 60e23f6750..67f388b859 100644 --- a/monai/networks/blocks/warp.py +++ b/monai/networks/blocks/warp.py @@ -62,7 +62,7 @@ def forward(self, image: torch.Tensor, ddf: torch.Tensor) -> torch.Tensor: """ Args: image: Tensor in shape (batch, num_channels, H, W[, D]) - ddf: Tensor in the same spatial size as image, in shape (batch, spatial_dims, H, W[, D]) + ddf: Tensor in the same spatial size as image, in shape (batch, ``spatial_dims``, H, W[, D]) Returns: warped_image in the same shape as image (batch, num_channels, H, W[, D]) @@ -111,3 +111,38 @@ def forward(self, image: torch.Tensor, ddf: torch.Tensor) -> torch.Tensor: ) return warped_image + + +class DVF2DDF(nn.Module): + """ + Layer calculates DVF from DDF with scaling and squaring. + + Adapted from: + DeepReg (https://github.com/DeepRegNet/DeepReg) + """ + + def __init__( + self, + spatial_dims: int, + num_steps: int = 7, + mode: int = 1, + padding_mode: Optional[Union[GridSamplePadMode, str]] = GridSamplePadMode.ZEROS, + ): + super(DVF2DDF, self).__init__() + if num_steps <= 0: + raise ValueError(f"expecting positive num_steps, got {num_steps}") + self.num_steps = num_steps + self.warp_layer = Warp(spatial_dims=spatial_dims, mode=mode, padding_mode=padding_mode) + + def forward(self, dvf): + """ + Args: + dvf: dvf to be transformed, in shape (batch, ``spatial_dims``, H, W[,D]) + + Returns: + + """ + ddf: torch.Tensor = dvf / (2 ** self.num_steps) + for _ in range(self.num_steps): + ddf += self.warp_layer(image=ddf, ddf=ddf) + return ddf diff --git a/tests/test_dvf2ddf.py b/tests/test_dvf2ddf.py new file mode 100644 index 0000000000..9f806d5732 --- /dev/null +++ b/tests/test_dvf2ddf.py @@ -0,0 +1,37 @@ +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks.blocks.warp import DVF2DDF + +TEST_CASES = [ + [{"spatial_dims": 2, "num_steps": 1}, {"dvf": torch.zeros(1, 2, 2, 2)}, torch.zeros(1, 2, 2, 2)], + [ + {"spatial_dims": 3, "num_steps": 1}, + {"dvf": torch.ones(1, 3, 2, 2, 2)}, + torch.tensor([[[1.0000, 0.7500], [0.7500, 0.6250]], [[0.7500, 0.6250], [0.6250, 0.5625]]]) + .reshape(1, 1, 2, 2, 2) + .expand(-1, 3, -1, -1, -1), + ], + [ + {"spatial_dims": 3, "num_steps": 2}, + {"dvf": torch.ones(1, 3, 2, 2, 2)}, + torch.tensor([[[0.9175, 0.6618], [0.6618, 0.5306]], [[0.6618, 0.5306], [0.5306, 0.4506]]]) + .reshape(1, 1, 2, 2, 2) + .expand(-1, 3, -1, -1, -1), + ], +] + + +class TestDVF2DDF(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_value(self, input_param, input_data, expected_val): + layer = DVF2DDF(**input_param) + result = layer(**input_data) + np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) + + +if __name__ == "__main__": + unittest.main() From ad981e9779f70ab086ed75643ee58fbdb6912fc5 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Thu, 11 Feb 2021 17:46:34 +0000 Subject: [PATCH 2/3] 1559 add gradient test Signed-off-by: kate-sann5100 --- monai/networks/blocks/warp.py | 2 +- tests/test_dvf2ddf.py | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/monai/networks/blocks/warp.py b/monai/networks/blocks/warp.py index 67f388b859..5534cf3bdc 100644 --- a/monai/networks/blocks/warp.py +++ b/monai/networks/blocks/warp.py @@ -144,5 +144,5 @@ def forward(self, dvf): """ ddf: torch.Tensor = dvf / (2 ** self.num_steps) for _ in range(self.num_steps): - ddf += self.warp_layer(image=ddf, ddf=ddf) + ddf = ddf + self.warp_layer(image=ddf, ddf=ddf) return ddf diff --git a/tests/test_dvf2ddf.py b/tests/test_dvf2ddf.py index 9f806d5732..4be0669237 100644 --- a/tests/test_dvf2ddf.py +++ b/tests/test_dvf2ddf.py @@ -3,6 +3,8 @@ import numpy as np import torch from parameterized import parameterized +from torch import nn +from torch.optim import SGD from monai.networks.blocks.warp import DVF2DDF @@ -32,6 +34,17 @@ def test_value(self, input_param, input_data, expected_val): result = layer(**input_data) np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) + def test_gradient(self): + network = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=1) + dvf2ddf = DVF2DDF(spatial_dims=2, num_steps=1) + optimizer = SGD(network.parameters(), lr=0.01) + x = torch.ones((1, 1, 5, 5)) + x = network(x) + x = dvf2ddf(x) + loss = torch.sum(x) + loss.backward() + optimizer.step() + if __name__ == "__main__": unittest.main() From f28383aca7391e65969d54c73f0c1a05326c2705 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 18 Feb 2021 09:41:01 +0000 Subject: [PATCH 3/3] update docs; test cases Signed-off-by: Wenqi Li --- docs/source/networks.rst | 10 +++++----- monai/networks/blocks/warp.py | 4 +++- tests/test_dvf2ddf.py | 8 ++++++++ 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 0e5e0a1fb9..e0ac0f2d75 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -120,12 +120,12 @@ Blocks .. autoclass:: SubpixelUpSample `LocalNet DownSample Block` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: LocalNetDownSampleBlock :members: `LocalNet UpSample Block` -~~~~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: LocalNetUpSampleBlock :members: @@ -135,12 +135,12 @@ Blocks :members: `Warp` -~~~~~~~ +~~~~~~ .. autoclass:: Warp :members: `DVF2DDF` -~~~~~~~~~~ +~~~~~~~~~ .. autoclass:: DVF2DDF :members: @@ -206,7 +206,7 @@ Layers ~~~~~~~~~~~~~~~~ .. autoclass:: GaussianFilter :members: - + `BilateralFilter` ~~~~~~~~~~~~~~~~~ .. autoclass:: BilateralFilter diff --git a/monai/networks/blocks/warp.py b/monai/networks/blocks/warp.py index 5534cf3bdc..eb4c09fa72 100644 --- a/monai/networks/blocks/warp.py +++ b/monai/networks/blocks/warp.py @@ -115,10 +115,12 @@ def forward(self, image: torch.Tensor, ddf: torch.Tensor) -> torch.Tensor: class DVF2DDF(nn.Module): """ - Layer calculates DVF from DDF with scaling and squaring. + Layer calculates a dense velocity field (DVF) from a dense displacement field (DDF) + with scaling and squaring. Adapted from: DeepReg (https://github.com/DeepRegNet/DeepReg) + """ def __init__( diff --git a/tests/test_dvf2ddf.py b/tests/test_dvf2ddf.py index 4be0669237..0ee8ba6c30 100644 --- a/tests/test_dvf2ddf.py +++ b/tests/test_dvf2ddf.py @@ -7,6 +7,7 @@ from torch.optim import SGD from monai.networks.blocks.warp import DVF2DDF +from monai.utils import set_determinism TEST_CASES = [ [{"spatial_dims": 2, "num_steps": 1}, {"dvf": torch.zeros(1, 2, 2, 2)}, torch.zeros(1, 2, 2, 2)], @@ -28,6 +29,12 @@ class TestDVF2DDF(unittest.TestCase): + def setUp(self): + set_determinism(0) + + def tearDown(self): + set_determinism(None) + @parameterized.expand(TEST_CASES) def test_value(self, input_param, input_data, expected_val): layer = DVF2DDF(**input_param) @@ -44,6 +51,7 @@ def test_gradient(self): loss = torch.sum(x) loss.backward() optimizer.step() + np.testing.assert_allclose(network.weight.grad.cpu().numpy(), np.array([[[[22.471329]]], [[[22.552576]]]])) if __name__ == "__main__":