diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index 1534ecf6b2..4351199aee 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -66,7 +66,7 @@ def __init__( kernel_size: int = 3, kernel_type: str = "rectangular", reduction: Union[LossReduction, str] = LossReduction.MEAN, - smooth_nr: float = 1e-5, + smooth_nr: float = 0.0, smooth_dr: float = 1e-5, ) -> None: """ @@ -96,6 +96,7 @@ def __init__( _kernel = look_up_option(kernel_type, kernel_dict) self.kernel = _kernel(self.kernel_size) + self.kernel.require_grads = False self.kernel_vol = self.get_kernel_vol() self.smooth_nr = float(smooth_nr) @@ -120,14 +121,15 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if target.shape != pred.shape: raise ValueError(f"ground truth has differing shape ({target.shape}) from pred ({pred.shape})") - t2, p2, tp = target**2, pred**2, target * pred + t2, p2, tp = target * target, pred * pred, target * pred kernel, kernel_vol = self.kernel.to(pred), self.kernel_vol.to(pred) + kernels = [kernel] * self.ndim # sum over kernel - t_sum = separable_filtering(target, kernels=[kernel.to(pred)] * self.ndim) - p_sum = separable_filtering(pred, kernels=[kernel.to(pred)] * self.ndim) - t2_sum = separable_filtering(t2, kernels=[kernel.to(pred)] * self.ndim) - p2_sum = separable_filtering(p2, kernels=[kernel.to(pred)] * self.ndim) - tp_sum = separable_filtering(tp, kernels=[kernel.to(pred)] * self.ndim) + t_sum = separable_filtering(target, kernels=kernels) + p_sum = separable_filtering(pred, kernels=kernels) + t2_sum = separable_filtering(t2, kernels=kernels) + p2_sum = separable_filtering(p2, kernels=kernels) + tp_sum = separable_filtering(tp, kernels=kernels) # average over kernel t_avg = t_sum / kernel_vol @@ -143,11 +145,13 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # = sum[t*p] - sum[t] * mean[p] = cross # the following is actually squared ncc cross = tp_sum - p_avg * t_sum - t_var = t2_sum - t_avg * t_sum # std[t] ** 2 - p_var = p2_sum - p_avg * p_sum # std[p] ** 2 - t_var = torch.max(t_var, torch.zeros_like(t_var)) - p_var = torch.max(p_var, torch.zeros_like(p_var)) - ncc: torch.Tensor = (cross * cross + self.smooth_nr) / (t_var * p_var + self.smooth_dr) + t_var = torch.max( + t2_sum - t_avg * t_sum, torch.as_tensor(self.smooth_dr, dtype=t2_sum.dtype, device=t2_sum.device) + ) + p_var = torch.max( + p2_sum - p_avg * p_sum, torch.as_tensor(self.smooth_dr, dtype=p2_sum.dtype, device=p2_sum.device) + ) + ncc: torch.Tensor = (cross * cross + self.smooth_nr) / (t_var * p_var) if self.reduction == LossReduction.SUM.value: return torch.sum(ncc).neg() # sum over the batch, channel and spatial ndims diff --git a/monai/networks/blocks/warp.py b/monai/networks/blocks/warp.py index 5b925258b6..7a28e86301 100644 --- a/monai/networks/blocks/warp.py +++ b/monai/networks/blocks/warp.py @@ -82,13 +82,21 @@ def __init__(self, mode=GridSampleMode.BILINEAR.value, padding_mode=GridSamplePa else: self._padding_mode = GridSamplePadMode(padding_mode).value - @staticmethod - def get_reference_grid(ddf: torch.Tensor) -> torch.Tensor: + self.ref_grid = None + + def get_reference_grid(self, ddf: torch.Tensor) -> torch.Tensor: + if ( + self.ref_grid is not None + and self.ref_grid.shape[0] == ddf.shape[0] + and self.ref_grid.shape[1:] == ddf.shape[2:] + ): + return self.ref_grid # type: ignore mesh_points = [torch.arange(0, dim) for dim in ddf.shape[2:]] grid = torch.stack(meshgrid_ij(*mesh_points), dim=0) # (spatial_dims, ...) grid = torch.stack([grid] * ddf.shape[0], dim=0) # (batch, spatial_dims, ...) - grid = grid.to(ddf) - return grid + self.ref_grid = grid.to(ddf) + self.ref_grid.requires_grad = False + return self.ref_grid def forward(self, image: torch.Tensor, ddf: torch.Tensor): """ @@ -105,7 +113,8 @@ def forward(self, image: torch.Tensor, ddf: torch.Tensor): ddf_shape = (image.shape[0], spatial_dims) + tuple(image.shape[2:]) if ddf.shape != ddf_shape: raise ValueError( - f"Given input {spatial_dims}-d image shape {image.shape}, " f"the input DDF shape must be {ddf_shape}." + f"Given input {spatial_dims}-d image shape {image.shape}, the input DDF shape must be {ddf_shape}, " + f"Got {ddf.shape} instead." ) grid = self.get_reference_grid(ddf) + ddf grid = grid.permute([0] + list(range(2, 2 + spatial_dims)) + [1]) # (batch, ..., spatial_dims) diff --git a/tests/test_local_normalized_cross_correlation_loss.py b/tests/test_local_normalized_cross_correlation_loss.py index 394e514f43..e6052824a9 100644 --- a/tests/test_local_normalized_cross_correlation_loss.py +++ b/tests/test_local_normalized_cross_correlation_loss.py @@ -36,6 +36,14 @@ }, -1.0, ], + [ + {"spatial_dims": 1, "kernel_type": "triangular", "smooth_dr": 0.1}, + { + "pred": torch.zeros(1, 2, 3).reshape(1, 1, -1).to(dtype=torch.float, device=device), + "target": torch.zeros(1, 2, 3).reshape(1, 1, -1).to(dtype=torch.float, device=device), + }, + 0.0, + ], [ {"spatial_dims": 2, "kernel_type": "rectangular"}, {