diff --git a/docs/source/losses.rst b/docs/source/losses.rst index e929e9d605..61dd959807 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -73,6 +73,11 @@ Segmentation Losses .. autoclass:: ContrastiveLoss :members: +`BarlowTwinsLoss` +~~~~~~~~~~~~~~~~~ +.. autoclass:: BarlowTwinsLoss + :members: + `HausdorffDTLoss` ~~~~~~~~~~~~~~~~~ .. autoclass:: HausdorffDTLoss diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index 92898c81ca..4ebedb2084 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -12,6 +12,7 @@ from __future__ import annotations from .adversarial_loss import PatchAdversarialLoss +from .barlow_twins import BarlowTwinsLoss from .cldice import SoftclDiceLoss, SoftDiceclDiceLoss from .contrastive import ContrastiveLoss from .deform import BendingEnergyLoss, DiffusionLoss diff --git a/monai/losses/barlow_twins.py b/monai/losses/barlow_twins.py new file mode 100644 index 0000000000..a61acca66e --- /dev/null +++ b/monai/losses/barlow_twins.py @@ -0,0 +1,84 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +from torch.nn.modules.loss import _Loss + + +class BarlowTwinsLoss(_Loss): + """ + The Barlow Twins cost function takes the representations extracted by a neural network from two + distorted views and seeks to make the cross-correlation matrix of the two representations tend + towards identity. This encourages the neural network to learn similar representations with the least + amount of redundancy. This cost function can be used in particular in multimodal learning to work on + representations from two modalities. The most common use case is for unsupervised learning, where data + augmentations are used to generate 2 distorted views of the same sample to force the encoder to + extract useful features for downstream tasks. + + Zbontar, Jure, et al. "Barlow Twins: Self-Supervised Learning via Redundancy Reduction" International + conference on machine learning. PMLR, 2020. (http://proceedings.mlr.press/v139/zbontar21a/zbontar21a.pdf) + + Adapted from: + https://github.com/facebookresearch/barlowtwins + + """ + + def __init__(self, lambd: float = 5e-3) -> None: + """ + Args: + lamb: Can be any float to handle the informativeness and invariance trade-off. Ideally set to 5e-3. + + Raises: + ValueError: When an input of dimension length > 2 is passed + ValueError: When input and target are of different shapes + ValueError: When batch size is less than or equal to 1 + + """ + super().__init__() + self.lambd = lambd + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Args: + input: the shape should be B[F]. + target: the shape should be B[F]. + """ + if len(target.shape) > 2 or len(input.shape) > 2: + raise ValueError( + f"Either target or input has dimensions greater than 2 where target " + f"shape is ({target.shape}) and input shape is ({input.shape})" + ) + + if target.shape != input.shape: + raise ValueError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})") + + if target.size(0) <= 1: + raise ValueError( + f"Batch size must be greater than 1 to compute Barlow Twins Loss, but got {target.size(0)}" + ) + + lambd_tensor = torch.as_tensor(self.lambd).to(input.device) + batch_size = input.shape[0] + + # normalize input and target + input_norm = (input - input.mean(0)) / input.std(0).add(1e-6) + target_norm = (target - target.mean(0)) / target.std(0).add(1e-6) + + # cross-correlation matrix + c = torch.mm(input_norm.t(), target_norm) / batch_size # input_norm.t() is FxB, target_norm is BxF so c is FxF + + # loss + c_diff = (c - torch.eye(c.size(0), device=c.device)).pow_(2) # FxF + c_diff[~torch.eye(c.size(0), device=c.device).bool()] *= lambd_tensor + + return c_diff.sum() diff --git a/tests/test_barlow_twins_loss.py b/tests/test_barlow_twins_loss.py new file mode 100644 index 0000000000..81f4032e0c --- /dev/null +++ b/tests/test_barlow_twins_loss.py @@ -0,0 +1,109 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.losses import BarlowTwinsLoss + +TEST_CASES = [ + [ # shape: (2, 4), (2, 4) + {"lambd": 5e-3}, + { + "input": torch.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]), + "target": torch.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]), + }, + 4.0, + ], + [ # shape: (2, 4), (2, 4) + {"lambd": 5e-3}, + { + "input": torch.tensor([[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]]), + "target": torch.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]), + }, + 4.0, + ], + [ # shape: (2, 4), (2, 4) + {"lambd": 5e-3}, + { + "input": torch.tensor([[1.0, 0.0, 1.0, 1.0], [0.0, 1.0, 1.0, 0.0]]), + "target": torch.tensor([[1.0, 1.0, 1.0, 0.0], [1.0, 1.0, 0.0, 1.0]]), + }, + 5.2562, + ], + [ # shape: (2, 4), (2, 4) + {"lambd": 5e-4}, + { + "input": torch.tensor([[2.0, 3.0, 1.0, 2.0], [0.0, 1.0, 2.0, 5.0]]), + "target": torch.tensor([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]]), + }, + 5.0015, + ], + [ # shape: (4, 4), (4, 4) + {"lambd": 5e-3}, + { + "input": torch.tensor( + [[1.0, 2.0, 1.0, 1.0], [3.0, 1.0, 1.0, 2.0], [1.0, 1.0, 1.0, 1.0], [2.0, 1.0, 1.0, 0.0]] + ), + "target": torch.tensor( + [ + [0.0, 1.0, -1.0, 0.0], + [1 / 3, 0.0, -2 / 3, 1 / 3], + [-2 / 3, -1.0, 7 / 3, 1 / 3], + [1 / 3, 0.0, 1 / 3, -2 / 3], + ] + ), + }, + 1.4736, + ], +] + + +class TestBarlowTwinsLoss(unittest.TestCase): + + @parameterized.expand(TEST_CASES) + def test_result(self, input_param, input_data, expected_val): + barlowtwinsloss = BarlowTwinsLoss(**input_param) + result = barlowtwinsloss(**input_data) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4) + + def test_ill_shape(self): + loss = BarlowTwinsLoss(lambd=5e-3) + with self.assertRaises(ValueError): + loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3))) + + def test_ill_batch_size(self): + loss = BarlowTwinsLoss(lambd=5e-3) + with self.assertRaises(ValueError): + loss(torch.ones((1, 2)), torch.ones((1, 2))) + + def test_with_cuda(self): + loss = BarlowTwinsLoss(lambd=5e-3) + i = torch.ones((2, 10)) + j = torch.ones((2, 10)) + if torch.cuda.is_available(): + i = i.cuda() + j = j.cuda() + output = loss(i, j) + np.testing.assert_allclose(output.detach().cpu().numpy(), 10.0, atol=1e-4, rtol=1e-4) + + def check_warning_raised(self): + with self.assertWarns(Warning): + BarlowTwinsLoss(lambd=5e-3, batch_size=1) + + +if __name__ == "__main__": + unittest.main()