-
Notifications
You must be signed in to change notification settings - Fork 5
/
helper.py
111 lines (87 loc) · 3.19 KB
/
helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import torch
import glob
import yaml
import six
from ignite.engine import Engine
def load_yaml(config_path):
if not isinstance(config_path, six.string_types):
raise ValueError("Got {}, expected string", type(config_path))
else:
with open(config_path, "r") as yaml_file:
config = yaml.load(yaml_file)
return config
def create_supervised_evaluator(model, inference_fn, metrics={}, cuda=False):
"""
Factory function for creating an evaluator for supervised models.
Extended version from ignite's create_supervised_evaluator
Args:
model (torch.nn.Module): the model to train
inference_fn (function): inference function
metrics (dict of str: Metric): a map of metric names to Metrics
cuda (bool, optional): whether or not to transfer batch to GPU
(default: False)
Returns:
Engine: an evaluator engine with supervised inference function
"""
engine = Engine(inference_fn)
for name, metric in metrics.items():
metric.attach(engine, name)
return engine
def create_training_function(tagger, opt):
def training_function(engine, batch):
tagger.train()
opt.zero_grad()
sentence = batch.sentence[0]
sent_len = batch.sentence[1].numpy()
char_rep = batch.char_sentence[0]
tags = batch.tags
result = tagger(char_rep, sentence, sent_len, tags)
result.backward()
opt.step()
return result.detach()
return training_function
def create_evaluation_function(tagger):
def evaluation_function(engine, batch):
tagger.eval()
sentence = batch.sentence[0]
sent_len = batch.sentence[1].numpy()
char_rep = batch.char_sentence[0]
tags = batch.tags
result = torch.tensor(
tagger.decode(char_rep, sentence, sent_len), dtype=torch.int32)
result = result.transpose(1, 0)
return result, tags.detach()
return evaluation_function
def restore_model(path, restore="latest"):
"""Restore saved model
Args:
path (str): The path where the model is saved
restore (int or str): Amongst the saved model, which last
saved model would like to be restored. 1, 2, 3, ... or latest
Returns:
model: nn.Module
"""
models = glob.glob(path)
if len(models) == 0:
print("No models are found, it's either you put the wrong"
" path or the model is not even existed yet!")
return None
# Sort the models based on the time date
models.sort(key=os.path.getmtime)
if restore == "latest":
restored_model = models[-1]
else:
if isinstance(restore, int) and \
restore < len(models):
restored_model = models[restore]
else:
raise ValueError("Value of restore must be either latest"
" or should be an integer with value less than %d"
% len(models))
try:
model = torch.load(restored_model)
print("Successfully restored model!")
return model
except Exception as e:
print("Something wrong while restoring the model: %s" % str(e))