diff --git a/examples/train_model.py b/examples/train_model.py index 90f5c297229..8223113d012 100644 --- a/examples/train_model.py +++ b/examples/train_model.py @@ -3,7 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. An additional grant # of patent rights can be found in the PATENTS file in the same directory. -'''Train a model. +"""Train a model. After training, computes validation and test error. @@ -21,31 +21,38 @@ TODO List: - More logging (e.g. to files), make things prettier. -''' +""" from parlai.core.agents import create_agent from parlai.core.worlds import create_task from parlai.core.params import ParlaiParser from parlai.core.utils import Timer import build_dict -import copy -import importlib import math -import os -def run_eval(agent, opt, datatype, still_training=False, max_exs=-1): - ''' Eval on validation/test data. ''' +def run_eval(agent, opt, datatype, max_exs=-1, write_log=False, valid_world=None): + """Eval on validation/test data. + - Agent is the agent to use for the evaluation. + - opt is the options that specific the task, eval_task, etc + - datatype is the datatype to use, such as "valid" or "test" + - write_log specifies to write metrics to file if the model_file is set + - max_exs limits the number of examples if max_exs > 0 + - valid_world can be an existing world which will be reset instead of reinitialized + """ print('[ running eval: ' + datatype + ' ]') opt['datatype'] = datatype if opt.get('evaltask'): + opt['task'] = opt['evaltask'] - valid_world = create_task(opt, agent) + if valid_world is None: + valid_world = create_task(opt, agent) + else: + valid_world.reset() cnt = 0 for _ in valid_world: valid_world.parley() if cnt == 0 and opt['display_examples']: - first_run = False print(valid_world.display() + '\n~~') print(valid_world.report()) cnt += opt['batchsize'] @@ -53,19 +60,19 @@ def run_eval(agent, opt, datatype, still_training=False, max_exs=-1): # note this max_exs is approximate--some batches won't always be # full depending on the structure of the data break - valid_world.shutdown() valid_report = valid_world.report() metrics = datatype + ':' + str(valid_report) print(metrics) - if still_training: - return valid_report - elif opt['model_file']: + if write_log and opt['model_file']: # Write out metrics f = open(opt['model_file'] + '.' + datatype, 'a+') f.write(metrics + '\n') f.close() + return valid_report, valid_world + + def main(): # Get command line arguments parser = ParlaiParser(True, True) @@ -115,6 +122,7 @@ def main(): best_accuracy = 0 impatience = 0 saved = False + valid_world = None while True: world.parley() parleys += 1 @@ -149,7 +157,7 @@ def main(): # check if we should log amount of time remaining time_left = None if opt['num_epochs'] > 0: - exs_per_sec = train_time.time() / total_exs + exs_per_sec = train_time.time() / total_exs time_left = (max_exs - total_exs) * exs_per_sec if opt['max_train_time'] > 0: other_time_left = opt['max_train_time'] - train_time.time() @@ -168,11 +176,11 @@ def main(): if (opt['validation_every_n_secs'] > 0 and validate_time.time() > opt['validation_every_n_secs']): - valid_report = run_eval(agent, opt, 'valid', True, opt['validation_max_exs']) + valid_report, valid_world = run_eval(agent, opt, 'valid', opt['validation_max_exs'], valid_world=valid_world) if valid_report['accuracy'] > best_accuracy: best_accuracy = valid_report['accuracy'] impatience = 0 - print('[ new best accuracy: ' + str(best_accuracy) + ' ]') + print('[ new best accuracy: ' + str(best_accuracy) + ' ]') world.save_agents() saved = True if best_accuracy == 1: @@ -193,8 +201,8 @@ def main(): # reload best validation model agent = create_agent(opt) - run_eval(agent, opt, 'valid') - run_eval(agent, opt, 'test') + run_eval(agent, opt, 'valid', write_log=True) + run_eval(agent, opt, 'test', write_log=True) if __name__ == '__main__': diff --git a/parlai/agents/seq2seq/seq2seq.py b/parlai/agents/seq2seq/seq2seq.py index 8333b27c78c..bb8f86ab7f8 100644 --- a/parlai/agents/seq2seq/seq2seq.py +++ b/parlai/agents/seq2seq/seq2seq.py @@ -30,7 +30,7 @@ def add_cmdline_args(argparser): """Add command-line arguments specifically for this agent.""" DictionaryAgent.add_cmdline_args(argparser) agent = argparser.add_argument_group('Seq2Seq Arguments') - agent.add_argument('-hs', '--hiddensize', type=int, default=64, + agent.add_argument('-hs', '--hiddensize', type=int, default=128, help='size of the hidden layers and embeddings') agent.add_argument('-nl', '--numlayers', type=int, default=2, help='number of hidden layers') @@ -38,10 +38,16 @@ def add_cmdline_args(argparser): help='learning rate') agent.add_argument('-dr', '--dropout', type=float, default=0.1, help='dropout rate') + # agent.add_argument('-bi', '--bidirectional', type='bool', default=False, + # help='whether to encode the context with a bidirectional RNN') agent.add_argument('--no-cuda', action='store_true', default=False, help='disable GPUs even if available') agent.add_argument('--gpu', type=int, default=-1, help='which GPU device to use') + agent.add_argument('-r', '--rank-candidates', type='bool', default=False, + help='rank candidates if available. this is done by computing the' + + ' mean score per token for each candidate and selecting the ' + + 'highest scoring one.') def __init__(self, opt, shared=None): # initialize defaults first @@ -49,6 +55,13 @@ def __init__(self, opt, shared=None): if not shared: # this is not a shared instance of this class, so do full # initialization. if shared is set, only set up shared members. + + # check for cuda + self.use_cuda = not opt.get('no_cuda') and torch.cuda.is_available() + if self.use_cuda: + print('[ Using CUDA ]') + torch.cuda.set_device(opt['gpu']) + if opt.get('model_file') and os.path.isfile(opt['model_file']): # load model parameters if available print('Loading existing model params from ' + opt['model_file']) @@ -62,18 +75,22 @@ def __init__(self, opt, shared=None): self.END = self.dict.end_token self.observation = {'text': self.END, 'episode_done': True} self.END_TENSOR = torch.LongTensor(self.dict.parse(self.END)) + # get index of null token from dictionary (probably 0) + self.NULL_IDX = self.dict.txt2vec(self.dict.null_token)[0] # store important params directly hsz = opt['hiddensize'] self.hidden_size = hsz self.num_layers = opt['numlayers'] self.learning_rate = opt['learningrate'] + self.rank = opt['rank_candidates'] self.longest_label = 1 # set up modules self.criterion = nn.NLLLoss() # lookup table stores word embeddings - self.lt = nn.Embedding(len(self.dict), hsz, padding_idx=0, + self.lt = nn.Embedding(len(self.dict), hsz, + padding_idx=self.NULL_IDX, scale_grad_by_freq=True) # encoder captures the input text self.encoder = nn.GRU(hsz, hsz, opt['numlayers']) @@ -99,11 +116,6 @@ def __init__(self, opt, shared=None): # set loaded states if applicable self.set_states(self.states) - # check for cuda - self.use_cuda = not opt.get('no_cuda') and torch.cuda.is_available() - if self.use_cuda: - print('[ Using CUDA ]') - torch.cuda.set_device(opt['gpu']) if self.use_cuda: self.cuda() @@ -121,7 +133,7 @@ def override_opt(self, new_opt): return self.opt def parse(self, text): - return torch.LongTensor(self.dict.txt2vec(text)) + return self.dict.txt2vec(text) def v2t(self, vec): return self.dict.vec2txt(vec) @@ -180,11 +192,12 @@ def observe(self, observation): self.episode_done = observation['episode_done'] return observation - def predict(self, xs, ys=None): + def predict(self, xs, ys=None, cands=None): """Produce a prediction from our model. Update the model using the targets if available. """ batchsize = len(xs) + text_cand_inds = None # first encode context xes = self.lt(xs).t() @@ -220,17 +233,66 @@ def predict(self, xs, ys=None): loss.backward() self.update_params() + + if random.random() < 0.1: + # sometimes output a prediction for debugging + print('prediction:', ' '.join(output_lines[0]), + '\nlabel:', self.dict.vec2txt(ys.data[0])) else: # just produce a prediction without training the model done = [False for _ in range(batchsize)] total_done = 0 max_len = 0 + if cands: + # score each candidate separately + + # cands are exs_with_cands x cands_per_ex x words_per_cand + # cview is total_cands x words_per_cand + cview = cands.view(-1, cands.size(2)) + cands_xes = xe.expand(xe.size(0), cview.size(0), xe.size(2)) + sz = hn.size() + cands_hn = ( + hn.view(sz[0], sz[1], 1, sz[2]) + .expand(sz[0], sz[1], cands.size(1), sz[2]) + .contiguous() + .view(sz[0], -1, sz[2]) + ) + + cand_scores = torch.zeros(cview.size(0)) + cand_lengths = torch.LongTensor(cview.size(0)).fill_(0) + if self.use_cuda: + cand_scores = cand_scores.cuda(async=True) + cand_lengths = cand_lengths.cuda(async=True) + cand_scores = Variable(cand_scores) + cand_lengths = Variable(cand_lengths) + + for i in range(cview.size(1)): + output, cands_hn = self.decoder(cands_xes, cands_hn) + preds, scores = self.hidden_to_idx(output, dropout=False) + cs = cview.select(1, i) + non_nulls = cs.ne(self.NULL_IDX) + cand_lengths += non_nulls.long() + score_per_cand = torch.gather(scores, 1, cs.unsqueeze(1)) + cand_scores += score_per_cand.squeeze() * non_nulls.float() + cands_xes = self.lt(cs).unsqueeze(0) + + # set empty scores to -1, so when divided by 0 they become -inf + cand_scores -= cand_lengths.eq(0).float() + # average the scores per token + cand_scores /= cand_lengths.float() + + cand_scores = cand_scores.view(cands.size(0), cands.size(1)) + srtd_scores, text_cand_inds = cand_scores.sort(1, True) + text_cand_inds = text_cand_inds.data + + # now, generate a response from scratch while(total_done < batchsize) and max_len < self.longest_label: # keep producing tokens until we hit END or max length for each # example in the batch output, hn = self.decoder(xes, hn) preds, scores = self.hidden_to_idx(output, dropout=False) + xes = self.lt(preds.t()) max_len += 1 for b in range(batchsize): @@ -243,11 +305,12 @@ def predict(self, xs, ys=None): total_done += 1 else: output_lines[b].append(token) + if random.random() < 0.1: + # sometimes output a prediction for debugging print('prediction:', ' '.join(output_lines[0])) - return output_lines - + return output_lines, text_cand_inds def batchify(self, observations): """Convert a list of observations into input & target tensors.""" @@ -259,24 +322,26 @@ def batchify(self, observations): # set up the input tensors batchsize = len(exs) # tokenize the text - parsed = [self.parse(ex['text']) for ex in exs] - max_x_len = max([len(x) for x in parsed]) - xs = torch.LongTensor(batchsize, max_x_len).fill_(0) - # pack the data to the right side of the tensor for this model - for i, x in enumerate(parsed): - offset = max_x_len - len(x) - for j, idx in enumerate(x): - xs[i][j + offset] = idx - if self.use_cuda: - xs = xs.cuda(async=True) - xs = Variable(xs) + xs = None + if batchsize > 0: + parsed = [self.parse(ex['text']) for ex in exs] + max_x_len = max([len(x) for x in parsed]) + xs = torch.LongTensor(batchsize, max_x_len).fill_(0) + # pack the data to the right side of the tensor for this model + for i, x in enumerate(parsed): + offset = max_x_len - len(x) + for j, idx in enumerate(x): + xs[i][j + offset] = idx + if self.use_cuda: + xs = xs.cuda(async=True) + xs = Variable(xs) # set up the target tensors ys = None - if 'labels' in exs[0]: + if batchsize > 0 and any(['labels' in ex for ex in exs]): # randomly select one of the labels to update on, if multiple # append END to each label - labels = [random.choice(ex['labels']) + ' ' + self.END for ex in exs] + labels = [random.choice(ex.get('labels', [''])) + ' ' + self.END for ex in exs] parsed = [self.parse(y) for y in labels] max_y_len = max(len(y) for y in parsed) ys = torch.LongTensor(batchsize, max_y_len).fill_(0) @@ -286,7 +351,37 @@ def batchify(self, observations): if self.use_cuda: ys = ys.cuda(async=True) ys = Variable(ys) - return xs, ys, valid_inds + + # set up candidates + cands = None + valid_cands = None + if ys is None and self.rank: + # only do ranking when no targets available and ranking flag set + parsed = [] + valid_cands = [] + for i in valid_inds: + if 'label_candidates' in observations[i]: + # each candidate tuple is a pair of the parsed version and + # the original full string + cs = list(observations[i]['label_candidates']) + parsed.append([self.parse(c) for c in cs]) + valid_cands.append((i, cs)) + if len(parsed) > 0: + # TODO: store lengths of cands separately, so don't have zero + # padding for varying number of cands per example + # found cands, pack them into tensor + max_c_len = max(max(len(c) for c in cs) for cs in parsed) + max_c_cnt = max(len(cs) for cs in parsed) + cands = torch.LongTensor(len(parsed), max_c_cnt, max_c_len).fill_(0) + for i, cs in enumerate(parsed): + for j, c in enumerate(cs): + for k, idx in enumerate(c): + cands[i][j][k] = idx + if self.use_cuda: + cands = cands.cuda(async=True) + cands = Variable(cands) + + return xs, ys, valid_inds, cands, valid_cands def batch_act(self, observations): batchsize = len(observations) @@ -297,20 +392,29 @@ def batch_act(self, observations): # valid_inds tells us the indices of all valid examples # e.g. for input [{}, {'text': 'hello'}, {}, {}], valid_inds is [1] # since the other three elements had no 'text' field - xs, ys, valid_inds = self.batchify(observations) + xs, ys, valid_inds, cands, valid_cands = self.batchify(observations) - if len(xs) == 0: + if xs is None: # no valid examples, just return the empty responses we set up return batch_reply # produce predictions either way, but use the targets if available - predictions = self.predict(xs, ys) + predictions, text_cand_inds = self.predict(xs, ys, cands) for i in range(len(predictions)): # map the predictions back to non-empty examples in the batch # we join with spaces since we produce tokens one at a time - batch_reply[valid_inds[i]]['text'] = ' '.join( - c for c in predictions[i] if c != self.END) + curr = batch_reply[valid_inds[i]] + curr['text'] = ' '.join(c for c in predictions[i] if c != self.END + and c != self.dict.null_token) + + if text_cand_inds is not None: + for i in range(len(valid_cands)): + order = text_cand_inds[i] + batch_idx, curr_cands = valid_cands[i] + curr = batch_reply[batch_idx] + curr['text_candidates'] = [curr_cands[idx] for idx in order + if idx < len(curr_cands)] return batch_reply diff --git a/parlai/core/dialog_teacher.py b/parlai/core/dialog_teacher.py index b636a6145e8..f0743230358 100644 --- a/parlai/core/dialog_teacher.py +++ b/parlai/core/dialog_teacher.py @@ -46,11 +46,11 @@ def __init__(self, opt, shared=None): # first initialize any shared objects self.random = self.datatype == 'train' if shared and shared.get('data'): - self.data = DialogData(opt, None, cands=self.label_candidates(), - shared=shared['data'].share()) + self.data = DialogData(opt, shared=shared['data']) else: - self.data = DialogData(opt, self.setup_data(opt['datafile']), - cands=self.label_candidates()) + self.data = DialogData(opt, + data_loader=self.setup_data(opt['datafile']), + cands=self.label_candidates()) # for ordered data in batch mode (especially, for validation and # testing), each teacher in the batch gets a start index and a step @@ -85,7 +85,7 @@ def __next__(self): def share(self): shared = super().share() - shared['data'] = self.data + shared['data'] = self.data.share() return shared def label_candidates(self): @@ -178,15 +178,15 @@ class at request-time. should always point to the raw image file. or randomly when returning examples to the caller. """ - def __init__(self, opt, data_loader, cands=None, shared=None): + def __init__(self, opt, data_loader=None, cands=None, shared=None): # self.data is a list of episodes # each episode is a tuple of entries # each entry is a tuple of values for the action/observation table self.opt = opt if shared: + self.image_loader = shared.get('image_loader', None) self.data = shared.get('data', []) self.cands = shared.get('cands', None) - self.image_loader = shared.get('image_loader', None) else: self.image_loader = ImageLoader(opt) self.data = [] @@ -196,8 +196,11 @@ def __init__(self, opt, data_loader, cands=None, shared=None): self.copied_cands = False def share(self): - shared = {'data': self.data, 'cands': self.cands, - 'image_loader': self.image_loader} + shared = { + 'data': self.data, + 'cands': self.cands, + 'image_loader': self.image_loader + } return shared def __len__(self): diff --git a/parlai/core/metrics.py b/parlai/core/metrics.py index fae4b0afe46..e1aee4a11e8 100644 --- a/parlai/core/metrics.py +++ b/parlai/core/metrics.py @@ -13,19 +13,20 @@ from collections import Counter import re -import string - +re_art = re.compile(r'\b(a|an|the)\b') +re_punc = re.compile(r'[!"#$%&()*+,-./:;<=>?@\[\]\\^`{|}~]') def _normalize_answer(s): """Lower text and remove punctuation, articles and extra whitespace.""" def remove_articles(text): - return re.sub(r'\b(a|an|the)\b', ' ', text) + return re_art.sub(' ', text) def white_space_fix(text): return ' '.join(text.split()) def remove_punc(text): - exclude = set(string.punctuation) + text = re_punc.sub(' ', text) # convert interword punctuation to spaces + exclude = set('_\'') # remove intraword punctuation completely return ''.join(ch for ch in text if ch not in exclude) def lower(text): @@ -110,12 +111,12 @@ def update_ranking_metrics(self, observation, labels): # Now loop through text candidates, assuming they are sorted. # If any of them is a label then score a point. # maintain hits@1, 5, 10, 50, 100, etc. - label_set = set(labels) if type(labels) != set else labels + label_set = set(_normalize_answer(l) for l in labels) cnts = {k: 0 for k in self.eval_pr} cnt = 0 for c in text_cands: cnt += 1 - if c in label_set: + if _normalize_answer(c) in label_set: for k in self.eval_pr: if cnt <= k: cnts[k] += 1 @@ -127,7 +128,6 @@ def update_ranking_metrics(self, observation, labels): if cnts[k] > 0: self.metrics['hits@' + str(k)] += 1 - def update(self, observation, labels): with self._lock(): self.metrics['cnt'] += 1 diff --git a/parlai/core/params.py b/parlai/core/params.py index 97971d9ba66..19975237fb9 100644 --- a/parlai/core/params.py +++ b/parlai/core/params.py @@ -233,4 +233,3 @@ def print_args(self): print('[ ' + group.title + ': ] ') count += 1 print('[ ' + key + ': ' + values[key] + ' ]') - diff --git a/parlai/tasks/babi/agents.py b/parlai/tasks/babi/agents.py index 62d1449d0d4..77d8409df20 100644 --- a/parlai/tasks/babi/agents.py +++ b/parlai/tasks/babi/agents.py @@ -11,6 +11,7 @@ import copy import os + def _path(exsz, task, opt, dt=''): # Build the data if it doesn't exist. build(opt) @@ -21,23 +22,58 @@ def _path(exsz, task, opt, dt=''): 'qa{task}_{type}.txt'.format(task=task, type=dt)) +def mod_labels(ys, task): + if ys is not None: + # replace comma-labeled babi tasks with spaces + # this is more friendly to our tokenizer which makes commas full tokens + # this way models won't be penalized for not generating a comma + if task == '8': + # holding: labels like 'milk,cookies,football' + # replace with spaces 'milk football cookies' + ys = [y.replace(',', ' ') for y in ys] + elif task == '19': + # pathfinding: labels like 'n,e' or 's,w' + # replace with spaces, 'n e' + ys = [y.replace(',', ' ') for y in ys] + + return ys + + # Single bAbI task (1k training). class Task1kTeacher(FbDialogTeacher): def __init__(self, opt, shared=None): task = opt.get('task', 'babi:Task1k:1') - opt['datafile'] = _path('', task.split(':')[2], opt) + self.task_num = task.split(':')[2] + opt['datafile'] = _path('', self.task_num, opt) opt['cands_datafile'] = _path('', task.split(':')[2], opt, 'train') super().__init__(opt, shared) + def setup_data(self, path): + for entry, new in super().setup_data(path): + entry[1] = mod_labels(entry[1], self.task_num) + yield entry, new + + def load_cands(self, path): + return mod_labels(super().load_cands(path), self.task_num) + # Single bAbI task (10k training). class Task10kTeacher(FbDialogTeacher): def __init__(self, opt, shared=None): task = opt.get('task', 'babi:Task10k:1') - opt['datafile'] = _path('-10k', task.split(':')[2], opt) + self.task_num = task.split(':')[2] + opt['datafile'] = _path('-10k', self.task_num, opt) opt['cands_datafile'] = _path('-10k', task.split(':')[2], opt, 'train') super().__init__(opt, shared) + def setup_data(self, path): + for entry, new in super().setup_data(path): + entry[1] = mod_labels(entry[1], self.task_num) + yield entry, new + + def load_cands(self, path): + return mod_labels(super().load_cands(path), self.task_num) + # By default train on all tasks at once. class All1kTeacher(MultiTaskTeacher):