Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
train/dict updates and move seq2seq (#153)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexholdenmiller authored Jun 21, 2017
1 parent 098db44 commit 7d1a2ea
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 36 deletions.
4 changes: 3 additions & 1 deletion examples/build_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
import os

def build_dict(opt):
if 'dict_file' not in opt:
if not opt.get('dict_file'):
print('Tried to build dictionary but `--dict-file` is not set. Set ' +
'this param so the dictionary can be saved.')
return
print('[ setting up dictionary. ]')
if os.path.isfile(opt['dict_file']):
Expand Down
51 changes: 28 additions & 23 deletions examples/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,25 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.
"""Train a model.
'''Train a model.
After training, computes validation and test error.
Run with, e.g.:
python examples/train_model.py -m ir_baseline -t dialog_babi:Task:1 -mf "/tmp/model"
python examples/train_model.py -m ir_baseline -t dialog_babi:Task:1 -mf '/tmp/model'
..or..
python examples/train_model.py -m rnn_baselines/seq2seq -t babi:Task10k:1 -mf "/tmp/model" -bs 32 -lr 0.5 -hs 128
python examples/train_model.py -m rnn_baselines/seq2seq -t babi:Task10k:1 -mf '/tmp/model' -bs 32 -lr 0.5 -hs 128
..or..
python examples/train_model.py -m drqa -t babi:Task10k:1 -mf "/tmp/model" -bs 10
python examples/train_model.py -m drqa -t babi:Task10k:1 -mf '/tmp/model' -bs 10
TODO List:
- More logging (e.g. to files), make things prettier.
"""
'''

from parlai.core.agents import create_agent
from parlai.core.worlds import create_task
Expand All @@ -36,32 +36,37 @@

def run_eval(agent, opt, datatype, still_training=False):
''' Eval on validation/test data. '''
print("[ running eval: " + datatype + " ]")
print('[ running eval: ' + datatype + ' ]')
opt['datatype'] = datatype
if opt.get('evaltask'):
opt['task'] = opt['evaltask']
valid_world = create_task(opt, agent)
for i in range(len(valid_world)):
valid_world.parley()
if i == 1 and opt['display_examples']:
print(valid_world.display() + "\n~~")
print(valid_world.display() + '\n~~')
print(valid_world.report())
if valid_world.epoch_done():
break
valid_world.shutdown()
valid_report = valid_world.report()
metrics = datatype + ":" + str(valid_report)
metrics = datatype + ':' + str(valid_report)
print(metrics)
if still_training:
return valid_report
else:
if opt['model_file']:
# Write out metrics
f = open(opt['model_file'] + '.' + datatype, "a+")
f = open(opt['model_file'] + '.' + datatype, 'a+')
f.write(metrics + '\n')
f.close()

def main():
# Get command line arguments
parser = ParlaiParser(True, True)
parser.add_argument('-et', '--evaltask',
help=('task to use for valid/test (defaults to the ' +
'one used for training if not set)'))
parser.add_argument('-d', '--display-examples',
type='bool', default=False)
parser.add_argument('-e', '--num-epochs', type=int, default=1)
Expand All @@ -70,18 +75,18 @@ def main():
parser.add_argument('-ltim', '--log-every-n-secs',
type=float, default=1)
parser.add_argument('-vtim', '--validation-every-n-secs',
type=float, default=False)
type=float, default=0)
parser.add_argument('-vp', '--validation-patience',
type=int, default=5,
help=('number of iterations of validation where result '
+ 'does not improve before we stop training'))
parser.add_argument('-dbf', '--dict_build_first',
parser.add_argument('-dbf', '--dict-build-first',
type='bool', default=True,
help='build dictionary first before training agent')
opt = parser.parse_args()
# Possibly build a dictionary (not all models do this).
if opt['dict_build_first']:
if 'dict_file' not in opt and 'model_file' in opt:
if opt['dict_build_first'] and 'dict_file' in opt:
if opt['dict_file'] is None and opt.get('model_file'):
opt['dict_file'] = opt['model_file'] + '.dict'
build_dict.build_dict(opt)
# Create model and assign it to the specified task
Expand All @@ -91,7 +96,7 @@ def main():
train_time = Timer()
validate_time = Timer()
log_time = Timer()
print("[ training... ]")
print('[ training... ]')
parleys = 0
num_parleys = opt['num_epochs'] * int(len(world) / opt['batchsize'])
best_accuracy = 0
Expand All @@ -101,17 +106,17 @@ def main():
world.parley()
parleys = parleys + 1
if train_time.time() > opt['max_train_time']:
print("[ max_train_time elapsed: " + str(train_time.time()) + " ]")
print('[ max_train_time elapsed: ' + str(train_time.time()) + ' ]')
break
if log_time.time() > opt['log_every_n_secs']:
if opt['display_examples']:
print(world.display() + "\n~~")
print(world.display() + '\n~~')
parleys_per_sec = train_time.time() / parleys
time_left = (num_parleys - parleys) * parleys_per_sec
log = ("[ time:" + str(math.floor(train_time.time()))
+ "s parleys:" + str(parleys)
+ " time_left:"
+ str(math.floor(time_left)) + "s ]")
log = ('[ time:' + str(math.floor(train_time.time()))
+ 's parleys:' + str(parleys)
+ ' time_left:'
+ str(math.floor(time_left)) + 's ]')
if hasattr(agent, 'report'):
log = log + str(agent.report())
else:
Expand All @@ -125,7 +130,7 @@ def main():
if valid_report['accuracy'] > best_accuracy:
best_accuracy = valid_report['accuracy']
impatience = 0
print("[ new best accuracy: " + str(best_accuracy) + " ]")
print('[ new best accuracy: ' + str(best_accuracy) + ' ]')
if opt['model_file']:
agent.save(opt['model_file'])
saved = True
Expand All @@ -134,8 +139,8 @@ def main():
break
else:
impatience += 1
print("[ did not beat best accuracy: " + str(best_accuracy) +
" impatience: " + str(impatience) + " ]")
print('[ did not beat best accuracy: ' + str(best_accuracy) +
' impatience: ' + str(impatience) + ' ]')
validate_time.reset()
if impatience >= opt['validation_patience']:
print('[ ran out of patience! stopping. ]')
Expand Down
2 changes: 1 addition & 1 deletion parlai/agents/drqa/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def add_cmdline_args(parser):

def set_defaults(opt):
# Embeddings options
if 'embedding_file' in opt:
if opt.get('embedding_file'):
if not os.path.isfile(opt['embedding_file']):
raise IOError('No such file: %s' % args.embedding_file)
with open(opt['embedding_file']) as f:
Expand Down
2 changes: 1 addition & 1 deletion parlai/agents/drqa/drqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(self, *args, **kwargs):
super(SimpleDictionaryAgent, self).__init__(*args, **kwargs)

# Index words in embedding file
if self.opt['pretrained_words'] and 'embedding_file' in self.opt:
if self.opt['pretrained_words'] and self.opt.get('embedding_file'):
print('[ Indexing words with embeddings... ]')
self.embedding_words = set()
with open(self.opt['embedding_file']) as f:
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch.nn as nn
import torch
import copy
import os
import os
import random


Expand Down Expand Up @@ -74,7 +74,7 @@ def __init__(self, opt, shared=None):
}
if self.use_cuda:
self.cuda()
if 'model_file' in opt and os.path.isfile(opt['model_file']):
if opt.get('model_file') and os.path.isfile(opt['model_file']):
print('Loading existing model parameters from ' + opt['model_file'])
self.load(opt['model_file'])

Expand Down
12 changes: 7 additions & 5 deletions parlai/core/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,14 @@ def __init__(self, opt, shared=None):
index = len(self.tok2ind)
self.tok2ind[self.unk_token] = index
self.ind2tok[index] = self.unk_token

if 'dict_file' in opt and os.path.isfile(opt['dict_file']):
self.load(opt['dict_file'])
elif 'dict_initpath' in opt:

if opt.get('dict_file') and os.path.isfile(opt['dict_file']):
# load pre-existing dictionary
self.load(opt['dict_file'])
elif opt.get('dict_initpath'):
# load seed dictionary
self.load(opt['dict_initpath'])


# initialize tokenizers
st_path = 'tokenizers/punkt/{0}.pickle'.format(opt['dict_language'])
Expand All @@ -157,7 +159,7 @@ def __init__(self, opt, shared=None):
# fix count for unknown token to one billion
self.freq[self.unk_token] = 1000000000

if 'dict_file' in opt:
if opt.get('dict_file'):
self.save_path = opt['dict_file']

def __contains__(self, key):
Expand Down
5 changes: 2 additions & 3 deletions parlai/core/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def add_model_args(self, args=None):
'-m', '--model', default='repeat_label',
help='the model class name, should match parlai/agents/<model>')
self.add_argument(
'-mf', '--model-file', default='',
'-mf', '--model-file', default=None,
help='model file name for loading and saving models')
# Find which model specified, and add its specific arguments.
if args is None:
Expand All @@ -133,8 +133,7 @@ def parse_args(self, args=None, namespace=None, print_args=True):
We specifically remove items with ``None`` as values in order to support
the style ``opt.get(key, default)``, which would otherwise return ``None``.
"""
self.args = super().parse_args(args=args)
self.opt = {k: v for k, v in vars(self.args).items() if v is not None}
self.opt = vars(super().parse_args(args=args))
self.opt['parlai_home'] = self.parlai_home
if 'download_path' in self.opt:
self.opt['download_path'] = self.opt['download_path']
Expand Down

0 comments on commit 7d1a2ea

Please sign in to comment.