Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
c106bd3
Issue #690: add debug prints
mathpluscode Mar 7, 2021
fac630b
Issue #609: hardcode update_freq for tensorboard
mathpluscode Mar 22, 2021
a667dd6
do not use zero boundary
mathpluscode Mar 22, 2021
80e2ac8
Merge branch '708-log-more-metrics-in-tensorboard' into 690-nan-inf-loss
mathpluscode Mar 23, 2021
e0f73c7
Issue #690: modify lncc implementation
mathpluscode Mar 23, 2021
b141b3c
Issue #690: add EPS to numerator for LNCC
mathpluscode Mar 23, 2021
952ae52
Issue #690: add debug print
mathpluscode Mar 23, 2021
9f997d1
Issue #690: use tf.debugging.enable_check_numerics
mathpluscode Mar 23, 2021
8ff1898
Issue #690: attempt to skip some ops
mathpluscode Mar 24, 2021
1330286
Issue #690: check input image shapes
mathpluscode Mar 24, 2021
ef40478
Merge branch '708-log-more-metrics-in-tensorboard' into 690-nan-inf-loss
mathpluscode Mar 24, 2021
e2354ff
Issue #690: add debug check
mathpluscode Mar 25, 2021
f31f532
Issue #690: fix op name
mathpluscode Mar 25, 2021
0ad5f56
Issue #690: add check into graph
mathpluscode Mar 25, 2021
40be7c3
Issue #690: add check into lncc graph
mathpluscode Mar 25, 2021
e034571
Issue #690: add more checks into lncc
mathpluscode Mar 25, 2021
eebfbdd
Issue #690: assert denom >= 0
mathpluscode Mar 25, 2021
4dc9a08
Issue #690: assert based on nan
mathpluscode Mar 25, 2021
b2f6338
Issue #690: do proper debugging
mathpluscode Mar 25, 2021
e0219d8
Issue #690: fix typo
mathpluscode Mar 25, 2021
90aa434
Issue #690: add assert into graph
mathpluscode Mar 25, 2021
6c27061
Issue #690: revert change to LNCC and increase EPS
mathpluscode Mar 26, 2021
de96cad
Issue #690: add additional metrics
mathpluscode Mar 26, 2021
64e6561
Issue #690: remove arg
mathpluscode Mar 26, 2021
15631da
Issue #690: fit eagerly
mathpluscode Mar 26, 2021
1cbf23c
Issue #690: print min max and save array
mathpluscode Mar 26, 2021
5ac34c1
Issue #690: catch err
mathpluscode Mar 26, 2021
c40cb34
add print into grapj
mathpluscode Mar 26, 2021
06b76b0
Issue #690: hack ncc
mathpluscode Mar 26, 2021
259bbd9
Issue #690: do not compile eagerly
mathpluscode Mar 26, 2021
9e7203b
Issue #690: add metrics
mathpluscode Mar 26, 2021
b764112
Issue #690: fix typo
mathpluscode Mar 26, 2021
35beea7
Issue #690: fix typo
mathpluscode Mar 26, 2021
4c47097
Issue #690: do not use separable filter
mathpluscode Mar 26, 2021
357837a
Issue #690: correct lncc
mathpluscode Mar 26, 2021
23645a0
Issue #690: fix bug
mathpluscode Mar 26, 2021
01ba18a
Issue #690: divide filters
mathpluscode Mar 26, 2021
d1c63bf
Issue #690: simplify code
mathpluscode Mar 26, 2021
fc98187
Issue #690: modify filters
mathpluscode Mar 26, 2021
d696cb4
Issue #690: change back to 1d kernel
mathpluscode Mar 26, 2021
e910bda
Issue #690: clip variance and remove debug code
mathpluscode Mar 26, 2021
8de5810
Merge remote-tracking branch 'origin/main' into 690-nan-inf-loss
mathpluscode Mar 26, 2021
3a09c82
Issue #690: fix tests and remove debug changes
mathpluscode Mar 26, 2021
c8d62ae
Issue #690: change code so that no need to add test
mathpluscode Mar 26, 2021
493300b
Issue #690: fix pylint
mathpluscode Mar 26, 2021
c129e58
Issue #690: update changelog and fix test
mathpluscode Mar 26, 2021
2b7c72b
Merge branch 'main' into 690-nan-inf-loss
YipengHu Mar 27, 2021
adc2304
Merge branch 'main' into 690-nan-inf-loss
YipengHu Mar 27, 2021
c226933
Issue #690: modify LNCC implementation to be more stable
mathpluscode Mar 27, 2021
2756bff
Merge branch 'main' into 690-nan-inf-loss
mathpluscode Mar 27, 2021
ac3a46d
Issue #690: add reference for LNCC
mathpluscode Mar 27, 2021
27b1936
Merge remote-tracking branch 'origin/690-nan-inf-loss' into 690-nan-i…
mathpluscode Mar 27, 2021
6ad3312
Merge branch 'main' into 690-nan-inf-loss
mathpluscode Mar 28, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ compatible with the updates.

### Changed

- Increased all EPS to 1e-5.
- Clarify the suggestion in doc to use all-zero masks for missing labels.
- Moved contributor list to a separate page.
- Changed `no-test` flag to `full` for demo scripts.
Expand All @@ -37,6 +38,7 @@ compatible with the updates.

### Fixed

- Fixed LNCC loss regarding INF values.
- Removed loss weight checks to be more robust.
- Fixed import error under python 3.6.
- Fixed the residual module in local net architecture, compatible for previous
Expand Down
3 changes: 3 additions & 0 deletions deepreg/constant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""Module defining global constants."""

EPS = 1.0e-5
6 changes: 3 additions & 3 deletions deepreg/dataset/loader/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,10 +379,10 @@ def validate_images_and_labels(
for arr, name in zip(
[moving_image, fixed_image], ["moving_image", "fixed_image"]
):
if len(arr.shape) != 3:
if len(arr.shape) != 3 or min(arr.shape) <= 0:
raise ValueError(
f"Sample {image_indices}'s {name}' shape should be 3D. "
f"Got {arr.shape}."
f"Sample {image_indices}'s {name}' shape should be 3D"
f" and non-empty, got {arr.shape}."
)
# when data are labeled
if moving_label is not None and fixed_label is not None:
Expand Down
73 changes: 32 additions & 41 deletions deepreg/loss/image.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Provide different loss or metrics classes for images."""
import tensorflow as tf

from deepreg.constant import EPS
from deepreg.loss.util import NegativeLossMixin
from deepreg.loss.util import gaussian_kernel1d_size as gaussian_kernel1d
from deepreg.loss.util import (
Expand All @@ -10,8 +11,6 @@
)
from deepreg.registry import REGISTRY

EPS = tf.keras.backend.epsilon()


@REGISTRY.register_loss(name="ssd")
class SumSquaredDifference(tf.keras.losses.Loss):
Expand Down Expand Up @@ -156,27 +155,33 @@ class LocalNormalizedCrossCorrelation(tf.keras.losses.Loss):

E[t] = sum_i(w_i * t_i) / sum_i(w_i)

Here, we assume sum_i(w_i) == 1, means the weights have been normalized.

Similarly, the discrete variance in the window V[t] is

V[t] = E[t**2] - E[t] ** 2
V[t] = E[(t - E[t])**2]

The local squared zero-normalized cross-correlation is therefore

E[ (t-E[t]) * (p-E[p]) ] ** 2 / V[t] / V[p]

where the expectation in numerator is

E[ (t-E[t]) * (p-E[p]) ] = E[t * p] - E[t] * E[p]

Different kernel corresponds to different weights.
When calculating variance, we choose to subtract the mean first then calculte
variance instead of calculating E[t**2] - E[t] ** 2, the reason is that when
E[t**2] and E[t] ** 2 are both very large or very small, the subtraction may
have large rounding error and makes the result inaccurate. Also, it is not
guaranteed that the result >= 0. For more details, please read "Algorithms for
computing the sample variance: Analysis and recommendations." page 1.

For now, y_true and y_pred have to be at least 4d tensor, including batch axis.

Reference:

- Zero-normalized cross-correlation (ZNCC):
https://en.wikipedia.org/wiki/Cross-correlation
- Code: https://github.com/voxelmorph/voxelmorph/blob/legacy/src/losses.py
- https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights
- Chan, Tony F., Gene H. Golub, and Randall J. LeVeque.
"Algorithms for computing the sample variance: Analysis and recommendations."
The American Statistician 37.3 (1983): 242-247.
"""

kernel_fn_dict = dict(
Expand Down Expand Up @@ -212,13 +217,8 @@ def __init__(
self.kernel_size = kernel_size

# (kernel_size, )
# sum of the kernel weights would be one
self.kernel = self.kernel_fn(kernel_size=self.kernel_size)
# E[1] = sum_i(w_i), ()
self.kernel_vol = tf.reduce_sum(
self.kernel[:, None, None]
* self.kernel[None, :, None]
* self.kernel[None, None, :]
)

def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
"""
Expand All @@ -230,38 +230,29 @@ def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
or (batch, dim1, dim2, dim3, ch)
:return: shape = (batch,)
"""
# adjust
# adjust shape to be (batch, dim1, dim2, dim3, ch)
if len(y_true.shape) == 4:
y_true = tf.expand_dims(y_true, axis=4)
y_pred = tf.expand_dims(y_pred, axis=4)
assert len(y_true.shape) == len(y_pred.shape) == 5

# t = y_true, p = y_pred
# (batch, dim1, dim2, dim3, ch)
t2 = y_true * y_true
p2 = y_pred * y_pred
tp = y_true * y_pred

# sum over kernel
# (batch, dim1, dim2, dim3, 1)
t_sum = separable_filter(y_true, kernel=self.kernel) # E[t] * E[1]
p_sum = separable_filter(y_pred, kernel=self.kernel) # E[p] * E[1]
t2_sum = separable_filter(t2, kernel=self.kernel) # E[tt] * E[1]
p2_sum = separable_filter(p2, kernel=self.kernel) # E[pp] * E[1]
tp_sum = separable_filter(tp, kernel=self.kernel) # E[tp] * E[1]

# average over kernel
# (batch, dim1, dim2, dim3, 1)
t_avg = t_sum / self.kernel_vol # E[t]
p_avg = p_sum / self.kernel_vol # E[p]

# shape = (batch, dim1, dim2, dim3, 1)
cross = tp_sum - p_avg * t_sum # E[tp] * E[1] - E[p] * E[t] * E[1]
t_var = t2_sum - t_avg * t_sum # V[t] * E[1]
p_var = p2_sum - p_avg * p_sum # V[p] * E[1]

# (E[tp] - E[p] * E[t]) ** 2 / V[t] / V[p]
ncc = (cross * cross + EPS) / (t_var * p_var + EPS)
t_mean = separable_filter(y_true, kernel=self.kernel)
p_mean = separable_filter(y_pred, kernel=self.kernel)

t = y_true - t_mean
p = y_pred - p_mean

# the variance can be biased but as both num and denom are biased
# it got cancelled
# https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights
cross = separable_filter(t * p, kernel=self.kernel)
t_var = separable_filter(t * t, kernel=self.kernel)
p_var = separable_filter(p * p, kernel=self.kernel)

num = cross * cross
denom = t_var * p_var
ncc = (num + EPS) / (denom + EPS)

return tf.reduce_mean(ncc, axis=[1, 2, 3, 4])

Expand Down
3 changes: 1 addition & 2 deletions deepreg/loss/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@

import tensorflow as tf

from deepreg.constant import EPS
from deepreg.loss.util import NegativeLossMixin, cauchy_kernel1d
from deepreg.loss.util import gaussian_kernel1d_sigma as gaussian_kernel1d
from deepreg.loss.util import separable_filter
from deepreg.registry import REGISTRY

EPS = tf.keras.backend.epsilon()


class MultiScaleLoss(tf.keras.losses.Loss):
"""
Expand Down
23 changes: 14 additions & 9 deletions deepreg/loss/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,25 +27,25 @@ def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
return -super().call(y_true=y_true, y_pred=y_pred)


EPS = tf.keras.backend.epsilon()


def rectangular_kernel1d(kernel_size: int) -> tf.Tensor:
"""
Return a the 1D filter for separable convolution equivalent to a 3-D rectangular
kernel for LocalNormalizedCrossCorrelation.
Return a the 1D rectangular kernel for LocalNormalizedCrossCorrelation.

Sum of the weights is 1.

:param kernel_size: scalar, size of the 1-D kernel
:return: kernel_weights, of shape (kernel_size, )
"""

kernel = tf.ones(shape=(kernel_size,), dtype=tf.float32)
kernel = tf.ones(shape=(kernel_size,), dtype=tf.float32) / float(kernel_size)
return kernel


def triangular_kernel1d(kernel_size: int) -> tf.Tensor:
"""
1D triangular kernel.
Return a the 1D triangular kernel for LocalNormalizedCrossCorrelation.

Sum of the weights is 1.

Assume kernel_size is odd, it will be a smoothed from
a kernel which center part is zero.
Expand Down Expand Up @@ -73,13 +73,17 @@ def triangular_kernel1d(kernel_size: int) -> tf.Tensor:
kernel = tf.nn.conv1d(
kernel[None, :, None], filters=filters, stride=[1, 1, 1], padding="SAME"
)
kernel = kernel / tf.reduce_sum(kernel)

return kernel[0, :, 0]


def gaussian_kernel1d_size(kernel_size: int) -> tf.Tensor:
"""
Return a the 1D filter for separable convolution equivalent to a 3-D Gaussian
kernel for LocalNormalizedCrossCorrelation.
Return a the 1D Gaussian kernel for LocalNormalizedCrossCorrelation.

Sum of the weights is 1.

:param kernel_size: scalar, size of the 1-D kernel
:return: filters, of shape (kernel_size, )
"""
Expand All @@ -88,6 +92,7 @@ def gaussian_kernel1d_size(kernel_size: int) -> tf.Tensor:

grid = tf.range(0, kernel_size, dtype=tf.float32)
filters = tf.exp(-tf.square(grid - mean) / (2 * sigma ** 2))
filters = filters / tf.reduce_sum(filters)

return filters

Expand Down
6 changes: 3 additions & 3 deletions deepreg/model/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ class RegistrationModel(tf.keras.Model):

def __init__(
self,
moving_image_size: tuple,
fixed_image_size: tuple,
moving_image_size: Tuple,
fixed_image_size: Tuple,
index_size: int,
labeled: bool,
batch_size: int,
Expand Down Expand Up @@ -61,6 +61,7 @@ def __init__(
self.config = config
self.num_devices = num_devices
self.global_batch_size = num_devices * batch_size
assert self.global_batch_size > 0

self._inputs = None # save inputs of self._model as dict
self._outputs = None # save outputs of self._model as dict
Expand Down Expand Up @@ -222,7 +223,6 @@ def _build_loss(self, name: str, inputs_dict: dict):

# add loss
self._model.add_loss(weighted_loss)

# add metric
self._model.add_metric(
loss_value, name=f"loss/{name}_{loss_layer.name}", aggregation="mean"
Expand Down
2 changes: 1 addition & 1 deletion test/unit/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def mock_sample_index_generator():
fixed_label=None,
image_indices=[1],
)
assert "Sample [1]'s moving_image' shape should be 3D. " in str(err_info.value)
assert "Sample [1]'s moving_image' shape should be 3D" in str(err_info.value)
with pytest.raises(ValueError) as err_info:
generator.validate_images_and_labels(
fixed_image=dummy_array,
Expand Down
13 changes: 7 additions & 6 deletions test/unit/test_loss_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import tensorflow as tf

import deepreg.loss.label as label
from deepreg.constant import EPS


class TestMultiScaleLoss:
Expand Down Expand Up @@ -91,11 +92,11 @@ def y_pred(self):
@pytest.mark.parametrize(
"binary,background_weight,scales,expected",
[
(True, 0.0, None, -np.log(1.0e-7)),
(False, 0.0, None, -0.6 * np.log(0.3)),
(False, 0.2, None, -0.48 * np.log(0.3) - 0.08 * np.log(0.7)),
(False, 0.2, [0, 0], -0.48 * np.log(0.3) - 0.08 * np.log(0.7)),
(False, 0.2, [0, 1], 0.5239637),
(True, 0.0, None, -np.log(EPS)),
(False, 0.0, None, -0.6 * np.log(0.3 + EPS)),
(False, 0.2, None, -0.48 * np.log(0.3 + EPS) - 0.08 * np.log(0.7 + EPS)),
(False, 0.2, [0, 0], -0.48 * np.log(0.3 + EPS) - 0.08 * np.log(0.7 + EPS)),
(False, 0.2, [0, 1], 0.5239465),
],
)
def test_call(self, y_true, y_pred, binary, background_weight, scales, expected):
Expand Down Expand Up @@ -135,7 +136,7 @@ def y_pred(self):
(True, None, 0),
(False, None, 0.25),
(False, [0, 0], 0.25),
(False, [0, 1], 0.17484076),
(False, [0, 1], 0.17485845),
],
)
def test_call(self, y_true, y_pred, binary, scales, expected):
Expand Down
3 changes: 3 additions & 0 deletions test/unit/test_loss_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def test_gaussian_kernel1d_size(kernel_size):

grid = tf.range(0, kernel_size, dtype=tf.float32)
expected = tf.exp(-tf.square(grid - mean) / (2 * sigma ** 2))
expected = expected / tf.reduce_sum(expected)

got = gaussian_kernel1d_size(kernel_size)
assert is_equal_tf(got, expected)
Expand All @@ -75,6 +76,7 @@ def test_rectangular_kernel1d(kernel_size):
:return:
"""
expected = tf.ones(shape=(kernel_size,), dtype=tf.float32)
expected = expected / tf.reduce_sum(expected)
got = rectangular_kernel1d(kernel_size)
assert is_equal_tf(got, expected)

Expand All @@ -91,6 +93,7 @@ def test_triangular_kernel1d(kernel_size):
for it_k in range(kernel_size // 2):
expected[it_k] = it_k + 1
expected[-it_k - 1] = it_k + 1
expected = expected / tf.reduce_sum(expected)

got = triangular_kernel1d(kernel_size)
assert is_equal_tf(got, expected)
Expand Down
6 changes: 4 additions & 2 deletions test/unit/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import numpy as np
import tensorflow as tf

from deepreg.constant import EPS


def is_equal_np(
x: Union[np.ndarray, List], y: Union[np.ndarray, List], atol: float = 1.0e-7
x: Union[np.ndarray, List], y: Union[np.ndarray, List], atol: float = EPS
) -> bool:
"""
Check if two numpy arrays are identical.
Expand All @@ -23,7 +25,7 @@ def is_equal_np(
def is_equal_tf(
x: Union[tf.Tensor, np.ndarray, List],
y: Union[tf.Tensor, np.ndarray, List],
atol: float = 1.0e-7,
atol: float = EPS,
) -> bool:
"""
Check if two tf tensors are identical.
Expand Down