Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions docs/source/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,12 @@ Blocks
.. autoclass:: SubpixelUpSample

`LocalNet DownSample Block`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: LocalNetDownSampleBlock
:members:

`LocalNet UpSample Block`
~~~~~~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: LocalNetUpSampleBlock
:members:

Expand All @@ -135,10 +135,15 @@ Blocks
:members:

`Warp`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~~~~~~
.. autoclass:: Warp
:members:

`DVF2DDF`
~~~~~~~~~
.. autoclass:: DVF2DDF
:members:

Layers
------

Expand Down Expand Up @@ -201,7 +206,7 @@ Layers
~~~~~~~~~~~~~~~~
.. autoclass:: GaussianFilter
:members:

`BilateralFilter`
~~~~~~~~~~~~~~~~~
.. autoclass:: BilateralFilter
Expand Down
2 changes: 1 addition & 1 deletion monai/networks/blocks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@
SEResNeXtBottleneck,
)
from .upsample import SubpixelUpsample, Subpixelupsample, SubpixelUpSample, Upsample, UpSample
from .warp import Warp
from .warp import DVF2DDF, Warp
39 changes: 38 additions & 1 deletion monai/networks/blocks/warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def forward(self, image: torch.Tensor, ddf: torch.Tensor) -> torch.Tensor:
"""
Args:
image: Tensor in shape (batch, num_channels, H, W[, D])
ddf: Tensor in the same spatial size as image, in shape (batch, spatial_dims, H, W[, D])
ddf: Tensor in the same spatial size as image, in shape (batch, ``spatial_dims``, H, W[, D])

Returns:
warped_image in the same shape as image (batch, num_channels, H, W[, D])
Expand Down Expand Up @@ -111,3 +111,40 @@ def forward(self, image: torch.Tensor, ddf: torch.Tensor) -> torch.Tensor:
)

return warped_image


class DVF2DDF(nn.Module):
"""
Layer calculates a dense velocity field (DVF) from a dense displacement field (DDF)
with scaling and squaring.

Adapted from:
DeepReg (https://github.com/DeepRegNet/DeepReg)

"""

def __init__(
self,
spatial_dims: int,
num_steps: int = 7,
mode: int = 1,
padding_mode: Optional[Union[GridSamplePadMode, str]] = GridSamplePadMode.ZEROS,
):
super(DVF2DDF, self).__init__()
if num_steps <= 0:
raise ValueError(f"expecting positive num_steps, got {num_steps}")
self.num_steps = num_steps
self.warp_layer = Warp(spatial_dims=spatial_dims, mode=mode, padding_mode=padding_mode)

def forward(self, dvf):
"""
Args:
dvf: dvf to be transformed, in shape (batch, ``spatial_dims``, H, W[,D])

Returns:

"""
ddf: torch.Tensor = dvf / (2 ** self.num_steps)
for _ in range(self.num_steps):
ddf = ddf + self.warp_layer(image=ddf, ddf=ddf)
return ddf
58 changes: 58 additions & 0 deletions tests/test_dvf2ddf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import unittest

import numpy as np
import torch
from parameterized import parameterized
from torch import nn
from torch.optim import SGD

from monai.networks.blocks.warp import DVF2DDF
from monai.utils import set_determinism

TEST_CASES = [
[{"spatial_dims": 2, "num_steps": 1}, {"dvf": torch.zeros(1, 2, 2, 2)}, torch.zeros(1, 2, 2, 2)],
[
{"spatial_dims": 3, "num_steps": 1},
{"dvf": torch.ones(1, 3, 2, 2, 2)},
torch.tensor([[[1.0000, 0.7500], [0.7500, 0.6250]], [[0.7500, 0.6250], [0.6250, 0.5625]]])
.reshape(1, 1, 2, 2, 2)
.expand(-1, 3, -1, -1, -1),
],
[
{"spatial_dims": 3, "num_steps": 2},
{"dvf": torch.ones(1, 3, 2, 2, 2)},
torch.tensor([[[0.9175, 0.6618], [0.6618, 0.5306]], [[0.6618, 0.5306], [0.5306, 0.4506]]])
.reshape(1, 1, 2, 2, 2)
.expand(-1, 3, -1, -1, -1),
],
]


class TestDVF2DDF(unittest.TestCase):
def setUp(self):
set_determinism(0)

def tearDown(self):
set_determinism(None)

@parameterized.expand(TEST_CASES)
def test_value(self, input_param, input_data, expected_val):
layer = DVF2DDF(**input_param)
result = layer(**input_data)
np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4)

def test_gradient(self):
network = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=1)
dvf2ddf = DVF2DDF(spatial_dims=2, num_steps=1)
optimizer = SGD(network.parameters(), lr=0.01)
x = torch.ones((1, 1, 5, 5))
x = network(x)
x = dvf2ddf(x)
loss = torch.sum(x)
loss.backward()
optimizer.step()
np.testing.assert_allclose(network.weight.grad.cpu().numpy(), np.array([[[[22.471329]]], [[[22.552576]]]]))


if __name__ == "__main__":
unittest.main()