Skip to content

Commit

Permalink
Add Barlow Twins loss for representation learning
Browse files Browse the repository at this point in the history
  • Loading branch information
Lucas-rbnt committed Mar 11, 2024
1 parent 95f69de commit 1371b0d
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 0 deletions.
5 changes: 5 additions & 0 deletions docs/source/losses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ Segmentation Losses
.. autoclass:: ContrastiveLoss
:members:

`BarlowTwinsLoss`
~~~~~~~~~~~~~~~~~
.. autoclass:: BarlowTwinsLoss
:members:

`HausdorffDTLoss`
~~~~~~~~~~~~~~~~~
.. autoclass:: HausdorffDTLoss
Expand Down
1 change: 1 addition & 0 deletions monai/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from __future__ import annotations

from .adversarial_loss import PatchAdversarialLoss
from .barlow_twins import BarlowTwinsLoss
from .cldice import SoftclDiceLoss, SoftDiceclDiceLoss
from .contrastive import ContrastiveLoss
from .deform import BendingEnergyLoss, DiffusionLoss
Expand Down
83 changes: 83 additions & 0 deletions monai/losses/barlow_twins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from warnings import warn

import torch
from torch.nn.modules.loss import _Loss


class BarlowTwinsLoss(_Loss):
"""
Compute the Barlow Twins loss defined in:
Zbontar, Jure, et al. "Barlow Twins: Self-Supervised Learning via Redundancy Reduction" International
conference on machine learning. PMLR, 2020. (http://proceedings.mlr.press/v139/zbontar21a/zbontar21a.pdf)
Adapted from:
https://github.com/facebookresearch/barlowtwins
"""

def __init__(self, lambd: float = 5e-3, batch_size: int = -1) -> None:
"""
Args:
lamb: Can be any float to handle the informativeness and invariance trade-off. Ideally set to 5e-3.
Raises:
ValueError: When an input of dimension length > 2 is passed
ValueError: When input and target are of different shapes
ValueError: When batch size is less than or equal to 1
"""
super().__init__()
self.lambd = lambd

if batch_size != -1:
warn("batch_size is no longer required to be set. It will be estimated dynamically in the forward call")

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Args:
input: the shape should be B[F].
target: the shape should be B[F].
"""
if len(target.shape) > 2 or len(input.shape) > 2:
raise ValueError(
f"Either target or input has dimensions greater than 2 where target "
f"shape is ({target.shape}) and input shape is ({input.shape})"
)

if target.shape != input.shape:
raise ValueError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})")

if target.size(0) <= 1:
raise ValueError(
f"Batch size must be greater than 1 to compute Barlow Twins Loss, but got {target.size(0)}"
)

lambd_tensor = torch.as_tensor(self.lambd).to(input.device)
batch_size = input.shape[0]

# normalize input and target
input_norm = (input - input.mean(0)) / input.std(0).add(1e-6)
target_norm = (target - target.mean(0)) / target.std(0).add(1e-6)

# cross-correlation matrix
c = torch.mm(input_norm.t(), target_norm) / batch_size # input_norm.t() is FxB, target_norm is BxF so c is FxF

# loss
c_diff = (c - torch.eye(c.size(0), device=c.device)).pow_(2) # FxF
c_diff[~torch.eye(c.size(0), device=c.device).bool()] *= lambd_tensor

return c_diff.sum()
109 changes: 109 additions & 0 deletions tests/test_barlow_twins_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.losses import BarlowTwinsLoss

TEST_CASES = [
[ # shape: (2, 4), (2, 4)
{"lambd": 5e-3},
{
"input": torch.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]),
"target": torch.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]),
},
4.0,
],
[ # shape: (2, 4), (2, 4)
{"lambd": 5e-3},
{
"input": torch.tensor([[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]]),
"target": torch.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]),
},
4.0,
],
[ # shape: (2, 4), (2, 4)
{"lambd": 5e-3},
{
"input": torch.tensor([[1.0, 0.0, 1.0, 1.0], [0.0, 1.0, 1.0, 0.0]]),
"target": torch.tensor([[1.0, 1.0, 1.0, 0.0], [1.0, 1.0, 0.0, 1.0]]),
},
5.2562,
],
[ # shape: (2, 4), (2, 4)
{"lambd": 5e-4},
{
"input": torch.tensor([[2.0, 3.0, 1.0, 2.0], [0.0, 1.0, 2.0, 5.0]]),
"target": torch.tensor([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]]),
},
5.0015,
],
[ # shape: (4, 4), (4, 4)
{"lambd": 5e-3},
{
"input": torch.tensor(
[[1.0, 2.0, 1.0, 1.0], [3.0, 1.0, 1.0, 2.0], [1.0, 1.0, 1.0, 1.0], [2.0, 1.0, 1.0, 0.0]]
),
"target": torch.tensor(
[
[0.0, 1.0, -1.0, 0.0],
[1 / 3, 0.0, -2 / 3, 1 / 3],
[-2 / 3, -1.0, 7 / 3, 1 / 3],
[1 / 3, 0.0, 1 / 3, -2 / 3],
]
),
},
1.4736,
],
]


class TestBarlowTwinsLoss(unittest.TestCase):

@parameterized.expand(TEST_CASES)
def test_result(self, input_param, input_data, expected_val):
barlowtwinsloss = BarlowTwinsLoss(**input_param)
result = barlowtwinsloss(**input_data)
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)

def test_ill_shape(self):
loss = BarlowTwinsLoss(lambd=5e-3)
with self.assertRaisesRegex(ValueError, ""):
loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))

def test_ill_batch_size(self):
loss = BarlowTwinsLoss(lambd=5e-3, batch_size=1)
with self.assertRaisesRegex(ValueError, ""):
loss(torch.ones((1, 2)), torch.ones((1, 2)))

def test_with_cuda(self):
loss = BarlowTwinsLoss(lambd=5e-3)
i = torch.ones((2, 10))
j = torch.ones((2, 10))
if torch.cuda.is_available():
i = i.cuda()
j = j.cuda()
output = loss(i, j)
np.testing.assert_allclose(output.detach().cpu().numpy(), 10.0, atol=1e-4, rtol=1e-4)

def check_warning_rasied(self):
with self.assertWarns(Warning):
BarlowTwinsLoss(lambd=5e-3, batch_size=1)


if __name__ == "__main__":
unittest.main()

0 comments on commit 1371b0d

Please sign in to comment.