-
Notifications
You must be signed in to change notification settings - Fork 4
/
config_data.py
60 lines (56 loc) · 1.47 KB
/
config_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
import tensorflow as tf
flags = tf.flags
flags.DEFINE_string("dataset", "nba", "The data config.")
# load all entities
if 'nba' in flags.config_data:
dataset_dir = 'nba_data'
elif 'e2e' in flags.config_data:
dataset_dir = 'e2e_data'
else:
print('[info] You need to choose one dataset.')
modes = ['train', 'val', 'test']
mode_to_filemode = {
'train': 'train',
'val': 'valid',
'test': 'test',
}
field_to_vocabname = {
'x_value': 'x_value',
'x_type': 'x_type',
'x_associated': 'x_associated',
'y_aux': 'y',
'x_ref_value': 'x_value',
'x_ref_type': 'x_type',
'x_ref_associated': 'x_associated',
'y_ref': 'y',
}
fields = list(field_to_vocabname.keys())
train_batch_size = 32
eval_batch_size = 32
batch_sizes = {
'train': train_batch_size,
'val': eval_batch_size,
'test': eval_batch_size,
}
datas = {
mode: {
'num_epochs': 1,
'shuffle': mode == 'train',
'batch_size': batch_sizes[mode],
'allow_smaller_final_batch': mode != 'train',
'datasets': [
{
'files': [os.path.join(
dataset_dir, mode,
'{}.{}.txt'.format(field, mode_to_filemode[mode])
)],
'vocab_file': os.path.join(
dataset_dir,
'{}.vocab.txt'.format(field_to_vocabname[field])),
'data_name': field,
}
for field in fields]
}
for mode in modes
}