Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
fixes to multiteacher, batching (#105)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexholdenmiller authored May 31, 2017
1 parent 27cade2 commit d3beeff
Show file tree
Hide file tree
Showing 8 changed files with 131 additions and 106 deletions.
5 changes: 2 additions & 3 deletions parlai/agents/ir_baseline/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,12 @@ def build_query_representation(self, query):
rw = rep['words']
used = {}
for w in words:
if len(self.dictionary.freq) > 0:
rw[w] = 1.0 / (1.0 + math.log(1.0 + self.dictionary.freq[w]))
if len(self.dictionary.freqs()) > 0:
rw[w] = 1.0 / (1.0 + math.log(1.0 + self.dictionary.freqs()[w]))
else:
if w not in stopwords:
rw[w] = 1
used[w] = True
norm = len(used)
rep['norm'] = math.sqrt(len(words))
return rep

53 changes: 34 additions & 19 deletions parlai/core/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,14 +258,17 @@ def __init__(self, opt, shared=None):
self.tasks = []
self.opt = opt
self.id = opt['task']
tasks = opt['task'].split(',')
for k in tasks:
k = k.strip()
if k:
opt_singletask = copy.deepcopy(opt)
opt_singletask['task'] = k
self.tasks.extend(create_task_agent_from_taskname(
opt_singletask))
if shared and 'tasks' in shared:
self.tasks = [create_agent_from_shared(t) for t in shared['tasks']]
else:
tasks = opt['task'].split(',')
for k in tasks:
k = k.strip()
if k:
opt_singletask = copy.deepcopy(opt)
opt_singletask['task'] = k
self.tasks.extend(create_task_agent_from_taskname(
opt_singletask))
self.task_idx = -1
self.new_task = True
self.random = opt.get('datatype') == 'train'
Expand All @@ -286,23 +289,24 @@ def __next__(self):
raise StopIteration()

def observe(self, observation):
self.tasks[self.task_idx].observe(observation)
return self.tasks[self.task_idx].observe(observation)

def act(self):
if self.new_task:
self.new_task = False
if self.random:
# select random teacher
self.task_idx = random.randrange(len(self.tasks))
else:
start_idx = self.task_idx
keep_looking = True
while keep_looking:
# do at most one full loop looking for unfinished task
for _ in range(len(self.tasks)):
self.task_idx = (self.task_idx + 1) % len(self.tasks)
keep_looking = (self.tasks[self.task_idx].epoch_done() and
start_idx != self.task_idx)
if start_idx == self.task_idx:
return {'text': 'There are no more examples remaining.'}
return observation

def act(self):
if not self.tasks[self.task_idx].epoch_done():
# if this task has examples ready, break
break
if self.tasks[self.task_idx].epoch_done():
# all tasks are done, so return empty action table
return {'episode_done': True}
t = self.tasks[self.task_idx].act()
if t['episode_done']:
self.new_task = True
Expand Down Expand Up @@ -333,3 +337,14 @@ def report(self):
if num_tasks > 0:
m['accuracy'] = sum_accuracy / num_tasks
return m

def reset(self):
for t in self.tasks:
t.reset()

def share(self):
shared = {}
shared['class'] = type(self)
shared['opt'] = self.opt
shared['tasks'] = [t.share() for t in self.tasks]
return shared
17 changes: 8 additions & 9 deletions parlai/core/dialog_teacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,11 @@ def reset(self):
self.metrics.clear()
self.lastY = None
self.episode_idx = self.data_offset - self.step_size
self.epochDone = False
self.episode_done = True
self.epochDone = False
if not self.random and self.data_offset >= self.data.num_episodes():
# could have bigger batchsize then episodes... so nothing to do
self.epochDone = True

def __len__(self):
return len(self.data)
Expand All @@ -80,7 +83,6 @@ def __next__(self):
if self.epochDone:
raise StopIteration()

# share datatype, data, metrics, and a lock on the metrics
def share(self):
shared = super().share()
shared['data'] = self.data
Expand All @@ -95,7 +97,7 @@ def label_candidates(self):
def observe(self, observation):
"""Process observation for metrics. """
if self.lastY is not None:
loss = self.metrics.update(observation, self.lastY)
self.metrics.update(observation, self.lastY)
self.lastY = None
return observation

Expand Down Expand Up @@ -126,7 +128,7 @@ def next_example(self):
def act(self):
"""Send new dialog message."""
if self.epochDone:
return { 'episode_done': True }
return {'episode_done': True}
action, self.epochDone = self.next_example()
self.episode_done = action['episode_done']
action['id'] = self.getID()
Expand Down Expand Up @@ -186,10 +188,7 @@ def __len__(self):
"""Returns total number of entries available. Each episode has at least
one entry, but might have many more.
"""
length = 0
for l in self.data:
length += len(l)
return length
return sum(len(episode) for episode in self.data)

def _load(self, data_loader):
"""Loads up data from an iterator over tuples described in the class
Expand Down Expand Up @@ -271,7 +270,7 @@ def get(self, episode_idx, entry_idx=0):


if (table.get('labels', None) is not None
and self.cands is not None):
and self.cands is not None):
if self.addedCands:
# remove elements in addedCands
self.cands.difference_update(self.addedCands)
Expand Down
85 changes: 51 additions & 34 deletions parlai/core/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
"""Contains code for parsing and building a dictionary from text."""

from .agents import Agent
from .thread_utils import SharedTable
from collections import defaultdict
import copy
import numpy as np
import nltk
import re


def find_ngrams(token_dict, text, n):
Expand Down Expand Up @@ -55,8 +55,9 @@ class DictionaryAgent(Agent):
default_lang = 'english'
default_maxngram = -1
default_minfreq = 0
default_null = '<NULL>'
default_unk = '<UNK>'
default_null = '__NULL__'
default_eos = '__EOS__'
default_unk = '__UNK__'

@staticmethod
def add_cmdline_args(argparser):
Expand All @@ -76,29 +77,31 @@ def add_cmdline_args(argparser):
help='looks for ngrams of up to this size. this is ignored when ' +
'building the dictionary. note: this takes approximate ' +
'runtime of len(sentence)^max_ngram_size')
argparser.add_arg(
'--dict-nulltoken', default=DictionaryAgent.default_null,
help='empty token, can be used for padding or just empty values')
# TODO(ahm): minfreq isn't actually being used, add the functionality
argparser.add_arg(
'--dict-minfreq', default=DictionaryAgent.default_minfreq,
help='minimum frequency of words to include them in the dictionary')
argparser.add_arg(
'--dict-nulltoken', default=DictionaryAgent.default_null,
help='empty token, can be used for padding or just empty values')
argparser.add_arg(
'--dict-eostoken', default=DictionaryAgent.default_eos,
help='token for end of sentence markers, if needed')
argparser.add_arg(
'--dict-unktoken', default=DictionaryAgent.default_unk,
help='token to return for unavailable words')

def __init__(self, opt, shared=None):
# initialize fields
self.opt = copy.deepcopy(opt)
self.null_token = opt.get('dict_nulltoken')
self.unk_token = opt.get('dict_unktoken')
self.max_ngram_size = opt.get('dict_max_ngram_size',
self.default_maxngram)
self.null_token = opt['dict_nulltoken']
self.eos_token = opt['dict_eostoken']
self.unk_token = opt['dict_unktoken']
self.max_ngram_size = opt['dict_max_ngram_size']

if shared:
self.freq = shared.get('freq', SharedTable({}))
self.tok2ind = shared.get('tok2ind', SharedTable({}))
self.ind2tok = shared.get('ind2tok', SharedTable({}))
self.freq = shared.get('freq', {})
self.tok2ind = shared.get('tok2ind', {})
self.ind2tok = shared.get('ind2tok', {})
else:
self.freq = defaultdict(int)
self.tok2ind = {}
Expand All @@ -108,18 +111,24 @@ def __init__(self, opt, shared=None):
self.tok2ind[self.null_token] = 0
self.ind2tok[0] = self.null_token

if self.eos_token:
# set special unknown word token
index = len(self.tok2ind)
self.tok2ind[self.eos_token] = index
self.ind2tok[index] = self.eos_token

if self.unk_token:
# set special unknown word token
index = len(self.tok2ind)
self.tok2ind[self.unk_token] = index
self.ind2tok[index] = self.unk_token

if opt.get('dict_load_path'):
if 'dict_loadpath' in opt:
# load existing dictionary
self.load(opt.get('dict_load_path'))
self.load(opt['dict_loadpath'])

# initialize tokenizers
st_path = 'tokenizers/punkt/{0}.pickle'.format(opt.get('dict_language'))
st_path = 'tokenizers/punkt/{0}.pickle'.format(opt['dict_language'])
try:
self.sent_tok = nltk.data.load(st_path)
except LookupError:
Expand All @@ -129,15 +138,20 @@ def __init__(self, opt, shared=None):
self.word_tok = nltk.tokenize.treebank.TreebankWordTokenizer()

if not shared:

if self.null_token:
# fix count for null token to one billion and one
self.freq[self.null_token] = 1000000001
# fix count for null token to one billion and two
self.freq[self.null_token] = 1000000002

if self.eos_token:
# fix count for end of sentence token to one billion and one
self.freq[self.eos_token] = 1000000001

if self.unk_token:
# fix count for unknown token to one billion
self.freq[self.unk_token] = 1000000000

if opt.get('dict_savepath'):
if 'dict_savepath' in opt:
self.save_path = opt['dict_savepath']

def __contains__(self, key):
Expand Down Expand Up @@ -177,6 +191,9 @@ def __setitem__(self, key, value):
self.tok2ind[key] = index
self.ind2tok[index] = key

def freqs(self):
return self.freq

def _sent_tokenize(self, text, building=False):
"""Uses nltk-trained PunktTokenizer for sentence tokenization"""
text = text.replace('|', ' ' if building else ' __pipe__ ')
Expand All @@ -197,8 +214,6 @@ def _word_tokenize(self, text, building=False):

def tokenize(self, text, building=False):
"""Returns a sequence of tokens from the iterable."""
# TODO(ahm): this should be easy to parallelize, since we don't care
# about sentence order here
return (token for sent in self._sent_tokenize(text, building)
for token in self._word_tokenize(sent, building))

Expand Down Expand Up @@ -235,9 +250,10 @@ def load(self, filename):
token = split[0]
cnt = int(split[1]) if len(split) > 1 else 0
self.freq[token] = cnt
index = len(self.tok2ind)
self.tok2ind[token] = index
self.ind2tok[index] = token
if token not in self.tok2ind:
index = len(self.tok2ind)
self.tok2ind[token] = index
self.ind2tok[index] = token

def save(self, filename, append=False, sort=True):
"""Save dictionary to file.
Expand All @@ -257,7 +273,7 @@ def save(self, filename, append=False, sort=True):
write.write('{tok}\t{cnt}\n'.format(tok=tok, cnt=cnt))

def sort(self):
"""Sorts the dictonary, so that the elements with the lowest index have
"""Sorts the dictionary, so that the elements with the lowest index have
the highest counts. This reindexes the dictionary according to the
sorted frequencies, breaking ties alphabetically by token.
"""
Expand All @@ -272,7 +288,7 @@ def sort(self):
self.ind2tok = new_ind2tok
return sorted_pairs

def parse(self, txt_or_vec, vec_type=np.ndarray):
def parse(self, txt_or_vec, vec_type=list):
"""Convenience function for parsing either text or vectors of indices.
vec_type is the type of the returned vector if the input is a string.
Expand All @@ -284,7 +300,7 @@ def parse(self, txt_or_vec, vec_type=np.ndarray):
else:
return self.vec2txt(txt_or_vec)

def txt2vec(self, text, vec_type=np.ndarray):
def txt2vec(self, text, vec_type=list):
"""Converts a string to a vector (list of ints).
First runs a sentence tokenizer, then a word tokenizer.
vec_type is the type of the returned vector if the input is a string.
Expand All @@ -294,8 +310,10 @@ def txt2vec(self, text, vec_type=np.ndarray):
(self[token] for token in self.tokenize(str(text))),
np.int
)
else:
elif vec_type == list or vec_type == tuple or vec_type == set:
res = vec_type((self[token] for token in self.tokenize(str(text))))
else:
raise RuntimeError('Type {} not supported by dict'.format(vec_type))
assert type(res) == vec_type
return res

Expand All @@ -315,13 +333,9 @@ def act(self):
for text in source:
if text:
self.add_to_dict(self.tokenize(text))
return {}
return {'id': 'Dictionary'}

def share(self):
"""Creates shared-memory versions of the internal maps."""
self.freq = SharedTable(self.freq)
self.tok2ind = SharedTable(self.tok2ind)
self.ind2tok = SharedTable(self.ind2tok)
shared = {}
shared['freq'] = self.freq
shared['tok2ind'] = self.tok2ind
Expand All @@ -334,3 +348,6 @@ def shutdown(self):
"""Save on shutdown if savepath is set."""
if hasattr(self, 'save_path'):
self.save(self.save_path)

def __str__(self):
return str(self.freq)
Loading

0 comments on commit d3beeff

Please sign in to comment.