-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Barlow Twins loss for representation learning
- Loading branch information
1 parent
95f69de
commit 1371b0d
Showing
4 changed files
with
198 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# Copyright (c) MONAI Consortium | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import annotations | ||
|
||
from warnings import warn | ||
|
||
import torch | ||
from torch.nn.modules.loss import _Loss | ||
|
||
|
||
class BarlowTwinsLoss(_Loss): | ||
""" | ||
Compute the Barlow Twins loss defined in: | ||
Zbontar, Jure, et al. "Barlow Twins: Self-Supervised Learning via Redundancy Reduction" International | ||
conference on machine learning. PMLR, 2020. (http://proceedings.mlr.press/v139/zbontar21a/zbontar21a.pdf) | ||
Adapted from: | ||
https://github.com/facebookresearch/barlowtwins | ||
""" | ||
|
||
def __init__(self, lambd: float = 5e-3, batch_size: int = -1) -> None: | ||
""" | ||
Args: | ||
lamb: Can be any float to handle the informativeness and invariance trade-off. Ideally set to 5e-3. | ||
Raises: | ||
ValueError: When an input of dimension length > 2 is passed | ||
ValueError: When input and target are of different shapes | ||
ValueError: When batch size is less than or equal to 1 | ||
""" | ||
super().__init__() | ||
self.lambd = lambd | ||
|
||
if batch_size != -1: | ||
warn("batch_size is no longer required to be set. It will be estimated dynamically in the forward call") | ||
|
||
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Args: | ||
input: the shape should be B[F]. | ||
target: the shape should be B[F]. | ||
""" | ||
if len(target.shape) > 2 or len(input.shape) > 2: | ||
raise ValueError( | ||
f"Either target or input has dimensions greater than 2 where target " | ||
f"shape is ({target.shape}) and input shape is ({input.shape})" | ||
) | ||
|
||
if target.shape != input.shape: | ||
raise ValueError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})") | ||
|
||
if target.size(0) <= 1: | ||
raise ValueError( | ||
f"Batch size must be greater than 1 to compute Barlow Twins Loss, but got {target.size(0)}" | ||
) | ||
|
||
lambd_tensor = torch.as_tensor(self.lambd).to(input.device) | ||
batch_size = input.shape[0] | ||
|
||
# normalize input and target | ||
input_norm = (input - input.mean(0)) / input.std(0).add(1e-6) | ||
target_norm = (target - target.mean(0)) / target.std(0).add(1e-6) | ||
|
||
# cross-correlation matrix | ||
c = torch.mm(input_norm.t(), target_norm) / batch_size # input_norm.t() is FxB, target_norm is BxF so c is FxF | ||
|
||
# loss | ||
c_diff = (c - torch.eye(c.size(0), device=c.device)).pow_(2) # FxF | ||
c_diff[~torch.eye(c.size(0), device=c.device).bool()] *= lambd_tensor | ||
|
||
return c_diff.sum() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
# Copyright (c) MONAI Consortium | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import annotations | ||
|
||
import unittest | ||
|
||
import numpy as np | ||
import torch | ||
from parameterized import parameterized | ||
|
||
from monai.losses import BarlowTwinsLoss | ||
|
||
TEST_CASES = [ | ||
[ # shape: (2, 4), (2, 4) | ||
{"lambd": 5e-3}, | ||
{ | ||
"input": torch.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]), | ||
"target": torch.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]), | ||
}, | ||
4.0, | ||
], | ||
[ # shape: (2, 4), (2, 4) | ||
{"lambd": 5e-3}, | ||
{ | ||
"input": torch.tensor([[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]]), | ||
"target": torch.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]), | ||
}, | ||
4.0, | ||
], | ||
[ # shape: (2, 4), (2, 4) | ||
{"lambd": 5e-3}, | ||
{ | ||
"input": torch.tensor([[1.0, 0.0, 1.0, 1.0], [0.0, 1.0, 1.0, 0.0]]), | ||
"target": torch.tensor([[1.0, 1.0, 1.0, 0.0], [1.0, 1.0, 0.0, 1.0]]), | ||
}, | ||
5.2562, | ||
], | ||
[ # shape: (2, 4), (2, 4) | ||
{"lambd": 5e-4}, | ||
{ | ||
"input": torch.tensor([[2.0, 3.0, 1.0, 2.0], [0.0, 1.0, 2.0, 5.0]]), | ||
"target": torch.tensor([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]]), | ||
}, | ||
5.0015, | ||
], | ||
[ # shape: (4, 4), (4, 4) | ||
{"lambd": 5e-3}, | ||
{ | ||
"input": torch.tensor( | ||
[[1.0, 2.0, 1.0, 1.0], [3.0, 1.0, 1.0, 2.0], [1.0, 1.0, 1.0, 1.0], [2.0, 1.0, 1.0, 0.0]] | ||
), | ||
"target": torch.tensor( | ||
[ | ||
[0.0, 1.0, -1.0, 0.0], | ||
[1 / 3, 0.0, -2 / 3, 1 / 3], | ||
[-2 / 3, -1.0, 7 / 3, 1 / 3], | ||
[1 / 3, 0.0, 1 / 3, -2 / 3], | ||
] | ||
), | ||
}, | ||
1.4736, | ||
], | ||
] | ||
|
||
|
||
class TestBarlowTwinsLoss(unittest.TestCase): | ||
|
||
@parameterized.expand(TEST_CASES) | ||
def test_result(self, input_param, input_data, expected_val): | ||
barlowtwinsloss = BarlowTwinsLoss(**input_param) | ||
result = barlowtwinsloss(**input_data) | ||
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4) | ||
|
||
def test_ill_shape(self): | ||
loss = BarlowTwinsLoss(lambd=5e-3) | ||
with self.assertRaisesRegex(ValueError, ""): | ||
loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3))) | ||
|
||
def test_ill_batch_size(self): | ||
loss = BarlowTwinsLoss(lambd=5e-3, batch_size=1) | ||
with self.assertRaisesRegex(ValueError, ""): | ||
loss(torch.ones((1, 2)), torch.ones((1, 2))) | ||
|
||
def test_with_cuda(self): | ||
loss = BarlowTwinsLoss(lambd=5e-3) | ||
i = torch.ones((2, 10)) | ||
j = torch.ones((2, 10)) | ||
if torch.cuda.is_available(): | ||
i = i.cuda() | ||
j = j.cuda() | ||
output = loss(i, j) | ||
np.testing.assert_allclose(output.detach().cpu().numpy(), 10.0, atol=1e-4, rtol=1e-4) | ||
|
||
def check_warning_rasied(self): | ||
with self.assertWarns(Warning): | ||
BarlowTwinsLoss(lambd=5e-3, batch_size=1) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |