Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
7579659
Merge remote-tracking branch 'Project-MONAI/master'
kate-sann5100 Jan 6, 2021
ce61c38
1412 add local normalized cross correlation
kate-sann5100 Jan 7, 2021
5cf91d0
1412 add unit test and documentation
kate-sann5100 Jan 7, 2021
9376195
1412 fix bug
kate-sann5100 Jan 7, 2021
ac36a9f
1412 reformat code
kate-sann5100 Jan 7, 2021
ed6c28b
1412 debug type check
kate-sann5100 Jan 7, 2021
43c2f35
1412 use separable filter for speed
kate-sann5100 Jan 9, 2021
9cc4438
Merge branch 'master' into 1412-local-normalized-cross-correlation
kate-sann5100 Jan 9, 2021
f76e3f0
1412 update Union import route
kate-sann5100 Jan 9, 2021
d3aba3f
1412 add global mutual information
kate-sann5100 Jan 9, 2021
dc3c036
1412 add global mutual information
kate-sann5100 Jan 9, 2021
efaca27
1412 add documentation and fix typing error
kate-sann5100 Jan 9, 2021
79a2819
1412 autostyle fix
kate-sann5100 Jan 9, 2021
68de8ec
1412 add unit test
kate-sann5100 Jan 9, 2021
0836408
Merge remote-tracking branch 'Project-MONAI/master' into 1412-global-…
kate-sann5100 Jan 10, 2021
b66a3fb
1412 add integration test
kate-sann5100 Jan 11, 2021
548bc1a
Merge branch 'master' into 1412-global-mutual-information
kate-sann5100 Jan 11, 2021
d2d2a87
1412 reformat code
kate-sann5100 Jan 11, 2021
a539fa7
Merge remote-tracking branch 'origin/1412-global-mutual-information' …
kate-sann5100 Jan 11, 2021
32bcdec
1412 autofix style
kate-sann5100 Jan 11, 2021
fd62c96
1412 debug
kate-sann5100 Jan 11, 2021
c67072e
Merge branch 'master' into 1412-global-mutual-information
kate-sann5100 Jan 11, 2021
5a8b1d7
1412 fix device bug
kate-sann5100 Jan 11, 2021
4ef09c3
Merge remote-tracking branch 'origin/1412-global-mutual-information' …
kate-sann5100 Jan 11, 2021
3d76ee1
1412 autofix style
kate-sann5100 Jan 11, 2021
30a5145
1412 fix typo
kate-sann5100 Jan 11, 2021
56f331d
Merge branch 'master' into 1412-global-mutual-information
kate-sann5100 Jan 11, 2021
89b89ad
Merge branch 'master' into 1412-global-mutual-information
wyli Jan 12, 2021
2d4d182
Merge branch 'master' into 1412-global-mutual-information
wyli Jan 12, 2021
5f4a7d9
1412 debug kernel_vol
kate-sann5100 Jan 12, 2021
bf97fc3
Merge remote-tracking branch 'origin/1412-global-mutual-information' …
kate-sann5100 Jan 12, 2021
1b33686
1412 autofix style
kate-sann5100 Jan 12, 2021
14cb5dc
Merge branch 'master' into 1412-global-mutual-information
kate-sann5100 Jan 12, 2021
8b818dc
Merge branch 'master' into 1412-global-mutual-information
wyli Jan 12, 2021
3245bd4
Merge remote-tracking branch 'Project-MONAI/master' into 1412-global-…
kate-sann5100 Jan 12, 2021
04db075
Merge remote-tracking branch 'origin/1412-global-mutual-information' …
kate-sann5100 Jan 12, 2021
ef881a2
1412 simplify simple network
kate-sann5100 Jan 12, 2021
6fa2469
1412 debug
kate-sann5100 Jan 12, 2021
401b89e
Merge branch 'master' into 1412-global-mutual-information
wyli Jan 12, 2021
f616813
remove temp. scripts
wyli Jan 12, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/losses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,8 @@ Registration Losses
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: LocalNormalizedCrossCorrelationLoss
:members:

`GlobalMutualInformationLoss`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: GlobalMutualInformationLoss
:members:
2 changes: 1 addition & 1 deletion monai/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@
generalized_wasserstein_dice,
)
from .focal_loss import FocalLoss
from .image_dissimilarity import LocalNormalizedCrossCorrelationLoss
from .image_dissimilarity import GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss
from .tversky import TverskyLoss
34 changes: 14 additions & 20 deletions monai/losses/deform.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from monai.utils import LossReduction


def spatial_gradient(input: torch.Tensor, dim: int) -> torch.Tensor:
def spatial_gradient(x: torch.Tensor, dim: int) -> torch.Tensor:
"""
Calculate gradients on single dimension of a tensor using central finite difference.
It moves the tensor along the dimension to calculate the approximate gradient
Expand All @@ -26,7 +26,7 @@ def spatial_gradient(input: torch.Tensor, dim: int) -> torch.Tensor:
DeepReg (https://github.com/DeepRegNet/DeepReg)

Args:
input: the shape should be BCH(WD).
x: the shape should be BCH(WD).
dim: dimension to calculate gradient along.
Returns:
gradient_dx: the shape should be BCH(WD)
Expand All @@ -36,17 +36,17 @@ def spatial_gradient(input: torch.Tensor, dim: int) -> torch.Tensor:
slice_2_e = slice(None, -2)
slice_all = slice(None)
slicing_s, slicing_e = [slice_all, slice_all], [slice_all, slice_all]
while len(slicing_s) < input.ndim:
while len(slicing_s) < x.ndim:
slicing_s = slicing_s + [slice_1]
slicing_e = slicing_e + [slice_1]
slicing_s[dim] = slice_2_s
slicing_e[dim] = slice_2_e
return (input[slicing_s] - input[slicing_e]) / 2.0
return (x[slicing_s] - x[slicing_e]) / 2.0


class BendingEnergyLoss(_Loss):
"""
Calculate the bending energy based on second-order differentiation of input using central finite difference.
Calculate the bending energy based on second-order differentiation of pred using central finite difference.

Adapted from:
DeepReg (https://github.com/DeepRegNet/DeepReg)
Expand All @@ -67,35 +67,29 @@ def __init__(
"""
super(BendingEnergyLoss, self).__init__(reduction=LossReduction(reduction).value)

def forward(self, input: torch.Tensor) -> torch.Tensor:
def forward(self, pred: torch.Tensor) -> torch.Tensor:
"""
Args:
input: the shape should be BCH(WD)
pred: the shape should be BCH(WD)

Raises:
ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].

"""
assert input.ndim in [3, 4, 5], f"expecting 3-d, 4-d or 5-d input, instead got input of shape {input.shape}"
if input.ndim == 3:
assert input.shape[-1] > 4, f"all spatial dimensions must > 4, got input of shape {input.shape}"
elif input.ndim == 4:
assert (
input.shape[-1] > 4 and input.shape[-2] > 4
), f"all spatial dimensions must > 4, got input of shape {input.shape}"
elif input.ndim == 5:
assert (
input.shape[-1] > 4 and input.shape[-2] > 4 and input.shape[-3] > 4
), f"all spatial dimensions must > 4, got input of shape {input.shape}"
if pred.ndim not in [3, 4, 5]:
raise ValueError(f"expecting 3-d, 4-d or 5-d pred, instead got pred of shape {pred.shape}")
for i in range(pred.ndim - 2):
if pred.shape[-i - 1] <= 4:
raise ValueError("all spatial dimensions must > 4, got pred of shape {pred.shape}")

# first order gradient
first_order_gradient = [spatial_gradient(input, dim) for dim in range(2, input.ndim)]
first_order_gradient = [spatial_gradient(pred, dim) for dim in range(2, pred.ndim)]

energy = torch.tensor(0)
for dim_1, g in enumerate(first_order_gradient):
dim_1 += 2
energy = spatial_gradient(g, dim_1) ** 2 + energy
for dim_2 in range(dim_1 + 1, input.ndim):
for dim_2 in range(dim_1 + 1, pred.ndim):
energy = 2 * spatial_gradient(g, dim_2) ** 2 + energy

if self.reduction == LossReduction.MEAN.value:
Expand Down
137 changes: 111 additions & 26 deletions monai/losses/image_dissimilarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
from typing import Tuple, Union

import torch
from torch.nn import functional as F
Expand Down Expand Up @@ -63,7 +63,7 @@ def __init__(
self,
in_channels: int,
ndim: int = 3,
kernel_size: int = 9,
kernel_size: int = 3,
kernel_type: str = "rectangular",
reduction: Union[LossReduction, str] = LossReduction.MEAN,
smooth_nr: float = 1e-7,
Expand Down Expand Up @@ -100,40 +100,44 @@ def __init__(
f'Unsupported kernel_type: {kernel_type}, available options are ["rectangular", "triangular", "gaussian"].'
)
self.kernel = kernel_dict[kernel_type](self.kernel_size)
self.kernel_vol = torch.sum(self.kernel) ** self.ndim
self.kernel_vol = self.get_kernel_vol()

self.smooth_nr = float(smooth_nr)
self.smooth_dr = float(smooth_dr)

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
def get_kernel_vol(self):
vol = self.kernel
for _ in range(self.ndim - 1):
vol = torch.matmul(vol.unsqueeze(-1), self.kernel.unsqueeze(0))
return torch.sum(vol)

def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Args:
input: the shape should be BNH[WD].
pred: the shape should be BNH[WD].
target: the shape should be BNH[WD].
Raises:
ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
"""
assert (
input.shape[1] == self.in_channels
), f"expecting input with {self.in_channels} channels, got input of shape {input.shape}"
assert (
input.ndim - 2 == self.ndim
), f"expecting input with {self.ndim} spatial dimensions, got input of shape {input.shape}"
assert (
target.shape == input.shape
), f"ground truth has differing shape ({target.shape}) from input ({input.shape})"

t2, p2, tp = target ** 2, input ** 2, target * input

if pred.shape[1] != self.in_channels:
raise ValueError(f"expecting pred with {self.in_channels} channels, got pred of shape {pred.shape}")
if pred.ndim - 2 != self.ndim:
raise ValueError(f"expecting pred with {self.ndim} spatial dimensions, got pred of shape {pred.shape}")
if target.shape != pred.shape:
raise ValueError(f"ground truth has differing shape ({target.shape}) from pred ({pred.shape})")

t2, p2, tp = target ** 2, pred ** 2, target * pred
kernel, kernel_vol = self.kernel.to(pred), self.kernel_vol.to(pred)
# sum over kernel
t_sum = separable_filtering(target, kernels=[self.kernel] * self.ndim).sum(1, keepdim=True)
p_sum = separable_filtering(input, kernels=[self.kernel] * self.ndim).sum(1, keepdim=True)
t2_sum = separable_filtering(t2, kernels=[self.kernel] * self.ndim).sum(1, keepdim=True)
p2_sum = separable_filtering(p2, kernels=[self.kernel] * self.ndim).sum(1, keepdim=True)
tp_sum = separable_filtering(tp, kernels=[self.kernel] * self.ndim).sum(1, keepdim=True)
t_sum = separable_filtering(target, kernels=[kernel] * self.ndim)
p_sum = separable_filtering(pred, kernels=[kernel] * self.ndim)
t2_sum = separable_filtering(t2, kernels=[kernel] * self.ndim)
p2_sum = separable_filtering(p2, kernels=[kernel] * self.ndim)
tp_sum = separable_filtering(tp, kernels=[kernel] * self.ndim)

# average over kernel
t_avg = t_sum / self.kernel_vol
p_avg = p_sum / self.kernel_vol
t_avg = t_sum / kernel_vol
p_avg = p_sum / kernel_vol

# normalized cross correlation between t and p
# sum[(t - mean[t]) * (p - mean[p])] / std[t] / std[p]
Expand All @@ -151,9 +155,90 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
# shape = (batch, 1, D, H, W)

if self.reduction == LossReduction.SUM.value:
return torch.sum(ncc).neg() # sum over the batch and spatial ndims
return torch.sum(ncc).neg() # sum over the batch, channel and spatial ndims
if self.reduction == LossReduction.NONE.value:
return ncc.neg()
if self.reduction == LossReduction.MEAN.value:
return torch.mean(ncc).neg() # average over the batch and spatial ndims
return torch.mean(ncc).neg() # average over the batch, channel and spatial ndims
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')


class GlobalMutualInformationLoss(_Loss):
"""
Differentiable global mutual information loss via Parzen windowing method.

Reference:
https://dspace.mit.edu/handle/1721.1/123142, Section 3.1, equation 3.1-3.5, Algorithm 1
"""

def __init__(
self,
num_bins: int = 23,
sigma_ratio: float = 0.5,
reduction: Union[LossReduction, str] = LossReduction.MEAN,
smooth_nr: float = 1e-7,
smooth_dr: float = 1e-7,
) -> None:
"""
Args:
num_bins: number of bins for intensity
sigma_ratio: a hyper param for gaussian function
reduction: {``"none"``, ``"mean"``, ``"sum"``}
Specifies the reduction to apply to the output. Defaults to ``"mean"``.

- ``"none"``: no reduction will be applied.
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
- ``"sum"``: the output will be summed.
smooth_nr: a small constant added to the numerator to avoid nan.
smooth_dr: a small constant added to the denominator to avoid nan.
"""
super(GlobalMutualInformationLoss, self).__init__(reduction=LossReduction(reduction).value)
if num_bins <= 0:
raise ValueError("num_bins must > 0, got {num_bins}")
bin_centers = torch.linspace(0.0, 1.0, num_bins) # (num_bins,)
sigma = torch.mean(bin_centers[1:] - bin_centers[:-1]) * sigma_ratio
self.preterm = 1 / (2 * sigma ** 2)
self.bin_centers = bin_centers[None, None, ...]
self.smooth_nr = float(smooth_nr)
self.smooth_dr = float(smooth_dr)

def parzen_windowing(self, pred: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
pred: the shape should be B[NDHW].
"""
pred = torch.clamp(pred, 0, 1)
pred = pred.reshape(pred.shape[0], -1, 1) # (batch, num_sample, 1)
weight = torch.exp(
-self.preterm.to(pred) * (pred - self.bin_centers.to(pred)) ** 2
) # (batch, num_sample, num_bin)
weight = weight / torch.sum(weight, dim=-1, keepdim=True) # (batch, num_sample, num_bin)
probability = torch.mean(weight, dim=-2, keepdim=True) # (batch, 1, num_bin)
return weight, probability

def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Args:
pred: the shape should be B[NDHW].
target: the shape should be same as the pred shape.
Raises:
ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
"""
if target.shape != pred.shape:
raise ValueError(f"ground truth has differing shape ({target.shape}) from pred ({pred.shape})")
wa, pa = self.parzen_windowing(pred) # (batch, num_sample, num_bin), (batch, 1, num_bin)
wb, pb = self.parzen_windowing(target) # (batch, num_sample, num_bin), (batch, 1, num_bin)
pab = torch.bmm(wa.permute(0, 2, 1), wb).div(wa.shape[1]) # (batch, num_bins, num_bins)

papb = torch.bmm(pa.permute(0, 2, 1), pb) # (batch, num_bins, num_bins)
mi = torch.sum(
pab * torch.log((pab + self.smooth_nr) / (papb + self.smooth_dr) + self.smooth_dr), dim=(1, 2)
) # (batch)

if self.reduction == LossReduction.SUM.value:
return torch.sum(mi).neg() # sum over the batch and channel ndims
if self.reduction == LossReduction.NONE.value:
return mi.neg()
if self.reduction == LossReduction.MEAN.value:
return torch.mean(mi).neg() # average over the batch and channel ndims
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')
26 changes: 13 additions & 13 deletions tests/test_bending_energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,27 @@
TEST_CASES = [
[
{},
{"input": torch.ones((1, 3, 5, 5, 5))},
{"pred": torch.ones((1, 3, 5, 5, 5))},
0.0,
],
[
{},
{"input": torch.arange(0, 5)[None, None, None, None, :].expand(1, 3, 5, 5, 5)},
{"pred": torch.arange(0, 5)[None, None, None, None, :].expand(1, 3, 5, 5, 5)},
0.0,
],
[
{},
{"input": torch.arange(0, 5)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2},
{"pred": torch.arange(0, 5)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2},
4.0,
],
[
{},
{"input": torch.arange(0, 5)[None, None, None, :].expand(1, 3, 5, 5) ** 2},
{"pred": torch.arange(0, 5)[None, None, None, :].expand(1, 3, 5, 5) ** 2},
4.0,
],
[
{},
{"input": torch.arange(0, 5)[None, None, :].expand(1, 3, 5) ** 2},
{"pred": torch.arange(0, 5)[None, None, :].expand(1, 3, 5) ** 2},
4.0,
],
]
Expand All @@ -55,24 +55,24 @@ def test_shape(self, input_param, input_data, expected_val):
def test_ill_shape(self):
loss = BendingEnergyLoss()
# not in 3-d, 4-d, 5-d
with self.assertRaisesRegex(AssertionError, ""):
with self.assertRaisesRegex(ValueError, ""):
loss.forward(torch.ones((1, 3)))
with self.assertRaisesRegex(AssertionError, ""):
with self.assertRaisesRegex(ValueError, ""):
loss.forward(torch.ones((1, 3, 5, 5, 5, 5)))
# spatial_dim < 5
with self.assertRaisesRegex(AssertionError, ""):
with self.assertRaisesRegex(ValueError, ""):
loss.forward(torch.ones((1, 3, 4, 5, 5)))
with self.assertRaisesRegex(AssertionError, ""):
with self.assertRaisesRegex(ValueError, ""):
loss.forward(torch.ones((1, 3, 5, 4, 5)))
with self.assertRaisesRegex(AssertionError, ""):
with self.assertRaisesRegex(ValueError, ""):
loss.forward(torch.ones((1, 3, 5, 5, 4)))

def test_ill_opts(self):
input = torch.rand(1, 3, 5, 5, 5)
pred = torch.rand(1, 3, 5, 5, 5)
with self.assertRaisesRegex(ValueError, ""):
BendingEnergyLoss(reduction="unknown")(input)
BendingEnergyLoss(reduction="unknown")(pred)
with self.assertRaisesRegex(ValueError, ""):
BendingEnergyLoss(reduction=None)(input)
BendingEnergyLoss(reduction=None)(pred)


if __name__ == "__main__":
Expand Down
Loading