-
Notifications
You must be signed in to change notification settings - Fork 59
/
train.py
93 lines (80 loc) · 3.72 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# -*- coding: utf-8 -*-
import os
import pickle
import random
import argparse
import torch as t
import numpy as np
from tqdm import tqdm
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from model import Word2Vec, SGNS
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--name', type=str, default='sgns', help="model name")
parser.add_argument('--data_dir', type=str, default='./data/', help="data directory path")
parser.add_argument('--save_dir', type=str, default='./pts/', help="model directory path")
parser.add_argument('--e_dim', type=int, default=300, help="embedding dimension")
parser.add_argument('--n_negs', type=int, default=20, help="number of negative samples")
parser.add_argument('--epoch', type=int, default=100, help="number of epochs")
parser.add_argument('--mb', type=int, default=4096, help="mini-batch size")
parser.add_argument('--ss_t', type=float, default=1e-5, help="subsample threshold")
parser.add_argument('--conti', action='store_true', help="continue learning")
parser.add_argument('--weights', action='store_true', help="use weights for negative sampling")
parser.add_argument('--cuda', action='store_true', help="use CUDA")
return parser.parse_args()
class PermutedSubsampledCorpus(Dataset):
def __init__(self, datapath, ws=None):
data = pickle.load(open(datapath, 'rb'))
if ws is not None:
self.data = []
for iword, owords in data:
if random.random() > ws[iword]:
self.data.append((iword, owords))
else:
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
iword, owords = self.data[idx]
return iword, np.array(owords)
def train(args):
idx2word = pickle.load(open(os.path.join(args.data_dir, 'idx2word.dat'), 'rb'))
wc = pickle.load(open(os.path.join(args.data_dir, 'wc.dat'), 'rb'))
wf = np.array([wc[word] for word in idx2word])
wf = wf / wf.sum()
ws = 1 - np.sqrt(args.ss_t / wf)
ws = np.clip(ws, 0, 1)
vocab_size = len(idx2word)
weights = wf if args.weights else None
if not os.path.isdir(args.save_dir):
os.mkdir(args.save_dir)
model = Word2Vec(vocab_size=vocab_size, embedding_size=args.e_dim)
modelpath = os.path.join(args.save_dir, '{}.pt'.format(args.name))
sgns = SGNS(embedding=model, vocab_size=vocab_size, n_negs=args.n_negs, weights=weights)
if os.path.isfile(modelpath) and args.conti:
sgns.load_state_dict(t.load(modelpath))
if args.cuda:
sgns = sgns.cuda()
optim = Adam(sgns.parameters())
optimpath = os.path.join(args.save_dir, '{}.optim.pt'.format(args.name))
if os.path.isfile(optimpath) and args.conti:
optim.load_state_dict(t.load(optimpath))
for epoch in range(1, args.epoch + 1):
dataset = PermutedSubsampledCorpus(os.path.join(args.data_dir, 'train.dat'))
dataloader = DataLoader(dataset, batch_size=args.mb, shuffle=True)
total_batches = int(np.ceil(len(dataset) / args.mb))
pbar = tqdm(dataloader)
pbar.set_description("[Epoch {}]".format(epoch))
for iword, owords in pbar:
loss = sgns(iword, owords)
optim.zero_grad()
loss.backward()
optim.step()
pbar.set_postfix(loss=loss.item())
idx2vec = model.ivectors.weight.data.cpu().numpy()
pickle.dump(idx2vec, open(os.path.join(args.data_dir, 'idx2vec.dat'), 'wb'))
t.save(sgns.state_dict(), os.path.join(args.save_dir, '{}.pt'.format(args.name)))
t.save(optim.state_dict(), os.path.join(args.save_dir, '{}.optim.pt'.format(args.name)))
if __name__ == '__main__':
train(parse_args())