Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ compatible with the updates.

### Changed

- Increased all EPS to 1e-5.
- Clarify the suggestion in doc to use all-zero masks for missing labels.
- Moved contributor list to a separate page.
- Changed `no-test` flag to `full` for demo scripts.
Expand All @@ -37,6 +38,7 @@ compatible with the updates.

### Fixed

- Fixed LNCC loss regarding INF values.
- Removed loss weight checks to be more robust.
- Fixed import error under python 3.6.
- Fixed the residual module in local net architecture, compatible for previous
Expand Down
3 changes: 3 additions & 0 deletions deepreg/constant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""Module defining global constants."""

EPS = 1.0e-5
6 changes: 3 additions & 3 deletions deepreg/dataset/loader/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,10 +379,10 @@ def validate_images_and_labels(
for arr, name in zip(
[moving_image, fixed_image], ["moving_image", "fixed_image"]
):
if len(arr.shape) != 3:
if len(arr.shape) != 3 or min(arr.shape) <= 0:
raise ValueError(
f"Sample {image_indices}'s {name}' shape should be 3D. "
f"Got {arr.shape}."
f"Sample {image_indices}'s {name}' shape should be 3D"
f" and non-empty, got {arr.shape}."
)
# when data are labeled
if moving_label is not None and fixed_label is not None:
Expand Down
87 changes: 68 additions & 19 deletions deepreg/loss/image.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Provide different loss or metrics classes for images."""
import tensorflow as tf

from deepreg.constant import EPS
from deepreg.loss.util import NegativeLossMixin
from deepreg.loss.util import gaussian_kernel1d_size as gaussian_kernel1d
from deepreg.loss.util import (
Expand All @@ -10,8 +11,6 @@
)
from deepreg.registry import REGISTRY

EPS = tf.keras.backend.epsilon()


@REGISTRY.register_loss(name="ssd")
class SumSquaredDifference(tf.keras.losses.Loss):
Expand All @@ -30,6 +29,8 @@ def __init__(
Init.

:param reduction: using SUM reduction over batch axis,
this is for supporting multi-device training,
and the loss will be divided by global batch size,
calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
:param name: name of the loss
"""
Expand Down Expand Up @@ -71,6 +72,8 @@ def __init__(
:param num_bins: number of bins for intensity, the default value is empirical.
:param sigma_ratio: a hyper param for gaussian function
:param reduction: using SUM reduction over batch axis,
this is for supporting multi-device training,
and the loss will be divided by global batch size,
calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
:param name: name of the loss
"""
Expand Down Expand Up @@ -189,6 +192,8 @@ def __init__(
self,
kernel_size: int = 9,
kernel_type: str = "rectangular",
smooth_nr: float = EPS,
smooth_dr: float = EPS,
reduction: str = tf.keras.losses.Reduction.SUM,
name: str = "LocalNormalizedCrossCorrelation",
):
Expand All @@ -197,7 +202,11 @@ def __init__(

:param kernel_size: int. Kernel size or kernel sigma for kernel_type='gauss'.
:param kernel_type: str, rectangular, triangular or gaussian
:param smooth_nr: small constant added to numerator in case of zero covariance.
:param smooth_dr: small constant added to denominator in case of zero variance.
:param reduction: using SUM reduction over batch axis,
this is for supporting multi-device training,
and the loss will be divided by global batch size,
calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
:param name: name of the loss
"""
Expand All @@ -210,6 +219,8 @@ def __init__(
self.kernel_fn = self.kernel_fn_dict[kernel_type]
self.kernel_type = kernel_type
self.kernel_size = kernel_size
self.smooth_nr = smooth_nr
self.smooth_dr = smooth_dr

# (kernel_size, )
self.kernel = self.kernel_fn(kernel_size=self.kernel_size)
Expand All @@ -220,24 +231,25 @@ def __init__(
* self.kernel[None, None, :]
)

def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
def calc_ncc(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
"""
Return loss for a batch.

:param y_true: shape = (batch, dim1, dim2, dim3)
or (batch, dim1, dim2, dim3, ch)
:param y_pred: shape = (batch, dim1, dim2, dim3)
or (batch, dim1, dim2, dim3, ch)
:return: shape = (batch,)
Return NCC for a batch.

The kernel should not be normalized, as normalizing them leads to computation
with small values and the precision will be reduced.
Here both numerator and denominator are actually multiplied by kernel volume,
which helps the precision as well.
However, when the variance is zero, the obtained value might be negative due to
machine error. Therefore a hard-coded clipping is added to
prevent division by zero.

:param y_true: shape = (batch, dim1, dim2, dim3, 1)
:param y_pred: shape = (batch, dim1, dim2, dim3, 1)
:return: shape = (batch, dim1, dim2, dim3. 1)
"""
# adjust
if len(y_true.shape) == 4:
y_true = tf.expand_dims(y_true, axis=4)
y_pred = tf.expand_dims(y_pred, axis=4)
assert len(y_true.shape) == len(y_pred.shape) == 5

# t = y_true, p = y_pred
# (batch, dim1, dim2, dim3, ch)
# (batch, dim1, dim2, dim3, 1)
t2 = y_true * y_true
p2 = y_pred * y_pred
tp = y_true * y_pred
Expand All @@ -260,16 +272,53 @@ def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
t_var = t2_sum - t_avg * t_sum # V[t] * E[1]
p_var = p2_sum - p_avg * p_sum # V[p] * E[1]

# ensure variance >= 0
t_var = tf.maximum(t_var, 0)
p_var = tf.maximum(p_var, 0)

# (E[tp] - E[p] * E[t]) ** 2 / V[t] / V[p]
ncc = (cross * cross + EPS) / (t_var * p_var + EPS)
ncc = (cross * cross + self.smooth_nr) / (t_var * p_var + self.smooth_dr)

return ncc

def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
"""
Return loss for a batch.

TODO: support channel axis dimension > 1.

:param y_true: shape = (batch, dim1, dim2, dim3)
or (batch, dim1, dim2, dim3, 1)
:param y_pred: shape = (batch, dim1, dim2, dim3)
or (batch, dim1, dim2, dim3, 1)
:return: shape = (batch,)
"""
# sanity checks
if len(y_true.shape) == 4:
y_true = tf.expand_dims(y_true, axis=4)
if y_true.shape[4] != 1:
raise ValueError(
"Last dimension of y_true is not one. " f"y_true.shape = {y_true.shape}"
)
if len(y_pred.shape) == 4:
y_pred = tf.expand_dims(y_pred, axis=4)
if y_pred.shape[4] != 1:
raise ValueError(
"Last dimension of y_pred is not one. " f"y_pred.shape = {y_pred.shape}"
)

ncc = self.calc_ncc(y_true=y_true, y_pred=y_pred)
return tf.reduce_mean(ncc, axis=[1, 2, 3, 4])

def get_config(self) -> dict:
"""Return the config dictionary for recreating this class."""
config = super().get_config()
config["kernel_size"] = self.kernel_size
config["kernel_type"] = self.kernel_type
config.update(
kernel_size=self.kernel_size,
kernel_type=self.kernel_type,
smooth_nr=self.smooth_nr,
smooth_dr=self.smooth_dr,
)
return config


Expand Down
11 changes: 9 additions & 2 deletions deepreg/loss/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@

import tensorflow as tf

from deepreg.constant import EPS
from deepreg.loss.util import NegativeLossMixin, cauchy_kernel1d
from deepreg.loss.util import gaussian_kernel1d_sigma as gaussian_kernel1d
from deepreg.loss.util import separable_filter
from deepreg.registry import REGISTRY

EPS = tf.keras.backend.epsilon()


class MultiScaleLoss(tf.keras.losses.Loss):
"""
Expand All @@ -35,6 +34,8 @@ def __init__(
:param scales: list of scalars or None, if None, do not apply any scaling.
:param kernel: gaussian or cauchy.
:param reduction: using SUM reduction over batch axis,
this is for supporting multi-device training,
and the loss will be divided by global batch size,
calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
:param name: str, name of the loss.
"""
Expand Down Expand Up @@ -133,6 +134,8 @@ def __init__(
:param scales: list of scalars or None, if None, do not apply any scaling.
:param kernel: gaussian or cauchy.
:param reduction: using SUM reduction over batch axis,
this is for supporting multi-device training,
and the loss will be divided by global batch size,
calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
:param name: str, name of the loss.
"""
Expand Down Expand Up @@ -207,6 +210,8 @@ def __init__(
:param scales: list of scalars or None, if None, do not apply any scaling.
:param kernel: gaussian or cauchy.
:param reduction: using SUM reduction over batch axis,
this is for supporting multi-device training,
and the loss will be divided by global batch size,
calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
:param name: str, name of the loss.
"""
Expand Down Expand Up @@ -273,6 +278,8 @@ def __init__(
:param scales: list of scalars or None, if None, do not apply any scaling.
:param kernel: gaussian or cauchy.
:param reduction: using SUM reduction over batch axis,
this is for supporting multi-device training,
and the loss will be divided by global batch size,
calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
:param name: str, name of the loss.
"""
Expand Down
13 changes: 5 additions & 8 deletions deepreg/loss/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,9 @@ def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
return -super().call(y_true=y_true, y_pred=y_pred)


EPS = tf.keras.backend.epsilon()


def rectangular_kernel1d(kernel_size: int) -> tf.Tensor:
"""
Return a the 1D filter for separable convolution equivalent to a 3-D rectangular
kernel for LocalNormalizedCrossCorrelation.
Return a the 1D rectangular kernel for LocalNormalizedCrossCorrelation.

:param kernel_size: scalar, size of the 1-D kernel
:return: kernel_weights, of shape (kernel_size, )
Expand All @@ -45,7 +41,7 @@ def rectangular_kernel1d(kernel_size: int) -> tf.Tensor:

def triangular_kernel1d(kernel_size: int) -> tf.Tensor:
"""
1D triangular kernel.
Return a the 1D triangular kernel for LocalNormalizedCrossCorrelation.

Assume kernel_size is odd, it will be a smoothed from
a kernel which center part is zero.
Expand Down Expand Up @@ -73,13 +69,14 @@ def triangular_kernel1d(kernel_size: int) -> tf.Tensor:
kernel = tf.nn.conv1d(
kernel[None, :, None], filters=filters, stride=[1, 1, 1], padding="SAME"
)

return kernel[0, :, 0]


def gaussian_kernel1d_size(kernel_size: int) -> tf.Tensor:
"""
Return a the 1D filter for separable convolution equivalent to a 3-D Gaussian
kernel for LocalNormalizedCrossCorrelation.
Return a the 1D Gaussian kernel for LocalNormalizedCrossCorrelation.

:param kernel_size: scalar, size of the 1-D kernel
:return: filters, of shape (kernel_size, )
"""
Expand Down
4 changes: 2 additions & 2 deletions deepreg/model/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ class RegistrationModel(tf.keras.Model):

def __init__(
self,
moving_image_size: tuple,
fixed_image_size: tuple,
moving_image_size: Tuple,
fixed_image_size: Tuple,
index_size: int,
labeled: bool,
batch_size: int,
Expand Down
2 changes: 2 additions & 0 deletions examples/custom_image_label_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def __init__(
Init.

:param reduction: using SUM reduction over batch axis,
this is for supporting multi-device training,
and the loss will be divided by global batch size,
calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
:param name: name of the loss
"""
Expand Down
2 changes: 2 additions & 0 deletions examples/custom_parameterized_image_label_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ def __init__(

:param p: order of the norm, 1 or 2.
:param reduction: using SUM reduction over batch axis,
this is for supporting multi-device training,
and the loss will be divided by global batch size,
calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
:param name: name of the loss.
"""
Expand Down
2 changes: 1 addition & 1 deletion test/unit/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def mock_sample_index_generator():
fixed_label=None,
image_indices=[1],
)
assert "Sample [1]'s moving_image' shape should be 3D. " in str(err_info.value)
assert "Sample [1]'s moving_image' shape should be 3D" in str(err_info.value)
with pytest.raises(ValueError) as err_info:
generator.validate_images_and_labels(
fixed_image=dummy_array,
Expand Down
Loading