-
Notifications
You must be signed in to change notification settings - Fork 60
/
usad.py
117 lines (99 loc) · 4.05 KB
/
usad.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
import torch.nn as nn
from utils import *
device = get_default_device()
class Encoder(nn.Module):
def __init__(self, in_size, latent_size):
super().__init__()
self.linear1 = nn.Linear(in_size, int(in_size/2))
self.linear2 = nn.Linear(int(in_size/2), int(in_size/4))
self.linear3 = nn.Linear(int(in_size/4), latent_size)
self.relu = nn.ReLU(True)
def forward(self, w):
out = self.linear1(w)
out = self.relu(out)
out = self.linear2(out)
out = self.relu(out)
out = self.linear3(out)
z = self.relu(out)
return z
class Decoder(nn.Module):
def __init__(self, latent_size, out_size):
super().__init__()
self.linear1 = nn.Linear(latent_size, int(out_size/4))
self.linear2 = nn.Linear(int(out_size/4), int(out_size/2))
self.linear3 = nn.Linear(int(out_size/2), out_size)
self.relu = nn.ReLU(True)
self.sigmoid = nn.Sigmoid()
def forward(self, z):
out = self.linear1(z)
out = self.relu(out)
out = self.linear2(out)
out = self.relu(out)
out = self.linear3(out)
w = self.sigmoid(out)
return w
class UsadModel(nn.Module):
def __init__(self, w_size, z_size):
super().__init__()
self.encoder = Encoder(w_size, z_size)
self.decoder1 = Decoder(z_size, w_size)
self.decoder2 = Decoder(z_size, w_size)
def training_step(self, batch, n):
z = self.encoder(batch)
w1 = self.decoder1(z)
w2 = self.decoder2(z)
w3 = self.decoder2(self.encoder(w1))
loss1 = 1/n*torch.mean((batch-w1)**2)+(1-1/n)*torch.mean((batch-w3)**2)
loss2 = 1/n*torch.mean((batch-w2)**2)-(1-1/n)*torch.mean((batch-w3)**2)
return loss1,loss2
def validation_step(self, batch, n):
with torch.no_grad():
z = self.encoder(batch)
w1 = self.decoder1(z)
w2 = self.decoder2(z)
w3 = self.decoder2(self.encoder(w1))
loss1 = 1/n*torch.mean((batch-w1)**2)+(1-1/n)*torch.mean((batch-w3)**2)
loss2 = 1/n*torch.mean((batch-w2)**2)-(1-1/n)*torch.mean((batch-w3)**2)
return {'val_loss1': loss1, 'val_loss2': loss2}
def validation_epoch_end(self, outputs):
batch_losses1 = [x['val_loss1'] for x in outputs]
epoch_loss1 = torch.stack(batch_losses1).mean()
batch_losses2 = [x['val_loss2'] for x in outputs]
epoch_loss2 = torch.stack(batch_losses2).mean()
return {'val_loss1': epoch_loss1.item(), 'val_loss2': epoch_loss2.item()}
def epoch_end(self, epoch, result):
print("Epoch [{}], val_loss1: {:.4f}, val_loss2: {:.4f}".format(epoch, result['val_loss1'], result['val_loss2']))
def evaluate(model, val_loader, n):
outputs = [model.validation_step(to_device(batch,device), n) for [batch] in val_loader]
return model.validation_epoch_end(outputs)
def training(epochs, model, train_loader, val_loader, opt_func=torch.optim.Adam):
history = []
optimizer1 = opt_func(list(model.encoder.parameters())+list(model.decoder1.parameters()))
optimizer2 = opt_func(list(model.encoder.parameters())+list(model.decoder2.parameters()))
for epoch in range(epochs):
for [batch] in train_loader:
batch=to_device(batch,device)
#Train AE1
loss1,loss2 = model.training_step(batch,epoch+1)
loss1.backward()
optimizer1.step()
optimizer1.zero_grad()
#Train AE2
loss1,loss2 = model.training_step(batch,epoch+1)
loss2.backward()
optimizer2.step()
optimizer2.zero_grad()
result = evaluate(model, val_loader, epoch+1)
model.epoch_end(epoch, result)
history.append(result)
return history
def testing(model, test_loader, alpha=.5, beta=.5):
results=[]
with torch.no_grad():
for [batch] in test_loader:
batch=to_device(batch,device)
w1=model.decoder1(model.encoder(batch))
w2=model.decoder2(model.encoder(w1))
results.append(alpha*torch.mean((batch-w1)**2,axis=1)+beta*torch.mean((batch-w2)**2,axis=1))
return results