-
Notifications
You must be signed in to change notification settings - Fork 3
/
loss.py
135 lines (90 loc) · 3.17 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# based on https://github.com/LynnHo/DCGAN-LSGAN-WGAN-GP-DRAGAN-Pytorch
import torch
def get_gan_losses_fn():
bce = torch.nn.BCEWithLogitsLoss()
def real_loss_fn(r_logit):
return bce(r_logit, torch.ones_like(r_logit))
def fake_loss_fn(f_logit):
return bce(f_logit, torch.zeros_like(f_logit))
return real_loss_fn, fake_loss_fn
def get_hinge_v1_losses_fn():
def real_loss_fn(r_logit):
return torch.max(1 - r_logit, torch.zeros_like(r_logit)).mean()
def fake_loss_fn(f_logit):
return torch.max(1 + f_logit, torch.zeros_like(f_logit)).mean()
return real_loss_fn, fake_loss_fn
def get_lsgan_losses_fn():
mse = torch.nn.MSELoss()
def real_loss_fn(r_logit):
return mse(r_logit, torch.ones_like(r_logit))
def fake_loss_fn(f_logit):
return mse(f_logit, torch.zeros_like(f_logit))
return real_loss_fn, fake_loss_fn
def get_wgan_losses_fn():
def real_loss_fn(r_logit):
return -r_logit.mean()
def fake_loss_fn(f_logit):
return f_logit.mean()
return real_loss_fn, fake_loss_fn
def get_adversarial_losses_fn(mode):
if mode == 'gan':
return get_gan_losses_fn()
elif mode == 'hinge_v1':
return get_hinge_v1_losses_fn()
elif mode == 'lsgan':
return get_lsgan_losses_fn()
elif mode == 'wgan':
return get_wgan_losses_fn()
# gp ===================================================================================
# ======================================
# = sample method =
# ======================================
def _sample_line(real, fake):
shape = [real.size(0)] + [1] * (real.dim() - 1)
alpha = torch.rand(shape, device=real.device)
sample = real + alpha * (fake - real)
return sample
def _sample_DRAGAN(real, fake): # fake is useless
beta = torch.rand_like(real)
fake = real + 0.5 * real.std() * beta
sample = _sample_line(real, fake)
return sample
# ======================================
# = gradient penalty method =
# ======================================
def _norm(x):
norm = x.view(x.size(0), -1).norm(p=2, dim=1)
return norm
def _one_mean_gp(grad):
norm = _norm(grad)
gp = ((norm - 1)**2).mean()
return gp
def _zero_mean_gp(grad):
norm = _norm(grad)
gp = (norm**2).mean()
return gp
def _lipschitz_penalty(grad):
norm = _norm(grad)
gp = (torch.max(torch.zeros_like(norm), norm - 1)**2).mean()
return gp
def gradient_penalty(f, real, fake, gp_mode, sample_mode):
sample_fns = {
'line': _sample_line,
'real': lambda real, fake: real,
'fake': lambda real, fake: fake,
'dragan': _sample_DRAGAN,
}
gp_fns = {
'1-gp': _one_mean_gp,
'0-gp': _zero_mean_gp,
'lp': _lipschitz_penalty,
}
if gp_mode == 'none':
gp = torch.tensor(0, dtype=real.dtype, device=real.device)
else:
x = sample_fns[sample_mode](real, fake).detach()
x.requires_grad = True
pred = f(x)
grad = torch.autograd.grad(pred, x, grad_outputs=torch.ones_like(pred), create_graph=True)[0]
gp = gp_fns[gp_mode](grad)
return gp