-
Notifications
You must be signed in to change notification settings - Fork 0
/
trainer.py
102 lines (80 loc) · 3.46 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import time
import json
import random
import torch
import torch.nn as nn
import torch.optim as optim
from model.net import BERT
from model.optim import ScheduledAdam
random.seed(32)
torch.manual_seed(32)
torch.backends.cudnn.deterministic = True
class Trainer:
def __init__(self, params, train_iter=None, valid_iter=None, test_iter=None):
self.params = params
if params.mode == 'train':
self.train_iter = train_iter
self.valid_iter = valid_iter
else:
self.test_iter = test_iter
self.model = BERT(self.params)
self.model.to(params.device)
self.optimizer = ScheduledAdam(
optimzer=optim.Adam(self.model.parameters(), betas=(0.9, 0.999)),
hidden_dim=params.hidden_dim,
warm_steps=params.warm_steps
)
self.optimizer = optim.Adam(self.model.parameters())
self.lm_criterion = nn.CrossEntropyLoss(ignore_index=0)
self.lm_criterion.to(params.device)
self.cls_criterion = nn.CrossEntropyLoss()
self.cls_criterion.to(params.device)
def train(self):
# print(self.model)
print(f'The model has {self.model.count_params():,} trainable parameters')
best_valid_loss = float('inf')
# For presentation
f_vocab = open('vocab.json')
vocab = json.load(f_vocab)
idx_vocab = {i: w for i, w in enumerate(vocab)}
for epoch in range(self.params.num_epoch):
self.model.train()
epoch_loss = 0
start_time = time.time()
# for batch in self.train_iter: # batchify
input_ids, segment_ids, masked_tokens, masked_pos, num_masked, isNext = self.train_iter
self.optimizer.zero_grad()
lm_logits, cls_logits = self.model(input_ids, segment_ids, masked_pos)
# Calculate losses for Masked LM
loss_lm = self.lm_criterion(lm_logits.transpose(1, 2), masked_tokens)
# lm_logits = [batch size, vocab size, max pred]
# masked_tokens = [batch size, max pred]
loss_lm = (loss_lm*num_masked.float()).mean()
# Calculate losses for Next Sentence Prediction
loss_cls = self.cls_criterion(cls_logits, isNext)
loss = loss_lm + loss_cls
# For presentation
rand_idx = random.randrange(len(input_ids))
first_sent = [idx_vocab[idx] for idx in input_ids[rand_idx].cpu().numpy()]
mask_toks = [idx_vocab[idx] for idx in masked_tokens[rand_idx].cpu().numpy()]
pred_toks = [idx_vocab[idx] for idx in torch.argmax(lm_logits[rand_idx], 1).cpu().numpy()]
print(f'Masked sentence: \n{" ".join(first_sent)}')
print(f'Masked tokens: {mask_toks}')
print(f'Masked positions: {masked_pos[rand_idx].cpu().numpy()}')
print(f'Is next?: {bool(isNext[rand_idx])}')
print(f'---------------------------------------')
print(f'[Predict] Masked tokens: {pred_toks}')
print(f'[Predict] Is next: {torch.argmax(cls_logits[rand_idx])}')
print(f'---------------------------------------')
loss.backward()
self.optimizer.step()
def valid(self):
self.model.eval()
epoch_loss = 0
with torch.no_grad():
pass
def test(self):
self.model.eval()
epoch_loss = 0
with torch.no_grad():
pass