forked from HobbitLong/SupContrast
-
Notifications
You must be signed in to change notification settings - Fork 0
/
clustering_loss.py
executable file
·88 lines (73 loc) · 2.63 KB
/
clustering_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""
This is taken from the IIC paper.
"""
import sys
import torch
from termcolor import colored
from torch import Tensor
from torch import nn
__all__ = ["IIDLoss"]
def simplex(t: torch.Tensor, axis=1) -> bool:
"""
check if the matrix is the probability distribution
:param t:
:param axis:
:return:
"""
_sum = t.sum(axis).type(torch.float32)
_ones = torch.ones_like(_sum, dtype=torch.float32)
return torch.allclose(_sum, _ones, rtol=1e-4, atol=1e-4)
class IIDLoss(nn.Module):
def __init__(self, lamb: float = 1.0, eps: float = sys.float_info.epsilon):
"""
:param lamb:
:param eps:
"""
super().__init__()
print(colored(f"Initialize {self.__class__.__name__}.", "green"))
self.lamb = float(lamb)
self.eps = float(eps)
self.torch_vision = torch.__version__
def forward(self, x_out: Tensor, x_tf_out: Tensor):
"""
return the inverse of the MI. if the x_out == y_out, return the inverse of Entropy
:param x_out:
:param x_tf_out:
:return:
"""
assert simplex(x_out), f"x_out not normalized."
assert simplex(x_tf_out), f"x_tf_out not normalized."
_, k = x_out.size()
p_i_j = compute_joint(x_out, x_tf_out)
assert p_i_j.size() == (k, k)
p_i = (
p_i_j.sum(dim=1).view(k, 1).expand(k, k)
) # p_i should be the mean of the x_out
p_j = p_i_j.sum(dim=0).view(1, k).expand(k, k) # but should be same, symmetric
# p_i = x_out.mean(0).view(k, 1).expand(k, k)
# p_j = x_tf_out.mean(0).view(1, k).expand(k, k)
#
# avoid NaN losses. Effect will get cancelled out by p_i_j tiny anyway
loss = -p_i_j * (
torch.log(p_i_j+self.eps) - self.lamb * torch.log(p_j+self.eps) - self.lamb * torch.log(p_i+self.eps)
)
loss = loss.sum()
return loss, p_i_j
def compute_joint(x_out: Tensor, x_tf_out: Tensor, symmetric=True) -> Tensor:
r"""
return joint probability
:param x_out: p1, simplex
:param x_tf_out: p2, simplex
:return: joint probability
"""
# produces variable that requires grad (since args require grad)
assert simplex(x_out), f"x_out not normalized."
assert simplex(x_tf_out), f"x_tf_out not normalized."
bn, k = x_out.shape
assert x_tf_out.size(0) == bn and x_tf_out.size(1) == k
p_i_j = x_out.unsqueeze(2) * x_tf_out.unsqueeze(1) # bn, k, k
p_i_j = p_i_j.sum(dim=0) # k, k aggregated over one batch
if symmetric:
p_i_j = (p_i_j + p_i_j.t()) / 2.0 # symmetric
p_i_j /= p_i_j.sum() # normalise
return p_i_j