diff --git a/examples/build_dict.py b/examples/build_dict.py index 2213ab07752..3e62d074b94 100644 --- a/examples/build_dict.py +++ b/examples/build_dict.py @@ -14,7 +14,9 @@ import os def build_dict(opt): - if 'dict_file' not in opt: + if not opt.get('dict_file'): + print('Tried to build dictionary but `--dict-file` is not set. Set ' + + 'this param so the dictionary can be saved.') return print('[ setting up dictionary. ]') if os.path.isfile(opt['dict_file']): diff --git a/examples/train_model.py b/examples/train_model.py index 315f08e12e3..08b66e37d61 100644 --- a/examples/train_model.py +++ b/examples/train_model.py @@ -3,25 +3,25 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. An additional grant # of patent rights can be found in the PATENTS file in the same directory. -"""Train a model. +'''Train a model. After training, computes validation and test error. Run with, e.g.: -python examples/train_model.py -m ir_baseline -t dialog_babi:Task:1 -mf "/tmp/model" +python examples/train_model.py -m ir_baseline -t dialog_babi:Task:1 -mf '/tmp/model' ..or.. -python examples/train_model.py -m rnn_baselines/seq2seq -t babi:Task10k:1 -mf "/tmp/model" -bs 32 -lr 0.5 -hs 128 +python examples/train_model.py -m rnn_baselines/seq2seq -t babi:Task10k:1 -mf '/tmp/model' -bs 32 -lr 0.5 -hs 128 ..or.. -python examples/train_model.py -m drqa -t babi:Task10k:1 -mf "/tmp/model" -bs 10 +python examples/train_model.py -m drqa -t babi:Task10k:1 -mf '/tmp/model' -bs 10 TODO List: - More logging (e.g. to files), make things prettier. -""" +''' from parlai.core.agents import create_agent from parlai.core.worlds import create_task @@ -36,32 +36,37 @@ def run_eval(agent, opt, datatype, still_training=False): ''' Eval on validation/test data. ''' - print("[ running eval: " + datatype + " ]") + print('[ running eval: ' + datatype + ' ]') opt['datatype'] = datatype + if opt.get('evaltask'): + opt['task'] = opt['evaltask'] valid_world = create_task(opt, agent) for i in range(len(valid_world)): valid_world.parley() if i == 1 and opt['display_examples']: - print(valid_world.display() + "\n~~") + print(valid_world.display() + '\n~~') print(valid_world.report()) if valid_world.epoch_done(): break valid_world.shutdown() valid_report = valid_world.report() - metrics = datatype + ":" + str(valid_report) + metrics = datatype + ':' + str(valid_report) print(metrics) if still_training: return valid_report else: if opt['model_file']: # Write out metrics - f = open(opt['model_file'] + '.' + datatype, "a+") + f = open(opt['model_file'] + '.' + datatype, 'a+') f.write(metrics + '\n') f.close() def main(): # Get command line arguments parser = ParlaiParser(True, True) + parser.add_argument('-et', '--evaltask', + help=('task to use for valid/test (defaults to the ' + + 'one used for training if not set)')) parser.add_argument('-d', '--display-examples', type='bool', default=False) parser.add_argument('-e', '--num-epochs', type=int, default=1) @@ -70,18 +75,18 @@ def main(): parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=1) parser.add_argument('-vtim', '--validation-every-n-secs', - type=float, default=False) + type=float, default=0) parser.add_argument('-vp', '--validation-patience', type=int, default=5, help=('number of iterations of validation where result ' + 'does not improve before we stop training')) - parser.add_argument('-dbf', '--dict_build_first', + parser.add_argument('-dbf', '--dict-build-first', type='bool', default=True, help='build dictionary first before training agent') opt = parser.parse_args() # Possibly build a dictionary (not all models do this). - if opt['dict_build_first']: - if 'dict_file' not in opt and 'model_file' in opt: + if opt['dict_build_first'] and 'dict_file' in opt: + if opt['dict_file'] is None and opt.get('model_file'): opt['dict_file'] = opt['model_file'] + '.dict' build_dict.build_dict(opt) # Create model and assign it to the specified task @@ -91,7 +96,7 @@ def main(): train_time = Timer() validate_time = Timer() log_time = Timer() - print("[ training... ]") + print('[ training... ]') parleys = 0 num_parleys = opt['num_epochs'] * int(len(world) / opt['batchsize']) best_accuracy = 0 @@ -101,17 +106,17 @@ def main(): world.parley() parleys = parleys + 1 if train_time.time() > opt['max_train_time']: - print("[ max_train_time elapsed: " + str(train_time.time()) + " ]") + print('[ max_train_time elapsed: ' + str(train_time.time()) + ' ]') break if log_time.time() > opt['log_every_n_secs']: if opt['display_examples']: - print(world.display() + "\n~~") + print(world.display() + '\n~~') parleys_per_sec = train_time.time() / parleys time_left = (num_parleys - parleys) * parleys_per_sec - log = ("[ time:" + str(math.floor(train_time.time())) - + "s parleys:" + str(parleys) - + " time_left:" - + str(math.floor(time_left)) + "s ]") + log = ('[ time:' + str(math.floor(train_time.time())) + + 's parleys:' + str(parleys) + + ' time_left:' + + str(math.floor(time_left)) + 's ]') if hasattr(agent, 'report'): log = log + str(agent.report()) else: @@ -125,7 +130,7 @@ def main(): if valid_report['accuracy'] > best_accuracy: best_accuracy = valid_report['accuracy'] impatience = 0 - print("[ new best accuracy: " + str(best_accuracy) + " ]") + print('[ new best accuracy: ' + str(best_accuracy) + ' ]') if opt['model_file']: agent.save(opt['model_file']) saved = True @@ -134,8 +139,8 @@ def main(): break else: impatience += 1 - print("[ did not beat best accuracy: " + str(best_accuracy) + - " impatience: " + str(impatience) + " ]") + print('[ did not beat best accuracy: ' + str(best_accuracy) + + ' impatience: ' + str(impatience) + ' ]') validate_time.reset() if impatience >= opt['validation_patience']: print('[ ran out of patience! stopping. ]') diff --git a/parlai/agents/drqa/config.py b/parlai/agents/drqa/config.py index fba471d6d13..cace87e2faf 100644 --- a/parlai/agents/drqa/config.py +++ b/parlai/agents/drqa/config.py @@ -82,7 +82,7 @@ def add_cmdline_args(parser): def set_defaults(opt): # Embeddings options - if 'embedding_file' in opt: + if opt.get('embedding_file'): if not os.path.isfile(opt['embedding_file']): raise IOError('No such file: %s' % args.embedding_file) with open(opt['embedding_file']) as f: diff --git a/parlai/agents/drqa/drqa.py b/parlai/agents/drqa/drqa.py index f24d8b64b5b..b078a32b8f5 100644 --- a/parlai/agents/drqa/drqa.py +++ b/parlai/agents/drqa/drqa.py @@ -59,7 +59,7 @@ def __init__(self, *args, **kwargs): super(SimpleDictionaryAgent, self).__init__(*args, **kwargs) # Index words in embedding file - if self.opt['pretrained_words'] and 'embedding_file' in self.opt: + if self.opt['pretrained_words'] and self.opt.get('embedding_file'): print('[ Indexing words with embeddings... ]') self.embedding_words = set() with open(self.opt['embedding_file']) as f: diff --git a/parlai/agents/rnn_baselines/__init__.py b/parlai/agents/seq2seq/__init__.py similarity index 100% rename from parlai/agents/rnn_baselines/__init__.py rename to parlai/agents/seq2seq/__init__.py diff --git a/parlai/agents/rnn_baselines/seq2seq.py b/parlai/agents/seq2seq/seq2seq.py similarity index 99% rename from parlai/agents/rnn_baselines/seq2seq.py rename to parlai/agents/seq2seq/seq2seq.py index ac60bd51cb9..e920b10be42 100644 --- a/parlai/agents/rnn_baselines/seq2seq.py +++ b/parlai/agents/seq2seq/seq2seq.py @@ -12,7 +12,7 @@ import torch.nn as nn import torch import copy -import os +import os import random @@ -74,7 +74,7 @@ def __init__(self, opt, shared=None): } if self.use_cuda: self.cuda() - if 'model_file' in opt and os.path.isfile(opt['model_file']): + if opt.get('model_file') and os.path.isfile(opt['model_file']): print('Loading existing model parameters from ' + opt['model_file']) self.load(opt['model_file']) diff --git a/parlai/core/dict.py b/parlai/core/dict.py index 0e99ca1f5d1..99a7cc60aef 100644 --- a/parlai/core/dict.py +++ b/parlai/core/dict.py @@ -126,12 +126,14 @@ def __init__(self, opt, shared=None): index = len(self.tok2ind) self.tok2ind[self.unk_token] = index self.ind2tok[index] = self.unk_token - - if 'dict_file' in opt and os.path.isfile(opt['dict_file']): - self.load(opt['dict_file']) - elif 'dict_initpath' in opt: + + if opt.get('dict_file') and os.path.isfile(opt['dict_file']): # load pre-existing dictionary + self.load(opt['dict_file']) + elif opt.get('dict_initpath'): + # load seed dictionary self.load(opt['dict_initpath']) + # initialize tokenizers st_path = 'tokenizers/punkt/{0}.pickle'.format(opt['dict_language']) @@ -157,7 +159,7 @@ def __init__(self, opt, shared=None): # fix count for unknown token to one billion self.freq[self.unk_token] = 1000000000 - if 'dict_file' in opt: + if opt.get('dict_file'): self.save_path = opt['dict_file'] def __contains__(self, key): diff --git a/parlai/core/params.py b/parlai/core/params.py index dfabd63fb27..04cbd5e0714 100644 --- a/parlai/core/params.py +++ b/parlai/core/params.py @@ -110,7 +110,7 @@ def add_model_args(self, args=None): '-m', '--model', default='repeat_label', help='the model class name, should match parlai/agents/') self.add_argument( - '-mf', '--model-file', default='', + '-mf', '--model-file', default=None, help='model file name for loading and saving models') # Find which model specified, and add its specific arguments. if args is None: @@ -133,8 +133,7 @@ def parse_args(self, args=None, namespace=None, print_args=True): We specifically remove items with ``None`` as values in order to support the style ``opt.get(key, default)``, which would otherwise return ``None``. """ - self.args = super().parse_args(args=args) - self.opt = {k: v for k, v in vars(self.args).items() if v is not None} + self.opt = vars(super().parse_args(args=args)) self.opt['parlai_home'] = self.parlai_home if 'download_path' in self.opt: self.opt['download_path'] = self.opt['download_path']