diff --git a/monai/networks/blocks/warp.py b/monai/networks/blocks/warp.py index 10a115eff8..8a570e42c4 100644 --- a/monai/networks/blocks/warp.py +++ b/monai/networks/blocks/warp.py @@ -32,7 +32,7 @@ class Warp(nn.Module): Warp an image with given dense displacement field (DDF). """ - def __init__(self, mode=GridSampleMode.BILINEAR.value, padding_mode=GridSamplePadMode.BORDER.value): + def __init__(self, mode=GridSampleMode.BILINEAR.value, padding_mode=GridSamplePadMode.BORDER.value, jitter=False): """ For pytorch native APIs, the possible values are: @@ -47,6 +47,11 @@ def __init__(self, mode=GridSampleMode.BILINEAR.value, padding_mode=GridSamplePa - padding_mode: ``"zeros"``, ``"border"``, ``"reflection"``, 0, 1, ... See also: :py:class:`monai.networks.layers.grid_pull` + + - jitter: bool, default=False + Define reference grid on non-integer values + Reference: B. Likar and F. Pernus. A heirarchical approach to elastic registration + based on mutual information. Image and Vision Computing, 19:33-44, 2001. """ super().__init__() # resolves _interp_mode for different methods @@ -84,8 +89,9 @@ def __init__(self, mode=GridSampleMode.BILINEAR.value, padding_mode=GridSamplePa self._padding_mode = GridSamplePadMode(padding_mode).value self.ref_grid = None + self.jitter = jitter - def get_reference_grid(self, ddf: torch.Tensor) -> torch.Tensor: + def get_reference_grid(self, ddf: torch.Tensor, jitter: bool = False, seed: int = 0) -> torch.Tensor: if ( self.ref_grid is not None and self.ref_grid.shape[0] == ddf.shape[0] @@ -96,6 +102,11 @@ def get_reference_grid(self, ddf: torch.Tensor) -> torch.Tensor: grid = torch.stack(meshgrid_ij(*mesh_points), dim=0) # (spatial_dims, ...) grid = torch.stack([grid] * ddf.shape[0], dim=0) # (batch, spatial_dims, ...) self.ref_grid = grid.to(ddf) + if jitter: + # Define reference grid on non-integer values + with torch.random.fork_rng(enabled=seed): + torch.random.manual_seed(seed) + grid += torch.rand_like(grid) self.ref_grid.requires_grad = False return self.ref_grid @@ -117,7 +128,7 @@ def forward(self, image: torch.Tensor, ddf: torch.Tensor): f"Given input {spatial_dims}-d image shape {image.shape}, the input DDF shape must be {ddf_shape}, " f"Got {ddf.shape} instead." ) - grid = self.get_reference_grid(ddf) + ddf + grid = self.get_reference_grid(ddf, jitter=self.jitter) + ddf grid = grid.permute([0] + list(range(2, 2 + spatial_dims)) + [1]) # (batch, ..., spatial_dims) if not USE_COMPILED: # pytorch native grid_sample