-
Notifications
You must be signed in to change notification settings - Fork 0
/
domain_trainer.py
86 lines (75 loc) · 3.3 KB
/
domain_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from collections import defaultdict, namedtuple
class DomainTrainer:
def __init__(self, models, optims, criterions, device, **kwargs):
self.models = models
self.optims = optims
self.criterions = criterions
self.device = device
self.history = kwargs.get('history', True)
self.log_interval = kwargs.get('log_interval', 100)
self.print_logs = kwargs.get('print_logs', True)
def _train_domain(self, loaders, gr_models, epoch, train_history):
model_d = self.models.model_d.train()
model_f = self.models.model_f.eval()
train_loader = loaders.merged_test_loader
optimizer = self.optims.optim_d
criterion_domain = self.criterions.criterion_domain
if gr_models is not None:
model_c = gr_models.model_c
model_gr = gr_models.model_d
for batch_idx, (data, domains) in enumerate(train_loader):
if train_loader.dataset.get_labels:
_, domains = domains
data, domains = data.to(self.device), domains.to(self.device)
optimizer.zero_grad()
output = model_d(model_f(data))
loss = criterion_domain(output, domains)
loss.backward()
optimizer.step()
if self.history and gr_models:
model_c_mtx = model_c.get_mtx().weight.cpu().detach().numpy()
model_d_mtx = model_d.get_mtx().weight.cpu().detach().numpy()
model_gr_mtx = model_gr.get_mtx().weight.cpu().detach().numpy()
train_history['avg_len'].append(np.mean(np.diag(model_d_mtx.dot(model_d_mtx.T))))
train_history['avg_dot'].append(np.mean(model_d_mtx.dot(model_c_mtx.T)))
train_history['avg_dot_gr'].append(np.mean(model_d_mtx.dot(model_gr_mtx.T)))
if batch_idx % self.log_interval == 0 and self.print_logs:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
@staticmethod
def test_domain_pred(model, device, merged_test_loader, print_logs=True, test_history=None):
model.eval()
domain_test_loss = 0
domain_correct = 0
with torch.no_grad():
for data, target in merged_test_loader:
data = data.to(device)
if merged_test_loader.dataset.get_labels:
_, domains = target
else:
domains = target
domains = domains.to(device)
domain_out = model(data)
domain_pred = domain_out.max(1, keepdim=True)[1]
domain_correct += domain_pred.eq(domains.view_as(domain_pred)).sum().item()
domain_test_loss /= len(merged_test_loader.dataset)
if print_logs:
print('\nDomains predictor: Accuracy: {}/{} ({:.0f}%)\n'.format(
domain_correct, len(merged_test_loader.dataset),
100. * domain_correct / len(merged_test_loader.dataset)))
if test_history is not None:
test_history['acc'].append(100. * domain_correct / len(merged_test_loader.dataset))
def train(self, epochs, loaders, gr_models=None, train_history=None):
self.epochs = epochs
if train_history is None:
train_history = defaultdict(lambda:[])
for epoch in range(1, self.epochs+1):
self._train_domain(loaders, gr_models, epoch, train_history)
domain_model = nn.Sequential(self.models.model_f, self.models.model_d)
self.test_domain_pred(domain_model, self.device, loaders.merged_test_loader, print_logs=self.print_logs, test_history=train_history)