forked from DAMO-DI-ML/KDD2023-DCdetector
-
Notifications
You must be signed in to change notification settings - Fork 0
/
solver.py
346 lines (292 loc) · 16.3 KB
/
solver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
import time
from utils.utils import *
from model.DCdetector import DCdetector
from data_factory.data_loader import get_loader_segment
from einops import rearrange
from metrics.metrics import *
import warnings
warnings.filterwarnings('ignore')
def my_kl_loss(p, q):
res = p * (torch.log(p + 0.0001) - torch.log(q + 0.0001))
return torch.mean(torch.sum(res, dim=-1), dim=1)
def adjust_learning_rate(optimizer, epoch, lr_):
lr_adjust = {epoch: lr_ * (0.5 ** ((epoch - 1) // 1))}
if epoch in lr_adjust.keys():
lr = lr_adjust[epoch]
for param_group in optimizer.param_groups:
param_group['lr'] = lr
class EarlyStopping:
def __init__(self, patience=7, verbose=False, dataset_name='', delta=0):
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.best_score2 = None
self.early_stop = False
self.val_loss_min = np.Inf
self.val_loss2_min = np.Inf
self.delta = delta
self.dataset = dataset_name
def __call__(self, val_loss, val_loss2, model, path):
score = -val_loss
score2 = -val_loss2
if self.best_score is None:
self.best_score = score
self.best_score2 = score2
self.save_checkpoint(val_loss, val_loss2, model, path)
elif score < self.best_score + self.delta or score2 < self.best_score2 + self.delta:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.best_score2 = score2
self.save_checkpoint(val_loss, val_loss2, model, path)
self.counter = 0
def save_checkpoint(self, val_loss, val_loss2, model, path):
torch.save(model.state_dict(), os.path.join(path, str(self.dataset) + '_checkpoint.pth'))
self.val_loss_min = val_loss
self.val_loss2_min = val_loss2
class Solver(object):
DEFAULTS = {}
def __init__(self, config):
self.__dict__.update(Solver.DEFAULTS, **config)
self.train_loader = get_loader_segment(self.index, 'dataset/'+self.data_path, batch_size=self.batch_size, win_size=self.win_size, mode='train', dataset=self.dataset, )
self.vali_loader = get_loader_segment(self.index, 'dataset/'+self.data_path, batch_size=self.batch_size, win_size=self.win_size, mode='val', dataset=self.dataset)
self.test_loader = get_loader_segment(self.index, 'dataset/'+self.data_path, batch_size=self.batch_size, win_size=self.win_size, mode='test', dataset=self.dataset)
self.thre_loader = get_loader_segment(self.index, 'dataset/'+self.data_path, batch_size=self.batch_size, win_size=self.win_size, mode='thre', dataset=self.dataset)
self.build_model()
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if self.loss_fuc == 'MAE':
self.criterion = nn.L1Loss()
elif self.loss_fuc == 'MSE':
self.criterion = nn.MSELoss()
def build_model(self):
self.model = DCdetector(win_size=self.win_size, enc_in=self.input_c, c_out=self.output_c, n_heads=self.n_heads, d_model=self.d_model, e_layers=self.e_layers, patch_size=self.patch_size, channel=self.input_c)
if torch.cuda.is_available():
self.model.cuda()
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
def vali(self, vali_loader):
self.model.eval()
loss_1 = []
loss_2 = []
for i, (input_data, _) in enumerate(vali_loader):
input = input_data.float().to(self.device)
series, prior = self.model(input)
series_loss = 0.0
prior_loss = 0.0
for u in range(len(prior)):
series_loss += (torch.mean(my_kl_loss(series[u], (
prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
self.win_size)).detach())) + torch.mean(
my_kl_loss(
(prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
self.win_size)).detach(),
series[u])))
prior_loss += (torch.mean(
my_kl_loss((prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
self.win_size)),
series[u].detach())) + torch.mean(
my_kl_loss(series[u].detach(),
(prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
self.win_size)))))
series_loss = series_loss / len(prior)
prior_loss = prior_loss / len(prior)
loss_1.append((prior_loss - series_loss).item())
return np.average(loss_1), np.average(loss_2)
def train(self):
time_now = time.time()
path = self.model_save_path
if not os.path.exists(path):
os.makedirs(path)
early_stopping = EarlyStopping(patience=5, verbose=True, dataset_name=self.data_path)
train_steps = len(self.train_loader)
for epoch in range(self.num_epochs):
iter_count = 0
epoch_time = time.time()
self.model.train()
for i, (input_data, labels) in enumerate(self.train_loader):
self.optimizer.zero_grad()
iter_count += 1
input = input_data.float().to(self.device)
series, prior = self.model(input)
series_loss = 0.0
prior_loss = 0.0
for u in range(len(prior)):
series_loss += (torch.mean(my_kl_loss(series[u], (
prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
self.win_size)).detach())) + torch.mean(
my_kl_loss((prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
self.win_size)).detach(),
series[u])))
prior_loss += (torch.mean(my_kl_loss(
(prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
self.win_size)),
series[u].detach())) + torch.mean(
my_kl_loss(series[u].detach(), (
prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
self.win_size)))))
series_loss = series_loss / len(prior)
prior_loss = prior_loss / len(prior)
loss = prior_loss - series_loss
if (i + 1) % 100 == 0:
speed = (time.time() - time_now) / iter_count
left_time = speed * ((self.num_epochs - epoch) * train_steps - i)
print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
iter_count = 0
time_now = time.time()
loss.backward()
self.optimizer.step()
vali_loss1, vali_loss2 = self.vali(self.test_loader)
print(
"Epoch: {0}, Cost time: {1:.3f}s ".format(
epoch + 1, time.time() - epoch_time))
early_stopping(vali_loss1, vali_loss2, self.model, path)
if early_stopping.early_stop:
break
adjust_learning_rate(self.optimizer, epoch + 1, self.lr)
def test(self):
self.model.load_state_dict(
torch.load(
os.path.join(str(self.model_save_path), str(self.data_path) + '_checkpoint.pth')))
self.model.eval()
temperature = 50
# (1) stastic on the train set
attens_energy = []
for i, (input_data, labels) in enumerate(self.train_loader):
input = input_data.float().to(self.device)
series, prior = self.model(input)
series_loss = 0.0
prior_loss = 0.0
for u in range(len(prior)):
if u == 0:
series_loss = my_kl_loss(series[u], (
prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
self.win_size)).detach()) * temperature
prior_loss = my_kl_loss(
(prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
self.win_size)),
series[u].detach()) * temperature
else:
series_loss += my_kl_loss(series[u], (
prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
self.win_size)).detach()) * temperature
prior_loss += my_kl_loss(
(prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
self.win_size)),
series[u].detach()) * temperature
metric = torch.softmax((-series_loss - prior_loss), dim=-1)
cri = metric.detach().cpu().numpy()
attens_energy.append(cri)
attens_energy = np.concatenate(attens_energy, axis=0).reshape(-1)
train_energy = np.array(attens_energy)
# (2) find the threshold
attens_energy = []
for i, (input_data, labels) in enumerate(self.thre_loader):
input = input_data.float().to(self.device)
series, prior = self.model(input)
series_loss = 0.0
prior_loss = 0.0
for u in range(len(prior)):
if u == 0:
series_loss = my_kl_loss(series[u], (
prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
self.win_size)).detach()) * temperature
prior_loss = my_kl_loss(
(prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
self.win_size)),
series[u].detach()) * temperature
else:
series_loss += my_kl_loss(series[u], (
prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
self.win_size)).detach()) * temperature
prior_loss += my_kl_loss(
(prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
self.win_size)),
series[u].detach()) * temperature
metric = torch.softmax((-series_loss - prior_loss), dim=-1)
cri = metric.detach().cpu().numpy()
attens_energy.append(cri)
attens_energy = np.concatenate(attens_energy, axis=0).reshape(-1)
test_energy = np.array(attens_energy)
combined_energy = np.concatenate([train_energy, test_energy], axis=0)
thresh = np.percentile(combined_energy, 100 - self.anormly_ratio)
print("Threshold :", thresh)
# (3) evaluation on the test set
test_labels = []
attens_energy = []
for i, (input_data, labels) in enumerate(self.thre_loader):
input = input_data.float().to(self.device)
series, prior = self.model(input)
series_loss = 0.0
prior_loss = 0.0
for u in range(len(prior)):
if u == 0:
series_loss = my_kl_loss(series[u], (
prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
self.win_size)).detach()) * temperature
prior_loss = my_kl_loss(
(prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
self.win_size)),
series[u].detach()) * temperature
else:
series_loss += my_kl_loss(series[u], (
prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
self.win_size)).detach()) * temperature
prior_loss += my_kl_loss(
(prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
self.win_size)),
series[u].detach()) * temperature
metric = torch.softmax((-series_loss - prior_loss), dim=-1)
cri = metric.detach().cpu().numpy()
attens_energy.append(cri)
test_labels.append(labels)
attens_energy = np.concatenate(attens_energy, axis=0).reshape(-1)
test_labels = np.concatenate(test_labels, axis=0).reshape(-1)
test_energy = np.array(attens_energy)
test_labels = np.array(test_labels)
pred = (test_energy > thresh).astype(int)
gt = test_labels.astype(int)
matrix = [self.index]
scores_simple = combine_all_evaluation_scores(pred, gt, test_energy)
for key, value in scores_simple.items():
matrix.append(value)
print('{0:21} : {1:0.4f}'.format(key, value))
anomaly_state = False
for i in range(len(gt)):
if gt[i] == 1 and pred[i] == 1 and not anomaly_state:
anomaly_state = True
for j in range(i, 0, -1):
if gt[j] == 0:
break
else:
if pred[j] == 0:
pred[j] = 1
for j in range(i, len(gt)):
if gt[j] == 0:
break
else:
if pred[j] == 0:
pred[j] = 1
elif gt[i] == 0:
anomaly_state = False
if anomaly_state:
pred[i] = 1
pred = np.array(pred)
gt = np.array(gt)
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import accuracy_score
accuracy = accuracy_score(gt, pred)
precision, recall, f_score, support = precision_recall_fscore_support(gt, pred, average='binary')
print("Accuracy : {:0.4f}, Precision : {:0.4f}, Recall : {:0.4f}, F-score : {:0.4f} ".format(accuracy, precision, recall, f_score))
if self.data_path == 'UCR' or 'UCR_AUG':
import csv
with open('result/'+self.data_path+'.csv', 'a+') as f:
writer = csv.writer(f)
writer.writerow(matrix)
return accuracy, precision, recall, f_score