From 3cdd179578cc4e1bad4ad620db28c4b514d0a816 Mon Sep 17 00:00:00 2001 From: Yunguan Fu Date: Thu, 1 Apr 2021 01:27:41 +0100 Subject: [PATCH 1/4] Issue #690: fix negative variance issue and add accurate test --- CHANGELOG.md | 2 ++ deepreg/constant.py | 3 +++ deepreg/dataset/loader/interface.py | 6 +++--- deepreg/loss/image.py | 31 +++++++++++++++++++++-------- deepreg/loss/label.py | 3 +-- deepreg/loss/util.py | 13 +++++------- deepreg/model/network.py | 4 ++-- test/unit/test_interface.py | 2 +- test/unit/test_loss_image.py | 31 +++++++++++++++++++++++++++++ test/unit/test_loss_label.py | 13 ++++++------ test/unit/util.py | 6 ++++-- 11 files changed, 82 insertions(+), 32 deletions(-) create mode 100644 deepreg/constant.py diff --git a/CHANGELOG.md b/CHANGELOG.md index d0089f17f..720ca54f1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ compatible with the updates. ### Changed +- Increased all EPS to 1e-5. - Clarify the suggestion in doc to use all-zero masks for missing labels. - Moved contributor list to a separate page. - Changed `no-test` flag to `full` for demo scripts. @@ -37,6 +38,7 @@ compatible with the updates. ### Fixed +- Fixed LNCC loss regarding INF values. - Removed loss weight checks to be more robust. - Fixed import error under python 3.6. - Fixed the residual module in local net architecture, compatible for previous diff --git a/deepreg/constant.py b/deepreg/constant.py new file mode 100644 index 000000000..72746ee3b --- /dev/null +++ b/deepreg/constant.py @@ -0,0 +1,3 @@ +"""Module defining global constants.""" + +EPS = 1.0e-5 diff --git a/deepreg/dataset/loader/interface.py b/deepreg/dataset/loader/interface.py index 1c89d5e53..997448195 100644 --- a/deepreg/dataset/loader/interface.py +++ b/deepreg/dataset/loader/interface.py @@ -379,10 +379,10 @@ def validate_images_and_labels( for arr, name in zip( [moving_image, fixed_image], ["moving_image", "fixed_image"] ): - if len(arr.shape) != 3: + if len(arr.shape) != 3 or min(arr.shape) <= 0: raise ValueError( - f"Sample {image_indices}'s {name}' shape should be 3D. " - f"Got {arr.shape}." + f"Sample {image_indices}'s {name}' shape should be 3D" + f" and non-empty, got {arr.shape}." ) # when data are labeled if moving_label is not None and fixed_label is not None: diff --git a/deepreg/loss/image.py b/deepreg/loss/image.py index 44510bdde..aa202e4d6 100644 --- a/deepreg/loss/image.py +++ b/deepreg/loss/image.py @@ -1,6 +1,7 @@ """Provide different loss or metrics classes for images.""" import tensorflow as tf +from deepreg.constant import EPS from deepreg.loss.util import NegativeLossMixin from deepreg.loss.util import gaussian_kernel1d_size as gaussian_kernel1d from deepreg.loss.util import ( @@ -10,8 +11,6 @@ ) from deepreg.registry import REGISTRY -EPS = tf.keras.backend.epsilon() - @REGISTRY.register_loss(name="ssd") class SumSquaredDifference(tf.keras.losses.Loss): @@ -220,21 +219,20 @@ def __init__( * self.kernel[None, None, :] ) - def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor: + def calc_ncc(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor: """ - Return loss for a batch. + Return NCC for a batch. :param y_true: shape = (batch, dim1, dim2, dim3) - or (batch, dim1, dim2, dim3, ch) + or (batch, dim1, dim2, dim3, 1) :param y_pred: shape = (batch, dim1, dim2, dim3) - or (batch, dim1, dim2, dim3, ch) - :return: shape = (batch,) + or (batch, dim1, dim2, dim3, 1) + :return: shape = (batch, dim1, dim2, dim3. 1) """ # adjust if len(y_true.shape) == 4: y_true = tf.expand_dims(y_true, axis=4) y_pred = tf.expand_dims(y_pred, axis=4) - assert len(y_true.shape) == len(y_pred.shape) == 5 # t = y_true, p = y_pred # (batch, dim1, dim2, dim3, ch) @@ -260,9 +258,26 @@ def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor: t_var = t2_sum - t_avg * t_sum # V[t] * E[1] p_var = p2_sum - p_avg * p_sum # V[p] * E[1] + # ensure variance >= 0 + t_var = tf.maximum(t_var, 0) + p_var = tf.maximum(p_var, 0) + # (E[tp] - E[p] * E[t]) ** 2 / V[t] / V[p] ncc = (cross * cross + EPS) / (t_var * p_var + EPS) + return ncc + + def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor: + """ + Return loss for a batch. + + :param y_true: shape = (batch, dim1, dim2, dim3) + or (batch, dim1, dim2, dim3, ch) + :param y_pred: shape = (batch, dim1, dim2, dim3) + or (batch, dim1, dim2, dim3, ch) + :return: shape = (batch,) + """ + ncc = self.calc_ncc(y_true=y_true, y_pred=y_pred) return tf.reduce_mean(ncc, axis=[1, 2, 3, 4]) def get_config(self) -> dict: diff --git a/deepreg/loss/label.py b/deepreg/loss/label.py index 0e0d94bad..37429f51a 100644 --- a/deepreg/loss/label.py +++ b/deepreg/loss/label.py @@ -4,13 +4,12 @@ import tensorflow as tf +from deepreg.constant import EPS from deepreg.loss.util import NegativeLossMixin, cauchy_kernel1d from deepreg.loss.util import gaussian_kernel1d_sigma as gaussian_kernel1d from deepreg.loss.util import separable_filter from deepreg.registry import REGISTRY -EPS = tf.keras.backend.epsilon() - class MultiScaleLoss(tf.keras.losses.Loss): """ diff --git a/deepreg/loss/util.py b/deepreg/loss/util.py index 6f453d73c..05083461f 100644 --- a/deepreg/loss/util.py +++ b/deepreg/loss/util.py @@ -27,13 +27,9 @@ def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor: return -super().call(y_true=y_true, y_pred=y_pred) -EPS = tf.keras.backend.epsilon() - - def rectangular_kernel1d(kernel_size: int) -> tf.Tensor: """ - Return a the 1D filter for separable convolution equivalent to a 3-D rectangular - kernel for LocalNormalizedCrossCorrelation. + Return a the 1D rectangular kernel for LocalNormalizedCrossCorrelation. :param kernel_size: scalar, size of the 1-D kernel :return: kernel_weights, of shape (kernel_size, ) @@ -45,7 +41,7 @@ def rectangular_kernel1d(kernel_size: int) -> tf.Tensor: def triangular_kernel1d(kernel_size: int) -> tf.Tensor: """ - 1D triangular kernel. + Return a the 1D triangular kernel for LocalNormalizedCrossCorrelation. Assume kernel_size is odd, it will be a smoothed from a kernel which center part is zero. @@ -73,13 +69,14 @@ def triangular_kernel1d(kernel_size: int) -> tf.Tensor: kernel = tf.nn.conv1d( kernel[None, :, None], filters=filters, stride=[1, 1, 1], padding="SAME" ) + return kernel[0, :, 0] def gaussian_kernel1d_size(kernel_size: int) -> tf.Tensor: """ - Return a the 1D filter for separable convolution equivalent to a 3-D Gaussian - kernel for LocalNormalizedCrossCorrelation. + Return a the 1D Gaussian kernel for LocalNormalizedCrossCorrelation. + :param kernel_size: scalar, size of the 1-D kernel :return: filters, of shape (kernel_size, ) """ diff --git a/deepreg/model/network.py b/deepreg/model/network.py index 227a20d4a..de1140d05 100644 --- a/deepreg/model/network.py +++ b/deepreg/model/network.py @@ -30,8 +30,8 @@ class RegistrationModel(tf.keras.Model): def __init__( self, - moving_image_size: tuple, - fixed_image_size: tuple, + moving_image_size: Tuple, + fixed_image_size: Tuple, index_size: int, labeled: bool, batch_size: int, diff --git a/test/unit/test_interface.py b/test/unit/test_interface.py index 0315938d5..22778f610 100644 --- a/test/unit/test_interface.py +++ b/test/unit/test_interface.py @@ -356,7 +356,7 @@ def mock_sample_index_generator(): fixed_label=None, image_indices=[1], ) - assert "Sample [1]'s moving_image' shape should be 3D. " in str(err_info.value) + assert "Sample [1]'s moving_image' shape should be 3D" in str(err_info.value) with pytest.raises(ValueError) as err_info: generator.validate_images_and_labels( fixed_image=dummy_array, diff --git a/test/unit/test_loss_image.py b/test/unit/test_loss_image.py index 9d90baf93..139078206 100644 --- a/test/unit/test_loss_image.py +++ b/test/unit/test_loss_image.py @@ -95,6 +95,37 @@ def test_zero_info(self, y_true, y_pred, shape, kernel_type, expected): ) assert is_equal_tf(got, expected) + @pytest.mark.parametrize( + "kernel_size", + [3, 5, 7], + ) + def test_exact_value(self, kernel_size): + """ + Test the exact value at the center of a cube. + + :param kernel_size: size of the kernel and the cube. + """ + mid = kernel_size // 2 + 1 + y_true = tf.random.uniform(shape=(1, kernel_size, kernel_size, kernel_size, 1)) + y_pred = tf.random.uniform(shape=(1, kernel_size, kernel_size, kernel_size, 1)) + + loss = image.LocalNormalizedCrossCorrelation(kernel_size=kernel_size) + got = loss.calc_ncc(y_true=y_true, y_pred=y_pred) + got = got[0, mid, mid, mid, 0] + + y_true_mean = tf.reduce_mean(y_true) + y_true_std = tf.math.reduce_std(y_true) + + y_pred_mean = tf.reduce_mean(y_pred) + y_pred_std = tf.math.reduce_std(y_pred) + + num = tf.reduce_mean((y_true - y_true_mean) * (y_pred - y_pred_mean)) + denom = y_true_std * y_pred_std + + expected = (num / denom) ** 2 + + assert is_equal_tf(got, expected) + def test_error(self): y = np.ones(shape=(3, 3, 3, 3)) with pytest.raises(ValueError) as err_info: diff --git a/test/unit/test_loss_label.py b/test/unit/test_loss_label.py index ba43f2eca..60007d685 100644 --- a/test/unit/test_loss_label.py +++ b/test/unit/test_loss_label.py @@ -12,6 +12,7 @@ import tensorflow as tf import deepreg.loss.label as label +from deepreg.constant import EPS class TestMultiScaleLoss: @@ -91,11 +92,11 @@ def y_pred(self): @pytest.mark.parametrize( "binary,background_weight,scales,expected", [ - (True, 0.0, None, -np.log(1.0e-7)), - (False, 0.0, None, -0.6 * np.log(0.3)), - (False, 0.2, None, -0.48 * np.log(0.3) - 0.08 * np.log(0.7)), - (False, 0.2, [0, 0], -0.48 * np.log(0.3) - 0.08 * np.log(0.7)), - (False, 0.2, [0, 1], 0.5239637), + (True, 0.0, None, -np.log(EPS)), + (False, 0.0, None, -0.6 * np.log(0.3 + EPS)), + (False, 0.2, None, -0.48 * np.log(0.3 + EPS) - 0.08 * np.log(0.7 + EPS)), + (False, 0.2, [0, 0], -0.48 * np.log(0.3 + EPS) - 0.08 * np.log(0.7 + EPS)), + (False, 0.2, [0, 1], 0.5239465), ], ) def test_call(self, y_true, y_pred, binary, background_weight, scales, expected): @@ -135,7 +136,7 @@ def y_pred(self): (True, None, 0), (False, None, 0.25), (False, [0, 0], 0.25), - (False, [0, 1], 0.17484076), + (False, [0, 1], 0.17485845), ], ) def test_call(self, y_true, y_pred, binary, scales, expected): diff --git a/test/unit/util.py b/test/unit/util.py index 02866779a..c40128d69 100644 --- a/test/unit/util.py +++ b/test/unit/util.py @@ -3,9 +3,11 @@ import numpy as np import tensorflow as tf +from deepreg.constant import EPS + def is_equal_np( - x: Union[np.ndarray, List], y: Union[np.ndarray, List], atol: float = 1.0e-7 + x: Union[np.ndarray, List], y: Union[np.ndarray, List], atol: float = EPS ) -> bool: """ Check if two numpy arrays are identical. @@ -23,7 +25,7 @@ def is_equal_np( def is_equal_tf( x: Union[tf.Tensor, np.ndarray, List], y: Union[tf.Tensor, np.ndarray, List], - atol: float = 1.0e-7, + atol: float = EPS, ) -> bool: """ Check if two tf tensors are identical. From 99f7a990c6984a5a5f048372622cbd8039c62023 Mon Sep 17 00:00:00 2001 From: Yunguan Fu Date: Thu, 1 Apr 2021 23:42:44 +0100 Subject: [PATCH 2/4] Issue #690: fix pytest --- test/unit/test_loss_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/unit/test_loss_image.py b/test/unit/test_loss_image.py index 139078206..63dcfa0da 100644 --- a/test/unit/test_loss_image.py +++ b/test/unit/test_loss_image.py @@ -105,7 +105,7 @@ def test_exact_value(self, kernel_size): :param kernel_size: size of the kernel and the cube. """ - mid = kernel_size // 2 + 1 + mid = kernel_size // 2 y_true = tf.random.uniform(shape=(1, kernel_size, kernel_size, kernel_size, 1)) y_pred = tf.random.uniform(shape=(1, kernel_size, kernel_size, kernel_size, 1)) From 0234ee96c4ccafc9322039288503cd744e47cd64 Mon Sep 17 00:00:00 2001 From: Yunguan Fu Date: Fri, 2 Apr 2021 00:57:09 +0100 Subject: [PATCH 3/4] Issue #690: make smooth value configurable and add extensive tests for LNCC --- deepreg/loss/image.py | 56 ++++++++++---- test/unit/test_loss_image.py | 141 +++++++++++++++++++++++++++++------ test/unit/util.py | 21 +++++- 3 files changed, 176 insertions(+), 42 deletions(-) diff --git a/deepreg/loss/image.py b/deepreg/loss/image.py index aa202e4d6..c38dc7fdb 100644 --- a/deepreg/loss/image.py +++ b/deepreg/loss/image.py @@ -188,6 +188,8 @@ def __init__( self, kernel_size: int = 9, kernel_type: str = "rectangular", + smooth_nr: float = EPS, + smooth_dr: float = EPS, reduction: str = tf.keras.losses.Reduction.SUM, name: str = "LocalNormalizedCrossCorrelation", ): @@ -196,6 +198,8 @@ def __init__( :param kernel_size: int. Kernel size or kernel sigma for kernel_type='gauss'. :param kernel_type: str, rectangular, triangular or gaussian + :param smooth_nr: small constant added to numerator in case of zero covariance. + :param smooth_dr: small constant added to denominator in case of zero variance. :param reduction: using SUM reduction over batch axis, calling the loss like `loss(y_true, y_pred)` will return a scalar tensor. :param name: name of the loss @@ -209,6 +213,8 @@ def __init__( self.kernel_fn = self.kernel_fn_dict[kernel_type] self.kernel_type = kernel_type self.kernel_size = kernel_size + self.smooth_nr = smooth_nr + self.smooth_dr = smooth_dr # (kernel_size, ) self.kernel = self.kernel_fn(kernel_size=self.kernel_size) @@ -223,19 +229,21 @@ def calc_ncc(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor: """ Return NCC for a batch. - :param y_true: shape = (batch, dim1, dim2, dim3) - or (batch, dim1, dim2, dim3, 1) - :param y_pred: shape = (batch, dim1, dim2, dim3) - or (batch, dim1, dim2, dim3, 1) + The kernel should not be normalized, as normalizing them leads to computation + with small values and the precision will be reduced. + Here both numerator and denominator are actually multiplied by kernel volume, + which helps the precision as well. + However, when the variance is zero, the obtained value might be negative due to + machine error. Therefore a hard-coded clipping is added to + prevent division by zero. + + :param y_true: shape = (batch, dim1, dim2, dim3, 1) + :param y_pred: shape = (batch, dim1, dim2, dim3, 1) :return: shape = (batch, dim1, dim2, dim3. 1) """ - # adjust - if len(y_true.shape) == 4: - y_true = tf.expand_dims(y_true, axis=4) - y_pred = tf.expand_dims(y_pred, axis=4) # t = y_true, p = y_pred - # (batch, dim1, dim2, dim3, ch) + # (batch, dim1, dim2, dim3, 1) t2 = y_true * y_true p2 = y_pred * y_pred tp = y_true * y_pred @@ -263,7 +271,7 @@ def calc_ncc(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor: p_var = tf.maximum(p_var, 0) # (E[tp] - E[p] * E[t]) ** 2 / V[t] / V[p] - ncc = (cross * cross + EPS) / (t_var * p_var + EPS) + ncc = (cross * cross + self.smooth_nr) / (t_var * p_var + self.smooth_dr) return ncc @@ -271,20 +279,40 @@ def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor: """ Return loss for a batch. + TODO: support channel axis dimension > 1. + :param y_true: shape = (batch, dim1, dim2, dim3) - or (batch, dim1, dim2, dim3, ch) + or (batch, dim1, dim2, dim3, 1) :param y_pred: shape = (batch, dim1, dim2, dim3) - or (batch, dim1, dim2, dim3, ch) + or (batch, dim1, dim2, dim3, 1) :return: shape = (batch,) """ + # sanity checks + if len(y_true.shape) == 4: + y_true = tf.expand_dims(y_true, axis=4) + if y_true.shape[4] != 1: + raise ValueError( + "Last dimension of y_true is not one. " f"y_true.shape = {y_true.shape}" + ) + if len(y_pred.shape) == 4: + y_pred = tf.expand_dims(y_pred, axis=4) + if y_pred.shape[4] != 1: + raise ValueError( + "Last dimension of y_pred is not one. " f"y_pred.shape = {y_pred.shape}" + ) + ncc = self.calc_ncc(y_true=y_true, y_pred=y_pred) return tf.reduce_mean(ncc, axis=[1, 2, 3, 4]) def get_config(self) -> dict: """Return the config dictionary for recreating this class.""" config = super().get_config() - config["kernel_size"] = self.kernel_size - config["kernel_type"] = self.kernel_type + config.update( + kernel_size=self.kernel_size, + kernel_type=self.kernel_type, + smooth_nr=self.smooth_nr, + smooth_dr=self.smooth_dr, + ) return config diff --git a/test/unit/test_loss_image.py b/test/unit/test_loss_image.py index 63dcfa0da..19330ee6f 100644 --- a/test/unit/test_loss_image.py +++ b/test/unit/test_loss_image.py @@ -5,12 +5,14 @@ in image.py should be better converted into tf tensor type beforehand. """ from test.unit.util import is_equal_tf +from typing import Tuple import numpy as np import pytest import tensorflow as tf import deepreg.loss.image as image +from deepreg.constant import EPS class TestSumSquaredDistance: @@ -76,69 +78,160 @@ def test_kernel_fn(kernel_size, name): class TestLocalNormalizedCrossCorrelation: @pytest.mark.parametrize( - "y_true,y_pred,shape,kernel_type,expected", + ("y_true_shape", "y_pred_shape"), [ - (0.6, 0.3, (12, 12, 12, 12), "rectangular", 1.0), - (0.6, 0.3, (12, 12, 12, 12, 1), "rectangular", 1.0), - (0.0, 1.0, (12, 12, 12, 12, 1), "rectangular", 1.0), - (0.6, 0.3, (12, 12, 12, 12, 1), "gaussian", 1.0), - (0.6, 0.3, (12, 12, 12, 12, 1), "triangular", 1.0), + ((2, 3, 4, 5), (2, 3, 4, 5)), + ((2, 3, 4, 5), (2, 3, 4, 5, 1)), + ((2, 3, 4, 5, 1), (2, 3, 4, 5)), + ((2, 3, 4, 5, 1), (2, 3, 4, 5, 1)), ], ) - def test_zero_info(self, y_true, y_pred, shape, kernel_type, expected): - y_true = y_true * tf.ones(shape=shape) - y_pred = y_pred * tf.ones(shape=shape) - expected = expected * tf.ones(shape=(shape[0],)) - got = image.LocalNormalizedCrossCorrelation(kernel_type=kernel_type).call( + def test_input_shape(self, y_true_shape: Tuple, y_pred_shape: Tuple): + """ + Test input with / without channel axis. + + :param y_true_shape: input shape for y_true. + :param y_pred_shape: input shape for y_pred. + """ + y_true = tf.ones(shape=y_true_shape) + y_pred = tf.ones(shape=y_pred_shape) + got = image.LocalNormalizedCrossCorrelation().call( + y_true, + y_pred, + ) + assert got.shape == y_true_shape[:1] + + @pytest.mark.parametrize( + ("y_true_shape", "y_pred_shape", "name"), + [ + ((2, 3, 4, 5), (2, 3, 4, 5, 6), "y_pred"), + ((2, 3, 4, 5, 6), (2, 3, 4, 5), "y_true"), + ], + ) + def test_input_shape_err(self, y_true_shape: Tuple, y_pred_shape: Tuple, name: str): + """ + Current LNCC does not support image having channel dimension > 1. + + :param y_true_shape: input shape for y_true. + :param y_pred_shape: input shape for y_pred. + :param name: name of the tensor having error. + """ + y_true = tf.ones(shape=y_true_shape) + y_pred = tf.ones(shape=y_pred_shape) + with pytest.raises(ValueError) as err_info: + image.LocalNormalizedCrossCorrelation().call(y_true, y_pred) + assert f"Last dimension of {name} is not one." in str(err_info.value) + + @pytest.mark.parametrize("value", [0.0, 0.5, 1.0]) + @pytest.mark.parametrize( + ("smooth_nr", "smooth_dr", "expected"), + [ + (1e-5, 1e-5, 1), + (0, 1e-5, 0), + (1e-5, 0, np.inf), + (0, 0, np.nan), + (1e-7, 1e-7, 1), + ], + ) + def test_smooth( + self, + value: float, + smooth_nr: float, + smooth_dr: float, + expected: float, + ): + """ + Test values in extreme cases where variances are all zero. + + :param value: value for input. + :param smooth_nr: constant for numerator. + :param smooth_dr: constant for denominator. + :param expected: target value. + """ + kernel_size = 5 + mid = kernel_size // 2 + shape = (1, kernel_size, kernel_size, kernel_size, 1) + y_true = tf.ones(shape=shape) * value + y_pred = tf.ones(shape=shape) * value + + got = image.LocalNormalizedCrossCorrelation( + kernel_size=kernel_size, + smooth_nr=smooth_nr, + smooth_dr=smooth_dr, + ).calc_ncc( y_true, y_pred, ) + got = got[0, mid, mid, mid, 0] + expected = tf.constant(expected) assert is_equal_tf(got, expected) + @pytest.mark.parametrize( + "kernel_type", + ["rectangular", "gaussian", "triangular"], + ) @pytest.mark.parametrize( "kernel_size", [3, 5, 7], ) - def test_exact_value(self, kernel_size): + def test_exact_value(self, kernel_type, kernel_size): """ Test the exact value at the center of a cube. + :param kernel_type: name of kernel. :param kernel_size: size of the kernel and the cube. """ + # init mid = kernel_size // 2 + tf.random.set_seed(0) y_true = tf.random.uniform(shape=(1, kernel_size, kernel_size, kernel_size, 1)) y_pred = tf.random.uniform(shape=(1, kernel_size, kernel_size, kernel_size, 1)) + loss = image.LocalNormalizedCrossCorrelation( + kernel_type=kernel_type, kernel_size=kernel_size + ) - loss = image.LocalNormalizedCrossCorrelation(kernel_size=kernel_size) + # obtained value got = loss.calc_ncc(y_true=y_true, y_pred=y_pred) - got = got[0, mid, mid, mid, 0] + got = got[0, mid, mid, mid, 0] # center voxel - y_true_mean = tf.reduce_mean(y_true) - y_true_std = tf.math.reduce_std(y_true) + # target value + kernel_3d = ( + loss.kernel[:, None, None] + * loss.kernel[None, :, None] + * loss.kernel[None, None, :] + ) + kernel_3d = kernel_3d[None, :, :, :, None] - y_pred_mean = tf.reduce_mean(y_pred) - y_pred_std = tf.math.reduce_std(y_pred) + y_true_mean = tf.reduce_sum(y_true * kernel_3d) / loss.kernel_vol + y_true_normalized = y_true - y_true_mean + y_true_var = tf.reduce_sum(y_true_normalized ** 2 * kernel_3d) - num = tf.reduce_mean((y_true - y_true_mean) * (y_pred - y_pred_mean)) - denom = y_true_std * y_pred_std + y_pred_mean = tf.reduce_sum(y_pred * kernel_3d) / loss.kernel_vol + y_pred_normalized = y_pred - y_pred_mean + y_pred_var = tf.reduce_sum(y_pred_normalized ** 2 * kernel_3d) - expected = (num / denom) ** 2 + cross = tf.reduce_sum(y_true_normalized * y_pred_normalized * kernel_3d) + expected = (cross ** 2 + EPS) / (y_pred_var * y_true_var + EPS) + # check assert is_equal_tf(got, expected) - def test_error(self): - y = np.ones(shape=(3, 3, 3, 3)) + def test_kernel_error(self): + """Test the error message when using wrong kernel.""" with pytest.raises(ValueError) as err_info: - image.LocalNormalizedCrossCorrelation(kernel_type="constant").call(y, y) + image.LocalNormalizedCrossCorrelation(kernel_type="constant") assert "Wrong kernel_type constant for LNCC loss type." in str(err_info.value) def test_get_config(self): + """Test the config is saved correctly.""" got = image.LocalNormalizedCrossCorrelation().get_config() expected = dict( kernel_size=9, kernel_type="rectangular", reduction=tf.keras.losses.Reduction.SUM, name="LocalNormalizedCrossCorrelation", + smooth_nr=1e-5, + smooth_dr=1e-5, ) assert got == expected diff --git a/test/unit/util.py b/test/unit/util.py index c40128d69..d16c55610 100644 --- a/test/unit/util.py +++ b/test/unit/util.py @@ -10,7 +10,7 @@ def is_equal_np( x: Union[np.ndarray, List], y: Union[np.ndarray, List], atol: float = EPS ) -> bool: """ - Check if two numpy arrays are identical. + Check if two numpy arrays are identical within a tolerance. :param x: :param y: @@ -19,7 +19,20 @@ def is_equal_np( """ x = np.asarray(x, dtype=np.float32) y = np.asarray(y, dtype=np.float32) - return x.shape == y.shape and np.all(np.isclose(x, y, atol=atol)) + + # check shape + if x.shape != y.shape: + return False + + # check nan values + # support case some values are nan + if np.any(np.isnan(x) != np.isnan(y)): + return False + x = np.nan_to_num(x) + y = np.nan_to_num(y) + + # check values + return np.all(np.isclose(x, y, atol=atol)) def is_equal_tf( @@ -28,7 +41,7 @@ def is_equal_tf( atol: float = EPS, ) -> bool: """ - Check if two tf tensors are identical. + Check if two tf tensors are identical within a tolerance. :param x: :param y: @@ -37,4 +50,4 @@ def is_equal_tf( """ x = tf.cast(x, dtype=tf.float32).numpy() y = tf.cast(y, dtype=tf.float32).numpy() - return x.shape == y.shape and np.all(np.isclose(x, y, atol=atol)) + return is_equal_np(x=x, y=y, atol=atol) From 531b68112ae04f621bfd828039c878a03dd37d97 Mon Sep 17 00:00:00 2001 From: Yunguan Fu Date: Fri, 2 Apr 2021 01:28:57 +0100 Subject: [PATCH 4/4] Issue #690: added doc regarding reduction strategy in loss --- deepreg/loss/image.py | 6 ++++++ deepreg/loss/label.py | 8 ++++++++ examples/custom_image_label_loss.py | 2 ++ examples/custom_parameterized_image_label_loss.py | 2 ++ 4 files changed, 18 insertions(+) diff --git a/deepreg/loss/image.py b/deepreg/loss/image.py index c38dc7fdb..fed72b05f 100644 --- a/deepreg/loss/image.py +++ b/deepreg/loss/image.py @@ -29,6 +29,8 @@ def __init__( Init. :param reduction: using SUM reduction over batch axis, + this is for supporting multi-device training, + and the loss will be divided by global batch size, calling the loss like `loss(y_true, y_pred)` will return a scalar tensor. :param name: name of the loss """ @@ -70,6 +72,8 @@ def __init__( :param num_bins: number of bins for intensity, the default value is empirical. :param sigma_ratio: a hyper param for gaussian function :param reduction: using SUM reduction over batch axis, + this is for supporting multi-device training, + and the loss will be divided by global batch size, calling the loss like `loss(y_true, y_pred)` will return a scalar tensor. :param name: name of the loss """ @@ -201,6 +205,8 @@ def __init__( :param smooth_nr: small constant added to numerator in case of zero covariance. :param smooth_dr: small constant added to denominator in case of zero variance. :param reduction: using SUM reduction over batch axis, + this is for supporting multi-device training, + and the loss will be divided by global batch size, calling the loss like `loss(y_true, y_pred)` will return a scalar tensor. :param name: name of the loss """ diff --git a/deepreg/loss/label.py b/deepreg/loss/label.py index 37429f51a..c51621614 100644 --- a/deepreg/loss/label.py +++ b/deepreg/loss/label.py @@ -34,6 +34,8 @@ def __init__( :param scales: list of scalars or None, if None, do not apply any scaling. :param kernel: gaussian or cauchy. :param reduction: using SUM reduction over batch axis, + this is for supporting multi-device training, + and the loss will be divided by global batch size, calling the loss like `loss(y_true, y_pred)` will return a scalar tensor. :param name: str, name of the loss. """ @@ -132,6 +134,8 @@ def __init__( :param scales: list of scalars or None, if None, do not apply any scaling. :param kernel: gaussian or cauchy. :param reduction: using SUM reduction over batch axis, + this is for supporting multi-device training, + and the loss will be divided by global batch size, calling the loss like `loss(y_true, y_pred)` will return a scalar tensor. :param name: str, name of the loss. """ @@ -206,6 +210,8 @@ def __init__( :param scales: list of scalars or None, if None, do not apply any scaling. :param kernel: gaussian or cauchy. :param reduction: using SUM reduction over batch axis, + this is for supporting multi-device training, + and the loss will be divided by global batch size, calling the loss like `loss(y_true, y_pred)` will return a scalar tensor. :param name: str, name of the loss. """ @@ -272,6 +278,8 @@ def __init__( :param scales: list of scalars or None, if None, do not apply any scaling. :param kernel: gaussian or cauchy. :param reduction: using SUM reduction over batch axis, + this is for supporting multi-device training, + and the loss will be divided by global batch size, calling the loss like `loss(y_true, y_pred)` will return a scalar tensor. :param name: str, name of the loss. """ diff --git a/examples/custom_image_label_loss.py b/examples/custom_image_label_loss.py index 341a9e0a9..3dd293155 100644 --- a/examples/custom_image_label_loss.py +++ b/examples/custom_image_label_loss.py @@ -22,6 +22,8 @@ def __init__( Init. :param reduction: using SUM reduction over batch axis, + this is for supporting multi-device training, + and the loss will be divided by global batch size, calling the loss like `loss(y_true, y_pred)` will return a scalar tensor. :param name: name of the loss """ diff --git a/examples/custom_parameterized_image_label_loss.py b/examples/custom_parameterized_image_label_loss.py index 44010fac4..c0db057db 100644 --- a/examples/custom_parameterized_image_label_loss.py +++ b/examples/custom_parameterized_image_label_loss.py @@ -24,6 +24,8 @@ def __init__( :param p: order of the norm, 1 or 2. :param reduction: using SUM reduction over batch axis, + this is for supporting multi-device training, + and the loss will be divided by global batch size, calling the loss like `loss(y_true, y_pred)` will return a scalar tensor. :param name: name of the loss. """