Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ compatible with the updates.

### Changed

- Increased all EPS to 1e-5.
- Clarify the suggestion in doc to use all-zero masks for missing labels.
- Moved contributor list to a separate page.
- Changed `no-test` flag to `full` for demo scripts.
Expand All @@ -38,7 +37,6 @@ compatible with the updates.

### Fixed

- Fixed LNCC loss regarding INF values.
- Removed loss weight checks to be more robust.
- Fixed import error under python 3.6.
- Fixed the residual module in local net architecture, compatible for previous
Expand Down
3 changes: 0 additions & 3 deletions deepreg/constant.py

This file was deleted.

6 changes: 3 additions & 3 deletions deepreg/dataset/loader/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,10 +379,10 @@ def validate_images_and_labels(
for arr, name in zip(
[moving_image, fixed_image], ["moving_image", "fixed_image"]
):
if len(arr.shape) != 3 or min(arr.shape) <= 0:
if len(arr.shape) != 3:
raise ValueError(
f"Sample {image_indices}'s {name}' shape should be 3D"
f" and non-empty, got {arr.shape}."
f"Sample {image_indices}'s {name}' shape should be 3D. "
f"Got {arr.shape}."
)
# when data are labeled
if moving_label is not None and fixed_label is not None:
Expand Down
73 changes: 41 additions & 32 deletions deepreg/loss/image.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Provide different loss or metrics classes for images."""
import tensorflow as tf

from deepreg.constant import EPS
from deepreg.loss.util import NegativeLossMixin
from deepreg.loss.util import gaussian_kernel1d_size as gaussian_kernel1d
from deepreg.loss.util import (
Expand All @@ -11,6 +10,8 @@
)
from deepreg.registry import REGISTRY

EPS = tf.keras.backend.epsilon()


@REGISTRY.register_loss(name="ssd")
class SumSquaredDifference(tf.keras.losses.Loss):
Expand Down Expand Up @@ -155,33 +156,27 @@ class LocalNormalizedCrossCorrelation(tf.keras.losses.Loss):

E[t] = sum_i(w_i * t_i) / sum_i(w_i)

Here, we assume sum_i(w_i) == 1, means the weights have been normalized.

Similarly, the discrete variance in the window V[t] is

V[t] = E[(t - E[t])**2]
V[t] = E[t**2] - E[t] ** 2

The local squared zero-normalized cross-correlation is therefore

E[ (t-E[t]) * (p-E[p]) ] ** 2 / V[t] / V[p]

When calculating variance, we choose to subtract the mean first then calculte
variance instead of calculating E[t**2] - E[t] ** 2, the reason is that when
E[t**2] and E[t] ** 2 are both very large or very small, the subtraction may
have large rounding error and makes the result inaccurate. Also, it is not
guaranteed that the result >= 0. For more details, please read "Algorithms for
computing the sample variance: Analysis and recommendations." page 1.
where the expectation in numerator is

E[ (t-E[t]) * (p-E[p]) ] = E[t * p] - E[t] * E[p]

Different kernel corresponds to different weights.

For now, y_true and y_pred have to be at least 4d tensor, including batch axis.

Reference:

- Zero-normalized cross-correlation (ZNCC):
https://en.wikipedia.org/wiki/Cross-correlation
- https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights
- Chan, Tony F., Gene H. Golub, and Randall J. LeVeque.
"Algorithms for computing the sample variance: Analysis and recommendations."
The American Statistician 37.3 (1983): 242-247.
- Code: https://github.com/voxelmorph/voxelmorph/blob/legacy/src/losses.py
"""

kernel_fn_dict = dict(
Expand Down Expand Up @@ -217,8 +212,13 @@ def __init__(
self.kernel_size = kernel_size

# (kernel_size, )
# sum of the kernel weights would be one
self.kernel = self.kernel_fn(kernel_size=self.kernel_size)
# E[1] = sum_i(w_i), ()
self.kernel_vol = tf.reduce_sum(
self.kernel[:, None, None]
* self.kernel[None, :, None]
* self.kernel[None, None, :]
)

def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
"""
Expand All @@ -230,29 +230,38 @@ def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
or (batch, dim1, dim2, dim3, ch)
:return: shape = (batch,)
"""
# adjust shape to be (batch, dim1, dim2, dim3, ch)
# adjust
if len(y_true.shape) == 4:
y_true = tf.expand_dims(y_true, axis=4)
y_pred = tf.expand_dims(y_pred, axis=4)
assert len(y_true.shape) == len(y_pred.shape) == 5

# t = y_true, p = y_pred
t_mean = separable_filter(y_true, kernel=self.kernel)
p_mean = separable_filter(y_pred, kernel=self.kernel)

t = y_true - t_mean
p = y_pred - p_mean

# the variance can be biased but as both num and denom are biased
# it got cancelled
# https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights
cross = separable_filter(t * p, kernel=self.kernel)
t_var = separable_filter(t * t, kernel=self.kernel)
p_var = separable_filter(p * p, kernel=self.kernel)

num = cross * cross
denom = t_var * p_var
ncc = (num + EPS) / (denom + EPS)
# (batch, dim1, dim2, dim3, ch)
t2 = y_true * y_true
p2 = y_pred * y_pred
tp = y_true * y_pred

# sum over kernel
# (batch, dim1, dim2, dim3, 1)
t_sum = separable_filter(y_true, kernel=self.kernel) # E[t] * E[1]
p_sum = separable_filter(y_pred, kernel=self.kernel) # E[p] * E[1]
t2_sum = separable_filter(t2, kernel=self.kernel) # E[tt] * E[1]
p2_sum = separable_filter(p2, kernel=self.kernel) # E[pp] * E[1]
tp_sum = separable_filter(tp, kernel=self.kernel) # E[tp] * E[1]

# average over kernel
# (batch, dim1, dim2, dim3, 1)
t_avg = t_sum / self.kernel_vol # E[t]
p_avg = p_sum / self.kernel_vol # E[p]

# shape = (batch, dim1, dim2, dim3, 1)
cross = tp_sum - p_avg * t_sum # E[tp] * E[1] - E[p] * E[t] * E[1]
t_var = t2_sum - t_avg * t_sum # V[t] * E[1]
p_var = p2_sum - p_avg * p_sum # V[p] * E[1]

# (E[tp] - E[p] * E[t]) ** 2 / V[t] / V[p]
ncc = (cross * cross + EPS) / (t_var * p_var + EPS)

return tf.reduce_mean(ncc, axis=[1, 2, 3, 4])

Expand Down
3 changes: 2 additions & 1 deletion deepreg/loss/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@

import tensorflow as tf

from deepreg.constant import EPS
from deepreg.loss.util import NegativeLossMixin, cauchy_kernel1d
from deepreg.loss.util import gaussian_kernel1d_sigma as gaussian_kernel1d
from deepreg.loss.util import separable_filter
from deepreg.registry import REGISTRY

EPS = tf.keras.backend.epsilon()


class MultiScaleLoss(tf.keras.losses.Loss):
"""
Expand Down
23 changes: 9 additions & 14 deletions deepreg/loss/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,25 +27,25 @@ def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
return -super().call(y_true=y_true, y_pred=y_pred)


EPS = tf.keras.backend.epsilon()


def rectangular_kernel1d(kernel_size: int) -> tf.Tensor:
"""
Return a the 1D rectangular kernel for LocalNormalizedCrossCorrelation.

Sum of the weights is 1.
Return a the 1D filter for separable convolution equivalent to a 3-D rectangular
kernel for LocalNormalizedCrossCorrelation.

:param kernel_size: scalar, size of the 1-D kernel
:return: kernel_weights, of shape (kernel_size, )
"""

kernel = tf.ones(shape=(kernel_size,), dtype=tf.float32) / float(kernel_size)
kernel = tf.ones(shape=(kernel_size,), dtype=tf.float32)
return kernel


def triangular_kernel1d(kernel_size: int) -> tf.Tensor:
"""
Return a the 1D triangular kernel for LocalNormalizedCrossCorrelation.

Sum of the weights is 1.
1D triangular kernel.

Assume kernel_size is odd, it will be a smoothed from
a kernel which center part is zero.
Expand Down Expand Up @@ -73,17 +73,13 @@ def triangular_kernel1d(kernel_size: int) -> tf.Tensor:
kernel = tf.nn.conv1d(
kernel[None, :, None], filters=filters, stride=[1, 1, 1], padding="SAME"
)
kernel = kernel / tf.reduce_sum(kernel)

return kernel[0, :, 0]


def gaussian_kernel1d_size(kernel_size: int) -> tf.Tensor:
"""
Return a the 1D Gaussian kernel for LocalNormalizedCrossCorrelation.

Sum of the weights is 1.

Return a the 1D filter for separable convolution equivalent to a 3-D Gaussian
kernel for LocalNormalizedCrossCorrelation.
:param kernel_size: scalar, size of the 1-D kernel
:return: filters, of shape (kernel_size, )
"""
Expand All @@ -92,7 +88,6 @@ def gaussian_kernel1d_size(kernel_size: int) -> tf.Tensor:

grid = tf.range(0, kernel_size, dtype=tf.float32)
filters = tf.exp(-tf.square(grid - mean) / (2 * sigma ** 2))
filters = filters / tf.reduce_sum(filters)

return filters

Expand Down
6 changes: 3 additions & 3 deletions deepreg/model/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ class RegistrationModel(tf.keras.Model):

def __init__(
self,
moving_image_size: Tuple,
fixed_image_size: Tuple,
moving_image_size: tuple,
fixed_image_size: tuple,
index_size: int,
labeled: bool,
batch_size: int,
Expand Down Expand Up @@ -61,7 +61,6 @@ def __init__(
self.config = config
self.num_devices = num_devices
self.global_batch_size = num_devices * batch_size
assert self.global_batch_size > 0

self._inputs = None # save inputs of self._model as dict
self._outputs = None # save outputs of self._model as dict
Expand Down Expand Up @@ -223,6 +222,7 @@ def _build_loss(self, name: str, inputs_dict: dict):

# add loss
self._model.add_loss(weighted_loss)

# add metric
self._model.add_metric(
loss_value, name=f"loss/{name}_{loss_layer.name}", aggregation="mean"
Expand Down
2 changes: 1 addition & 1 deletion test/unit/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def mock_sample_index_generator():
fixed_label=None,
image_indices=[1],
)
assert "Sample [1]'s moving_image' shape should be 3D" in str(err_info.value)
assert "Sample [1]'s moving_image' shape should be 3D. " in str(err_info.value)
with pytest.raises(ValueError) as err_info:
generator.validate_images_and_labels(
fixed_image=dummy_array,
Expand Down
13 changes: 6 additions & 7 deletions test/unit/test_loss_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import tensorflow as tf

import deepreg.loss.label as label
from deepreg.constant import EPS


class TestMultiScaleLoss:
Expand Down Expand Up @@ -92,11 +91,11 @@ def y_pred(self):
@pytest.mark.parametrize(
"binary,background_weight,scales,expected",
[
(True, 0.0, None, -np.log(EPS)),
(False, 0.0, None, -0.6 * np.log(0.3 + EPS)),
(False, 0.2, None, -0.48 * np.log(0.3 + EPS) - 0.08 * np.log(0.7 + EPS)),
(False, 0.2, [0, 0], -0.48 * np.log(0.3 + EPS) - 0.08 * np.log(0.7 + EPS)),
(False, 0.2, [0, 1], 0.5239465),
(True, 0.0, None, -np.log(1.0e-7)),
(False, 0.0, None, -0.6 * np.log(0.3)),
(False, 0.2, None, -0.48 * np.log(0.3) - 0.08 * np.log(0.7)),
(False, 0.2, [0, 0], -0.48 * np.log(0.3) - 0.08 * np.log(0.7)),
(False, 0.2, [0, 1], 0.5239637),
],
)
def test_call(self, y_true, y_pred, binary, background_weight, scales, expected):
Expand Down Expand Up @@ -136,7 +135,7 @@ def y_pred(self):
(True, None, 0),
(False, None, 0.25),
(False, [0, 0], 0.25),
(False, [0, 1], 0.17485845),
(False, [0, 1], 0.17484076),
],
)
def test_call(self, y_true, y_pred, binary, scales, expected):
Expand Down
3 changes: 0 additions & 3 deletions test/unit/test_loss_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def test_gaussian_kernel1d_size(kernel_size):

grid = tf.range(0, kernel_size, dtype=tf.float32)
expected = tf.exp(-tf.square(grid - mean) / (2 * sigma ** 2))
expected = expected / tf.reduce_sum(expected)

got = gaussian_kernel1d_size(kernel_size)
assert is_equal_tf(got, expected)
Expand All @@ -76,7 +75,6 @@ def test_rectangular_kernel1d(kernel_size):
:return:
"""
expected = tf.ones(shape=(kernel_size,), dtype=tf.float32)
expected = expected / tf.reduce_sum(expected)
got = rectangular_kernel1d(kernel_size)
assert is_equal_tf(got, expected)

Expand All @@ -93,7 +91,6 @@ def test_triangular_kernel1d(kernel_size):
for it_k in range(kernel_size // 2):
expected[it_k] = it_k + 1
expected[-it_k - 1] = it_k + 1
expected = expected / tf.reduce_sum(expected)

got = triangular_kernel1d(kernel_size)
assert is_equal_tf(got, expected)
Expand Down
6 changes: 2 additions & 4 deletions test/unit/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@
import numpy as np
import tensorflow as tf

from deepreg.constant import EPS


def is_equal_np(
x: Union[np.ndarray, List], y: Union[np.ndarray, List], atol: float = EPS
x: Union[np.ndarray, List], y: Union[np.ndarray, List], atol: float = 1.0e-7
) -> bool:
"""
Check if two numpy arrays are identical.
Expand All @@ -25,7 +23,7 @@ def is_equal_np(
def is_equal_tf(
x: Union[tf.Tensor, np.ndarray, List],
y: Union[tf.Tensor, np.ndarray, List],
atol: float = EPS,
atol: float = 1.0e-7,
) -> bool:
"""
Check if two tf tensors are identical.
Expand Down