-
Notifications
You must be signed in to change notification settings - Fork 2
/
main.py
105 lines (85 loc) · 3.39 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# -*- coding: utf-8 -*-
import tensorflow as tf
import numpy as np
np.random.seed(54322)
import random
import os
from tfdeterminism import patch
patch()
from configs.config import CFG
from model.model_ import Model
from loaders import dataloader
from utils.config import Config
from preprocessing import prep
from evaluation.evaluator import Eval
from utils import helpers
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
def set_seeds(seed):
random.seed(seed)
np.random.seed(seed)
tf.set_random_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
def op(config, embedding_matrix, evaluator, seeds, experiment_name,
xtrain, ytrain, xtest, ytest, xval, yval):
for i, seed in enumerate(seeds):
with tf.Session() as sess:
model_cls = Model(config=config, emb_mat=embedding_matrix)
print("experiment " + str(i + 1) + " with seed " + str(seed))
set_seeds(seed)
# training and evaluation
model_cls.build_model()
model_cls.vis_model()
model_cls.compile_model()
history = model_cls.training(xtrain, ytrain, xval, yval)
evaluator.evaluation(xtest, ytest, history.history, seed=seed, model=model_cls.model)
helpers.graph(history.history, log_dir=experiment_name, fig_name="/plots/losses" + str(seed) + ".png")
tf.reset_default_graph()
if evaluator.exp_repeats != 1:
evaluator.save_results()
def main_pipeline():
"""Builds model, loads data, trains and evaluates"""
"""setting up model configuration"""
config = Config.from_json(CFG)
"""auto configuration"""
config.create_folders()
"""n_classes = config.get_output_size()
exp_repeats = config.get_number_of_experiments()
"""
"""loading dataset"""
d_l = dataloader.DataLoader(config)
data = d_l.load_data()
"""preprocessing pipeline"""
preprocessor = prep.Preprocessor(config=config,
data=data)
tweets, labels = preprocessor.transform_crisis_lex()
tweets_preped = preprocessor.text_preprocessing(tweets)
xtrain, ytrain, xtest, ytest, xval, yval = preprocessor.splitting(tweets_preped, labels)
if config.data["balancing"] != False:
xtrain, ytrain = preprocessor.balancing(xtrain, ytrain)
elif config.data["augmentation"]:
#f_name = "multiclass" if config.data["setting"] == "info_type" else "binary"
xtrain, ytrain = d_l.augment(xtrain, ytrain, encoder=preprocessor.oh_enc)
tokenizer, xtrain, xtest, xval = preprocessor.tokens(train_data=xtrain, test_data=xtest, val_data=xval)
# update config
config.set_sequence_len(xtrain=xtrain)
config.set_vocabulary_size(size=len(tokenizer.word_index))
"""
loading embedding vectors
building embedding matrix
"""
w2v_model = d_l.load_embeddings()
embedding_matrix = d_l.build_embedding_matrix(w2v_model=w2v_model,
tokenizer=tokenizer)
"""training"""
evaluator = Eval(exp_repeats=config.get_number_of_experiments(),
n_class=config.get_output_size(),
log_dir=config.data['experiment_name'])
seeds = config.train['seeds']
op(config, embedding_matrix,
evaluator, seeds,
config.data["experiment_name"],
xtrain, ytrain,
xtest, ytest,
xval, yval)
if __name__ == '__main__':
main_pipeline()