Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

train/dict updates and move seq2seq #153

Merged
merged 4 commits into from
Jun 21, 2017
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion examples/build_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
import os

def build_dict(opt):
if 'dict_file' not in opt:
if not opt.get('dict_file'):
print('Tried to build dictionary but `--dict-file` is not set. Set ' +
'this param so the dictionary can be saved.')
return
print('[ setting up dictionary. ]')
if os.path.isfile(opt['dict_file']):
Expand Down
51 changes: 28 additions & 23 deletions examples/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,25 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.
"""Train a model.
'''Train a model.

After training, computes validation and test error.

Run with, e.g.:

python examples/train_model.py -m ir_baseline -t dialog_babi:Task:1 -mf "/tmp/model"
python examples/train_model.py -m ir_baseline -t dialog_babi:Task:1 -mf '/tmp/model'

..or..

python examples/train_model.py -m rnn_baselines/seq2seq -t babi:Task10k:1 -mf "/tmp/model" -bs 32 -lr 0.5 -hs 128
python examples/train_model.py -m rnn_baselines/seq2seq -t babi:Task10k:1 -mf '/tmp/model' -bs 32 -lr 0.5 -hs 128

..or..

python examples/train_model.py -m drqa -t babi:Task10k:1 -mf "/tmp/model" -bs 10
python examples/train_model.py -m drqa -t babi:Task10k:1 -mf '/tmp/model' -bs 10

TODO List:
- More logging (e.g. to files), make things prettier.
"""
'''

from parlai.core.agents import create_agent
from parlai.core.worlds import create_task
Expand All @@ -36,32 +36,37 @@

def run_eval(agent, opt, datatype, still_training=False):
''' Eval on validation/test data. '''
print("[ running eval: " + datatype + " ]")
print('[ running eval: ' + datatype + ' ]')
opt['datatype'] = datatype
if opt.get('evaltask'):
opt['task'] = opt['evaltask']
valid_world = create_task(opt, agent)
for i in range(len(valid_world)):
valid_world.parley()
if i == 1 and opt['display_examples']:
print(valid_world.display() + "\n~~")
print(valid_world.display() + '\n~~')
print(valid_world.report())
if valid_world.epoch_done():
break
valid_world.shutdown()
valid_report = valid_world.report()
metrics = datatype + ":" + str(valid_report)
metrics = datatype + ':' + str(valid_report)
print(metrics)
if still_training:
return valid_report
else:
if opt['model_file']:
# Write out metrics
f = open(opt['model_file'] + '.' + datatype, "a+")
f = open(opt['model_file'] + '.' + datatype, 'a+')
f.write(metrics + '\n')
f.close()

def main():
# Get command line arguments
parser = ParlaiParser(True, True)
parser.add_argument('-et', '--evaltask',
help=('task to use for valid/test (defaults to the ' +
'one used for training if not set)'))
parser.add_argument('-d', '--display-examples',
type='bool', default=False)
parser.add_argument('-e', '--num-epochs', type=int, default=1)
Expand All @@ -70,18 +75,18 @@ def main():
parser.add_argument('-ltim', '--log-every-n-secs',
type=float, default=1)
parser.add_argument('-vtim', '--validation-every-n-secs',
type=float, default=False)
type=float, default=0)
parser.add_argument('-vp', '--validation-patience',
type=int, default=5,
help=('number of iterations of validation where result '
+ 'does not improve before we stop training'))
parser.add_argument('-dbf', '--dict_build_first',
parser.add_argument('-dbf', '--dict-build-first',
type='bool', default=True,
help='build dictionary first before training agent')
opt = parser.parse_args()
# Possibly build a dictionary (not all models do this).
if opt['dict_build_first']:
if 'dict_file' not in opt and 'model_file' in opt:
if opt['dict_build_first'] and 'dict_file' in opt:
if opt['dict_file'] is None and opt.get('model_file'):
opt['dict_file'] = opt['model_file'] + '.dict'
build_dict.build_dict(opt)
# Create model and assign it to the specified task
Expand All @@ -91,7 +96,7 @@ def main():
train_time = Timer()
validate_time = Timer()
log_time = Timer()
print("[ training... ]")
print('[ training... ]')
parleys = 0
num_parleys = opt['num_epochs'] * int(len(world) / opt['batchsize'])
best_accuracy = 0
Expand All @@ -101,17 +106,17 @@ def main():
world.parley()
parleys = parleys + 1
if train_time.time() > opt['max_train_time']:
print("[ max_train_time elapsed: " + str(train_time.time()) + " ]")
print('[ max_train_time elapsed: ' + str(train_time.time()) + ' ]')
break
if log_time.time() > opt['log_every_n_secs']:
if opt['display_examples']:
print(world.display() + "\n~~")
print(world.display() + '\n~~')
parleys_per_sec = train_time.time() / parleys
time_left = (num_parleys - parleys) * parleys_per_sec
log = ("[ time:" + str(math.floor(train_time.time()))
+ "s parleys:" + str(parleys)
+ " time_left:"
+ str(math.floor(time_left)) + "s ]")
log = ('[ time:' + str(math.floor(train_time.time()))
+ 's parleys:' + str(parleys)
+ ' time_left:'
+ str(math.floor(time_left)) + 's ]')
if hasattr(agent, 'report'):
log = log + str(agent.report())
else:
Expand All @@ -125,7 +130,7 @@ def main():
if valid_report['accuracy'] > best_accuracy:
best_accuracy = valid_report['accuracy']
impatience = 0
print("[ new best accuracy: " + str(best_accuracy) + " ]")
print('[ new best accuracy: ' + str(best_accuracy) + ' ]')
if opt['model_file']:
agent.save(opt['model_file'])
saved = True
Expand All @@ -134,8 +139,8 @@ def main():
break
else:
impatience += 1
print("[ did not beat best accuracy: " + str(best_accuracy) +
" impatience: " + str(impatience) + " ]")
print('[ did not beat best accuracy: ' + str(best_accuracy) +
' impatience: ' + str(impatience) + ' ]')
validate_time.reset()
if impatience >= opt['validation_patience']:
print('[ ran out of patience! stopping. ]')
Expand Down
2 changes: 1 addition & 1 deletion parlai/agents/drqa/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def add_cmdline_args(parser):

def set_defaults(opt):
# Embeddings options
if 'embedding_file' in opt:
if opt.get('embedding_file'):
if not os.path.isfile(opt['embedding_file']):
raise IOError('No such file: %s' % args.embedding_file)
with open(opt['embedding_file']) as f:
Expand Down
2 changes: 1 addition & 1 deletion parlai/agents/drqa/drqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(self, *args, **kwargs):
super(SimpleDictionaryAgent, self).__init__(*args, **kwargs)

# Index words in embedding file
if self.opt['pretrained_words'] and 'embedding_file' in self.opt:
if self.opt['pretrained_words'] and self.opt.get('embedding_file'):
print('[ Indexing words with embeddings... ]')
self.embedding_words = set()
with open(self.opt['embedding_file']) as f:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch.nn as nn
import torch
import copy
import os
import os
import random


Expand Down Expand Up @@ -74,7 +74,7 @@ def __init__(self, opt, shared=None):
}
if self.use_cuda:
self.cuda()
if 'model_file' in opt and os.path.isfile(opt['model_file']):
if opt.get('model_file') and os.path.isfile(opt['model_file']):
print('Loading existing model parameters from ' + opt['model_file'])
self.load(opt['model_file'])

Expand Down
11 changes: 6 additions & 5 deletions parlai/core/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,12 @@ def __init__(self, opt, shared=None):
self.tok2ind[self.unk_token] = index
self.ind2tok[index] = self.unk_token

if 'dict_file' in opt and os.path.isfile(opt['dict_file']):
self.load(opt['dict_file'])
elif 'dict_initpath' in opt:
# load pre-existing dictionary
if opt.get('dict_initpath'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is different behavior to what i wanted, let us discuss

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that is, you should only load the init if you don't have the actual real dict (same thing with models usually)

# load seed dictionary
self.load(opt['dict_initpath'])
if opt.get('dict_file') and os.path.isfile(opt['dict_file']):
# load pre-existing dictionary
self.load(opt['dict_file'])

# initialize tokenizers
st_path = 'tokenizers/punkt/{0}.pickle'.format(opt['dict_language'])
Expand All @@ -157,7 +158,7 @@ def __init__(self, opt, shared=None):
# fix count for unknown token to one billion
self.freq[self.unk_token] = 1000000000

if 'dict_file' in opt:
if opt.get('dict_file'):
self.save_path = opt['dict_file']

def __contains__(self, key):
Expand Down
5 changes: 2 additions & 3 deletions parlai/core/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def add_model_args(self, args=None):
'-m', '--model', default='repeat_label',
help='the model class name, should match parlai/agents/<model>')
self.add_argument(
'-mf', '--model-file', default='',
'-mf', '--model-file', default=None,
help='model file name for loading and saving models')
# Find which model specified, and add its specific arguments.
if args is None:
Expand All @@ -133,8 +133,7 @@ def parse_args(self, args=None, namespace=None, print_args=True):
We specifically remove items with ``None`` as values in order to support
the style ``opt.get(key, default)``, which would otherwise return ``None``.
"""
self.args = super().parse_args(args=args)
self.opt = {k: v for k, v in vars(self.args).items() if v is not None}
self.opt = vars(super().parse_args(args=args))
self.opt['parlai_home'] = self.parlai_home
if 'download_path' in self.opt:
self.opt['download_path'] = self.opt['download_path']
Expand Down
1 change: 0 additions & 1 deletion parlai/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,3 @@ def time(self):
if self.running:
return self.total + time.time() - self.start
return self.total