diff --git a/CHANGELOG.md b/CHANGELOG.md index 720ca54f1..d0089f17f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,7 +24,6 @@ compatible with the updates. ### Changed -- Increased all EPS to 1e-5. - Clarify the suggestion in doc to use all-zero masks for missing labels. - Moved contributor list to a separate page. - Changed `no-test` flag to `full` for demo scripts. @@ -38,7 +37,6 @@ compatible with the updates. ### Fixed -- Fixed LNCC loss regarding INF values. - Removed loss weight checks to be more robust. - Fixed import error under python 3.6. - Fixed the residual module in local net architecture, compatible for previous diff --git a/deepreg/constant.py b/deepreg/constant.py deleted file mode 100644 index 72746ee3b..000000000 --- a/deepreg/constant.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Module defining global constants.""" - -EPS = 1.0e-5 diff --git a/deepreg/dataset/loader/interface.py b/deepreg/dataset/loader/interface.py index 997448195..1c89d5e53 100644 --- a/deepreg/dataset/loader/interface.py +++ b/deepreg/dataset/loader/interface.py @@ -379,10 +379,10 @@ def validate_images_and_labels( for arr, name in zip( [moving_image, fixed_image], ["moving_image", "fixed_image"] ): - if len(arr.shape) != 3 or min(arr.shape) <= 0: + if len(arr.shape) != 3: raise ValueError( - f"Sample {image_indices}'s {name}' shape should be 3D" - f" and non-empty, got {arr.shape}." + f"Sample {image_indices}'s {name}' shape should be 3D. " + f"Got {arr.shape}." ) # when data are labeled if moving_label is not None and fixed_label is not None: diff --git a/deepreg/loss/image.py b/deepreg/loss/image.py index 5300a33e4..44510bdde 100644 --- a/deepreg/loss/image.py +++ b/deepreg/loss/image.py @@ -1,7 +1,6 @@ """Provide different loss or metrics classes for images.""" import tensorflow as tf -from deepreg.constant import EPS from deepreg.loss.util import NegativeLossMixin from deepreg.loss.util import gaussian_kernel1d_size as gaussian_kernel1d from deepreg.loss.util import ( @@ -11,6 +10,8 @@ ) from deepreg.registry import REGISTRY +EPS = tf.keras.backend.epsilon() + @REGISTRY.register_loss(name="ssd") class SumSquaredDifference(tf.keras.losses.Loss): @@ -155,22 +156,19 @@ class LocalNormalizedCrossCorrelation(tf.keras.losses.Loss): E[t] = sum_i(w_i * t_i) / sum_i(w_i) - Here, we assume sum_i(w_i) == 1, means the weights have been normalized. - Similarly, the discrete variance in the window V[t] is - V[t] = E[(t - E[t])**2] + V[t] = E[t**2] - E[t] ** 2 The local squared zero-normalized cross-correlation is therefore E[ (t-E[t]) * (p-E[p]) ] ** 2 / V[t] / V[p] - When calculating variance, we choose to subtract the mean first then calculte - variance instead of calculating E[t**2] - E[t] ** 2, the reason is that when - E[t**2] and E[t] ** 2 are both very large or very small, the subtraction may - have large rounding error and makes the result inaccurate. Also, it is not - guaranteed that the result >= 0. For more details, please read "Algorithms for - computing the sample variance: Analysis and recommendations." page 1. + where the expectation in numerator is + + E[ (t-E[t]) * (p-E[p]) ] = E[t * p] - E[t] * E[p] + + Different kernel corresponds to different weights. For now, y_true and y_pred have to be at least 4d tensor, including batch axis. @@ -178,10 +176,7 @@ class LocalNormalizedCrossCorrelation(tf.keras.losses.Loss): - Zero-normalized cross-correlation (ZNCC): https://en.wikipedia.org/wiki/Cross-correlation - - https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights - - Chan, Tony F., Gene H. Golub, and Randall J. LeVeque. - "Algorithms for computing the sample variance: Analysis and recommendations." - The American Statistician 37.3 (1983): 242-247. + - Code: https://github.com/voxelmorph/voxelmorph/blob/legacy/src/losses.py """ kernel_fn_dict = dict( @@ -217,8 +212,13 @@ def __init__( self.kernel_size = kernel_size # (kernel_size, ) - # sum of the kernel weights would be one self.kernel = self.kernel_fn(kernel_size=self.kernel_size) + # E[1] = sum_i(w_i), () + self.kernel_vol = tf.reduce_sum( + self.kernel[:, None, None] + * self.kernel[None, :, None] + * self.kernel[None, None, :] + ) def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor: """ @@ -230,29 +230,38 @@ def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor: or (batch, dim1, dim2, dim3, ch) :return: shape = (batch,) """ - # adjust shape to be (batch, dim1, dim2, dim3, ch) + # adjust if len(y_true.shape) == 4: y_true = tf.expand_dims(y_true, axis=4) y_pred = tf.expand_dims(y_pred, axis=4) assert len(y_true.shape) == len(y_pred.shape) == 5 # t = y_true, p = y_pred - t_mean = separable_filter(y_true, kernel=self.kernel) - p_mean = separable_filter(y_pred, kernel=self.kernel) - - t = y_true - t_mean - p = y_pred - p_mean - - # the variance can be biased but as both num and denom are biased - # it got cancelled - # https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights - cross = separable_filter(t * p, kernel=self.kernel) - t_var = separable_filter(t * t, kernel=self.kernel) - p_var = separable_filter(p * p, kernel=self.kernel) - - num = cross * cross - denom = t_var * p_var - ncc = (num + EPS) / (denom + EPS) + # (batch, dim1, dim2, dim3, ch) + t2 = y_true * y_true + p2 = y_pred * y_pred + tp = y_true * y_pred + + # sum over kernel + # (batch, dim1, dim2, dim3, 1) + t_sum = separable_filter(y_true, kernel=self.kernel) # E[t] * E[1] + p_sum = separable_filter(y_pred, kernel=self.kernel) # E[p] * E[1] + t2_sum = separable_filter(t2, kernel=self.kernel) # E[tt] * E[1] + p2_sum = separable_filter(p2, kernel=self.kernel) # E[pp] * E[1] + tp_sum = separable_filter(tp, kernel=self.kernel) # E[tp] * E[1] + + # average over kernel + # (batch, dim1, dim2, dim3, 1) + t_avg = t_sum / self.kernel_vol # E[t] + p_avg = p_sum / self.kernel_vol # E[p] + + # shape = (batch, dim1, dim2, dim3, 1) + cross = tp_sum - p_avg * t_sum # E[tp] * E[1] - E[p] * E[t] * E[1] + t_var = t2_sum - t_avg * t_sum # V[t] * E[1] + p_var = p2_sum - p_avg * p_sum # V[p] * E[1] + + # (E[tp] - E[p] * E[t]) ** 2 / V[t] / V[p] + ncc = (cross * cross + EPS) / (t_var * p_var + EPS) return tf.reduce_mean(ncc, axis=[1, 2, 3, 4]) diff --git a/deepreg/loss/label.py b/deepreg/loss/label.py index 37429f51a..0e0d94bad 100644 --- a/deepreg/loss/label.py +++ b/deepreg/loss/label.py @@ -4,12 +4,13 @@ import tensorflow as tf -from deepreg.constant import EPS from deepreg.loss.util import NegativeLossMixin, cauchy_kernel1d from deepreg.loss.util import gaussian_kernel1d_sigma as gaussian_kernel1d from deepreg.loss.util import separable_filter from deepreg.registry import REGISTRY +EPS = tf.keras.backend.epsilon() + class MultiScaleLoss(tf.keras.losses.Loss): """ diff --git a/deepreg/loss/util.py b/deepreg/loss/util.py index 908d54b7e..6f453d73c 100644 --- a/deepreg/loss/util.py +++ b/deepreg/loss/util.py @@ -27,25 +27,25 @@ def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor: return -super().call(y_true=y_true, y_pred=y_pred) +EPS = tf.keras.backend.epsilon() + + def rectangular_kernel1d(kernel_size: int) -> tf.Tensor: """ - Return a the 1D rectangular kernel for LocalNormalizedCrossCorrelation. - - Sum of the weights is 1. + Return a the 1D filter for separable convolution equivalent to a 3-D rectangular + kernel for LocalNormalizedCrossCorrelation. :param kernel_size: scalar, size of the 1-D kernel :return: kernel_weights, of shape (kernel_size, ) """ - kernel = tf.ones(shape=(kernel_size,), dtype=tf.float32) / float(kernel_size) + kernel = tf.ones(shape=(kernel_size,), dtype=tf.float32) return kernel def triangular_kernel1d(kernel_size: int) -> tf.Tensor: """ - Return a the 1D triangular kernel for LocalNormalizedCrossCorrelation. - - Sum of the weights is 1. + 1D triangular kernel. Assume kernel_size is odd, it will be a smoothed from a kernel which center part is zero. @@ -73,17 +73,13 @@ def triangular_kernel1d(kernel_size: int) -> tf.Tensor: kernel = tf.nn.conv1d( kernel[None, :, None], filters=filters, stride=[1, 1, 1], padding="SAME" ) - kernel = kernel / tf.reduce_sum(kernel) - return kernel[0, :, 0] def gaussian_kernel1d_size(kernel_size: int) -> tf.Tensor: """ - Return a the 1D Gaussian kernel for LocalNormalizedCrossCorrelation. - - Sum of the weights is 1. - + Return a the 1D filter for separable convolution equivalent to a 3-D Gaussian + kernel for LocalNormalizedCrossCorrelation. :param kernel_size: scalar, size of the 1-D kernel :return: filters, of shape (kernel_size, ) """ @@ -92,7 +88,6 @@ def gaussian_kernel1d_size(kernel_size: int) -> tf.Tensor: grid = tf.range(0, kernel_size, dtype=tf.float32) filters = tf.exp(-tf.square(grid - mean) / (2 * sigma ** 2)) - filters = filters / tf.reduce_sum(filters) return filters diff --git a/deepreg/model/network.py b/deepreg/model/network.py index a51808ec8..227a20d4a 100644 --- a/deepreg/model/network.py +++ b/deepreg/model/network.py @@ -30,8 +30,8 @@ class RegistrationModel(tf.keras.Model): def __init__( self, - moving_image_size: Tuple, - fixed_image_size: Tuple, + moving_image_size: tuple, + fixed_image_size: tuple, index_size: int, labeled: bool, batch_size: int, @@ -61,7 +61,6 @@ def __init__( self.config = config self.num_devices = num_devices self.global_batch_size = num_devices * batch_size - assert self.global_batch_size > 0 self._inputs = None # save inputs of self._model as dict self._outputs = None # save outputs of self._model as dict @@ -223,6 +222,7 @@ def _build_loss(self, name: str, inputs_dict: dict): # add loss self._model.add_loss(weighted_loss) + # add metric self._model.add_metric( loss_value, name=f"loss/{name}_{loss_layer.name}", aggregation="mean" diff --git a/test/unit/test_interface.py b/test/unit/test_interface.py index 22778f610..0315938d5 100644 --- a/test/unit/test_interface.py +++ b/test/unit/test_interface.py @@ -356,7 +356,7 @@ def mock_sample_index_generator(): fixed_label=None, image_indices=[1], ) - assert "Sample [1]'s moving_image' shape should be 3D" in str(err_info.value) + assert "Sample [1]'s moving_image' shape should be 3D. " in str(err_info.value) with pytest.raises(ValueError) as err_info: generator.validate_images_and_labels( fixed_image=dummy_array, diff --git a/test/unit/test_loss_label.py b/test/unit/test_loss_label.py index 60007d685..ba43f2eca 100644 --- a/test/unit/test_loss_label.py +++ b/test/unit/test_loss_label.py @@ -12,7 +12,6 @@ import tensorflow as tf import deepreg.loss.label as label -from deepreg.constant import EPS class TestMultiScaleLoss: @@ -92,11 +91,11 @@ def y_pred(self): @pytest.mark.parametrize( "binary,background_weight,scales,expected", [ - (True, 0.0, None, -np.log(EPS)), - (False, 0.0, None, -0.6 * np.log(0.3 + EPS)), - (False, 0.2, None, -0.48 * np.log(0.3 + EPS) - 0.08 * np.log(0.7 + EPS)), - (False, 0.2, [0, 0], -0.48 * np.log(0.3 + EPS) - 0.08 * np.log(0.7 + EPS)), - (False, 0.2, [0, 1], 0.5239465), + (True, 0.0, None, -np.log(1.0e-7)), + (False, 0.0, None, -0.6 * np.log(0.3)), + (False, 0.2, None, -0.48 * np.log(0.3) - 0.08 * np.log(0.7)), + (False, 0.2, [0, 0], -0.48 * np.log(0.3) - 0.08 * np.log(0.7)), + (False, 0.2, [0, 1], 0.5239637), ], ) def test_call(self, y_true, y_pred, binary, background_weight, scales, expected): @@ -136,7 +135,7 @@ def y_pred(self): (True, None, 0), (False, None, 0.25), (False, [0, 0], 0.25), - (False, [0, 1], 0.17485845), + (False, [0, 1], 0.17484076), ], ) def test_call(self, y_true, y_pred, binary, scales, expected): diff --git a/test/unit/test_loss_util.py b/test/unit/test_loss_util.py index 20e456ee2..187760555 100644 --- a/test/unit/test_loss_util.py +++ b/test/unit/test_loss_util.py @@ -62,7 +62,6 @@ def test_gaussian_kernel1d_size(kernel_size): grid = tf.range(0, kernel_size, dtype=tf.float32) expected = tf.exp(-tf.square(grid - mean) / (2 * sigma ** 2)) - expected = expected / tf.reduce_sum(expected) got = gaussian_kernel1d_size(kernel_size) assert is_equal_tf(got, expected) @@ -76,7 +75,6 @@ def test_rectangular_kernel1d(kernel_size): :return: """ expected = tf.ones(shape=(kernel_size,), dtype=tf.float32) - expected = expected / tf.reduce_sum(expected) got = rectangular_kernel1d(kernel_size) assert is_equal_tf(got, expected) @@ -93,7 +91,6 @@ def test_triangular_kernel1d(kernel_size): for it_k in range(kernel_size // 2): expected[it_k] = it_k + 1 expected[-it_k - 1] = it_k + 1 - expected = expected / tf.reduce_sum(expected) got = triangular_kernel1d(kernel_size) assert is_equal_tf(got, expected) diff --git a/test/unit/util.py b/test/unit/util.py index c40128d69..02866779a 100644 --- a/test/unit/util.py +++ b/test/unit/util.py @@ -3,11 +3,9 @@ import numpy as np import tensorflow as tf -from deepreg.constant import EPS - def is_equal_np( - x: Union[np.ndarray, List], y: Union[np.ndarray, List], atol: float = EPS + x: Union[np.ndarray, List], y: Union[np.ndarray, List], atol: float = 1.0e-7 ) -> bool: """ Check if two numpy arrays are identical. @@ -25,7 +23,7 @@ def is_equal_np( def is_equal_tf( x: Union[tf.Tensor, np.ndarray, List], y: Union[tf.Tensor, np.ndarray, List], - atol: float = EPS, + atol: float = 1.0e-7, ) -> bool: """ Check if two tf tensors are identical.