Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
weak ranking system for seq2seq (#235)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexholdenmiller authored Aug 1, 2017
1 parent 143ee97 commit 9800a92
Show file tree
Hide file tree
Showing 6 changed files with 217 additions and 67 deletions.
44 changes: 26 additions & 18 deletions examples/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.
'''Train a model.
"""Train a model.
After training, computes validation and test error.
Expand All @@ -21,51 +21,58 @@
TODO List:
- More logging (e.g. to files), make things prettier.
'''
"""

from parlai.core.agents import create_agent
from parlai.core.worlds import create_task
from parlai.core.params import ParlaiParser
from parlai.core.utils import Timer
import build_dict
import copy
import importlib
import math
import os

def run_eval(agent, opt, datatype, still_training=False, max_exs=-1):
''' Eval on validation/test data. '''
def run_eval(agent, opt, datatype, max_exs=-1, write_log=False, valid_world=None):
"""Eval on validation/test data.
- Agent is the agent to use for the evaluation.
- opt is the options that specific the task, eval_task, etc
- datatype is the datatype to use, such as "valid" or "test"
- write_log specifies to write metrics to file if the model_file is set
- max_exs limits the number of examples if max_exs > 0
- valid_world can be an existing world which will be reset instead of reinitialized
"""
print('[ running eval: ' + datatype + ' ]')
opt['datatype'] = datatype
if opt.get('evaltask'):

opt['task'] = opt['evaltask']

valid_world = create_task(opt, agent)
if valid_world is None:
valid_world = create_task(opt, agent)
else:
valid_world.reset()
cnt = 0
for _ in valid_world:
valid_world.parley()
if cnt == 0 and opt['display_examples']:
first_run = False
print(valid_world.display() + '\n~~')
print(valid_world.report())
cnt += opt['batchsize']
if valid_world.epoch_done() or (max_exs > 0 and cnt > max_exs):
# note this max_exs is approximate--some batches won't always be
# full depending on the structure of the data
break
valid_world.shutdown()
valid_report = valid_world.report()

metrics = datatype + ':' + str(valid_report)
print(metrics)
if still_training:
return valid_report
elif opt['model_file']:
if write_log and opt['model_file']:
# Write out metrics
f = open(opt['model_file'] + '.' + datatype, 'a+')
f.write(metrics + '\n')
f.close()

return valid_report, valid_world


def main():
# Get command line arguments
parser = ParlaiParser(True, True)
Expand Down Expand Up @@ -115,6 +122,7 @@ def main():
best_accuracy = 0
impatience = 0
saved = False
valid_world = None
while True:
world.parley()
parleys += 1
Expand Down Expand Up @@ -149,7 +157,7 @@ def main():
# check if we should log amount of time remaining
time_left = None
if opt['num_epochs'] > 0:
exs_per_sec = train_time.time() / total_exs
exs_per_sec = train_time.time() / total_exs
time_left = (max_exs - total_exs) * exs_per_sec
if opt['max_train_time'] > 0:
other_time_left = opt['max_train_time'] - train_time.time()
Expand All @@ -168,11 +176,11 @@ def main():

if (opt['validation_every_n_secs'] > 0 and
validate_time.time() > opt['validation_every_n_secs']):
valid_report = run_eval(agent, opt, 'valid', True, opt['validation_max_exs'])
valid_report, valid_world = run_eval(agent, opt, 'valid', opt['validation_max_exs'], valid_world=valid_world)
if valid_report['accuracy'] > best_accuracy:
best_accuracy = valid_report['accuracy']
impatience = 0
print('[ new best accuracy: ' + str(best_accuracy) + ' ]')
print('[ new best accuracy: ' + str(best_accuracy) + ' ]')
world.save_agents()
saved = True
if best_accuracy == 1:
Expand All @@ -193,8 +201,8 @@ def main():
# reload best validation model
agent = create_agent(opt)

run_eval(agent, opt, 'valid')
run_eval(agent, opt, 'test')
run_eval(agent, opt, 'valid', write_log=True)
run_eval(agent, opt, 'test', write_log=True)


if __name__ == '__main__':
Expand Down
Loading

0 comments on commit 9800a92

Please sign in to comment.