-
Notifications
You must be signed in to change notification settings - Fork 43
/
loss.py
74 lines (63 loc) · 2.61 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import numpy as np
import torch
import train_func as tf
import utils
from itertools import combinations
class MaximalCodingRateReduction(torch.nn.Module):
def __init__(self, gam1=1.0, gam2=1.0, eps=0.01):
super(MaximalCodingRateReduction, self).__init__()
self.gam1 = gam1
self.gam2 = gam2
self.eps = eps
def compute_discrimn_loss_empirical(self, W):
"""Empirical Discriminative Loss."""
p, m = W.shape
I = torch.eye(p).cuda()
scalar = p / (m * self.eps)
logdet = torch.logdet(I + self.gam1 * scalar * W.matmul(W.T))
return logdet / 2.
def compute_compress_loss_empirical(self, W, Pi):
"""Empirical Compressive Loss."""
p, m = W.shape
k, _, _ = Pi.shape
I = torch.eye(p).cuda()
compress_loss = 0.
for j in range(k):
trPi = torch.trace(Pi[j]) + 1e-8
scalar = p / (trPi * self.eps)
log_det = torch.logdet(I + scalar * W.matmul(Pi[j]).matmul(W.T))
compress_loss += log_det * trPi / m
return compress_loss / 2.
def compute_discrimn_loss_theoretical(self, W):
"""Theoretical Discriminative Loss."""
p, m = W.shape
I = torch.eye(p).cuda()
scalar = p / (m * self.eps)
logdet = torch.logdet(I + scalar * W.matmul(W.T))
return logdet / 2.
def compute_compress_loss_theoretical(self, W, Pi):
"""Theoretical Compressive Loss."""
p, m = W.shape
k, _, _ = Pi.shape
I = torch.eye(p).cuda()
compress_loss = 0.
for j in range(k):
trPi = torch.trace(Pi[j]) + 1e-8
scalar = p / (trPi * self.eps)
log_det = torch.logdet(I + scalar * W.matmul(Pi[j]).matmul(W.T))
compress_loss += trPi / (2 * m) * log_det
return compress_loss
def forward(self, X, Y, num_classes=None):
if num_classes is None:
num_classes = Y.max() + 1
W = X.T
Pi = tf.label_to_membership(Y.numpy(), num_classes)
Pi = torch.tensor(Pi, dtype=torch.float32).cuda()
discrimn_loss_empi = self.compute_discrimn_loss_empirical(W)
compress_loss_empi = self.compute_compress_loss_empirical(W, Pi)
discrimn_loss_theo = self.compute_discrimn_loss_theoretical(W)
compress_loss_theo = self.compute_compress_loss_theoretical(W, Pi)
total_loss_empi = self.gam2 * -discrimn_loss_empi + compress_loss_empi
return (total_loss_empi,
[discrimn_loss_empi.item(), compress_loss_empi.item()],
[discrimn_loss_theo.item(), compress_loss_theo.item()])