Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
updates to the training loop / logging (#172)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexholdenmiller authored Jun 28, 2017
1 parent 2a5b01d commit fdadbdd
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 42 deletions.
108 changes: 71 additions & 37 deletions examples/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,32 +33,38 @@
import math
import os

def run_eval(agent, opt, datatype, still_training=False):
def run_eval(agent, opt, datatype, still_training=False, max_exs=-1):
''' Eval on validation/test data. '''
print('[ running eval: ' + datatype + ' ]')
opt['datatype'] = datatype
if opt.get('evaltask'):
opt['task'] = opt['evaltask']

valid_world = create_task(opt, agent)
for i in range(len(valid_world)):
first_run = True
for _ in valid_world:
valid_world.parley()
if i == 1 and opt['display_examples']:
if first_run and opt['display_examples']:
first_run = False
print(valid_world.display() + '\n~~')
print(valid_world.report())
if valid_world.epoch_done():
if valid_world.epoch_done() or (max_exs > 0 and
valid_world.report()['total'] > max_exs):
# need to check the report for total exs done, since sometimes
# when batching some of the batch is empty (for multi-ex episodes)
break
valid_world.shutdown()
valid_report = valid_world.report()

metrics = datatype + ':' + str(valid_report)
print(metrics)
if still_training:
return valid_report
else:
if opt['model_file']:
# Write out metrics
f = open(opt['model_file'] + '.' + datatype, 'a+')
f.write(metrics + '\n')
f.close()
elif opt['model_file']:
# Write out metrics
f = open(opt['model_file'] + '.' + datatype, 'a+')
f.write(metrics + '\n')
f.close()

def main():
# Get command line arguments
Expand All @@ -69,13 +75,17 @@ def main():
'one used for training if not set)'))
train.add_argument('-d', '--display-examples',
type='bool', default=False)
train.add_argument('-e', '--num-epochs', type=int, default=1)
train.add_argument('-e', '--num-epochs', type=int, default=-1)
train.add_argument('-ttim', '--max-train-time',
type=float, default=float('inf'))
type=float, default=-1)
train.add_argument('-ltim', '--log-every-n-secs',
type=float, default=1)
type=float, default=2)
train.add_argument('-vtim', '--validation-every-n-secs',
type=float, default=0)
type=float, default=-1)
train.add_argument('-vme', '--validation-max-exs',
type=int, default=-1,
help='max examples to use during validation (default ' +
'-1 uses all)')
train.add_argument('-vp', '--validation-patience',
type=int, default=5,
help=('number of iterations of validation where result '
Expand All @@ -97,36 +107,60 @@ def main():
validate_time = Timer()
log_time = Timer()
print('[ training... ]')
parleys = 0
num_parleys = opt['num_epochs'] * int(len(world) / opt['batchsize'])
total_exs = 0
max_exs = opt['num_epochs'] * len(world)
best_accuracy = 0
impatience = 0
saved = False
for i in range(num_parleys):
while True:
world.parley()
parleys = parleys + 1
if train_time.time() > opt['max_train_time']:
print('[ max_train_time elapsed: ' + str(train_time.time()) + ' ]')
if opt['num_epochs'] > 0 and total_exs >= max_exs:
print('[ num_epochs completed: {} ]'.format(opt['num_epochs']))
break
if opt['max_train_time'] > 0 and train_time.time() > opt['max_train_time']:
print('[ max_train_time elapsed: {} ]'.format(train_time.time()))
break
if log_time.time() > opt['log_every_n_secs']:
if opt['log_every_n_secs'] > 0 and log_time.time() > opt['log_every_n_secs']:
if opt['display_examples']:
print(world.display() + '\n~~')
parleys_per_sec = train_time.time() / parleys
time_left = (num_parleys - parleys) * parleys_per_sec
log = ('[ time:' + str(math.floor(train_time.time()))
+ 's parleys:' + str(parleys)
+ ' time_left:'
+ str(math.floor(time_left)) + 's ]')

logs = []
# time elapsed
logs.append('time:{}s'.format(math.floor(train_time.time())))

# get report and update total examples seen so far
if hasattr(agent, 'report'):
log = log + str(agent.report())
train_report = agent.report()
agent.reset_metrics()
else:
log = log + str(world.report())
# TODO: world.reset_metrics()
train_report = world.report()
world.reset_metrics()
total_exs += train_report['total']
logs.append('total_exs:{}'.format(total_exs))

# check if we should log amount of time remaining
time_left = None
if opt['num_epochs'] > 0:
exs_per_sec = train_time.time() / total_exs
time_left = (max_exs - total_exs) * exs_per_sec
if opt['max_train_time'] > 0:
other_time_left = opt['max_train_time'] - train_time.time()
if time_left is not None:
time_left = min(time_left, other_time_left)
else:
time_left = other_time_left
if time_left is not None:
logs.append('time_left:{}s'.format(math.floor(time_left)))

# join log string and add full metrics report to end of log
log = '[ {} ] {}'.format(' '.join(logs), train_report)

print(log)
log_time.reset()
if (opt['validation_every_n_secs'] and
validate_time.time() > opt['validation_every_n_secs']):
valid_report = run_eval(agent, opt, 'valid', True)

if (opt['validation_every_n_secs'] > 0 and
validate_time.time() > opt['validation_every_n_secs']):
valid_report = run_eval(agent, opt, 'valid', True, opt['validation_max_exs'])
if valid_report['accuracy'] > best_accuracy:
best_accuracy = valid_report['accuracy']
impatience = 0
Expand All @@ -139,11 +173,11 @@ def main():
break
else:
impatience += 1
print('[ did not beat best accuracy: ' + str(best_accuracy) +
' impatience: ' + str(impatience) + ' ]')
print('[ did not beat best accuracy: {} impatience: {} ]'.format(
round(best_accuracy, 4), impatience))
validate_time.reset()
if impatience >= opt['validation_patience']:
print('[ ran out of patience! stopping. ]')
if opt['validation_patience'] > 0 and impatience >= opt['validation_patience']:
print('[ ran out of patience! stopping training. ]')
break
world.shutdown()
if not saved:
Expand Down
10 changes: 10 additions & 0 deletions parlai/core/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def getID(self):
def reset(self):
self.observation = None

def reset_metrics(self):
pass

def share(self):
"""If applicable, share any parameters needed to create a shared version
of this agent.
Expand Down Expand Up @@ -133,7 +136,10 @@ def report(self):

def reset(self):
super().reset()
self.reset_metrics()
self.epochDone = False

def reset_metrics(self):
self.metrics.clear()

def share(self):
Expand Down Expand Up @@ -242,6 +248,10 @@ def reset(self):
for t in self.tasks:
t.reset()

def reset_metrics(self):
for t in self.tasks:
t.reset_metrics()

def share(self):
shared = {}
shared['class'] = type(self)
Expand Down
6 changes: 3 additions & 3 deletions parlai/core/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,11 @@ def report(self):
m = {}
m['total'] = self.metrics['cnt']
if self.metrics['cnt'] > 0:
m['accuracy'] = self.metrics['correct'] / self.metrics['cnt']
m['f1'] = self.metrics['f1'] / self.metrics['cnt']
m['accuracy'] = round(self.metrics['correct'] / self.metrics['cnt'], 4)
m['f1'] = round(self.metrics['f1'] / self.metrics['cnt'], 4)
m['hits@k'] = {}
for k in self.eval_pr:
m['hits@k'][k] = self.metrics['hits@' + str(k)] / self.metrics['cnt']
m['hits@k'][k] = round(self.metrics['hits@' + str(k)] / self.metrics['cnt'], 4)
return m

def clear(self):
Expand Down
11 changes: 11 additions & 0 deletions parlai/core/worlds.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,10 @@ def reset(self):
for a in self.agents:
a.reset()

def reset_metrics(self):
for a in self.agents:
a.reset_metrics()

def synchronize(self):
"""Can be used to synchronize processes."""
pass
Expand Down Expand Up @@ -480,6 +484,10 @@ def reset(self):
for w in self.worlds:
w.reset()

def reset_metrics(self):
for w in self.worlds:
w.reset_metrics()


def override_opts_in_shared(table, overrides):
"""Looks recursively for ``opt`` dictionaries within shared dict and overrides
Expand Down Expand Up @@ -621,6 +629,9 @@ def reset(self):
for w in self.worlds:
w.reset()

def reset_metrics(self):
self.worlds[0].reset_metrics()


class HogwildProcess(Process):
"""Process child used for ``HogwildWorld``.
Expand Down
4 changes: 2 additions & 2 deletions parlai/tasks/task_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,14 @@
{
"id": "personalized-dialog-full",
"display_name": "Personalized Dialog Full Set",
"task": "personalized-dialog:full",
"task": "personalized_dialog:full",
"tags": [ "all", "Goal", "Personalization" ],
"description": "Simulated dataset of restaurant booking focused on personalization based on user profiles. From Joshi et al. '17. Link: https://arxiv.org/abs/1706.07503"
},
{
"id": "personalized-dialog-small",
"display_name": "Personalized Dialog Small Set",
"task": "personalized-dialog:small",
"task": "personalized_dialog:small",
"tags": [ "all", "Goal", "Personalization" ],
"description": "Simulated dataset of restaurant booking focused on personalization based on user profiles. From Joshi et al. '17. Link: https://arxiv.org/abs/1706.07503"
},
Expand Down

0 comments on commit fdadbdd

Please sign in to comment.