Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Logging includes %done and ETA via a util function #856

Merged
merged 8 commits into from
Jun 14, 2018
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions parlai/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ def load_cands(path, lines_have_ids = False, cands_are_replies = False):
return cands




Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - added a few blank lines here :)

class Predictor(object):
"""Provides functionality for setting up a running version of a model and
requesting predictions from that model on live data.
Expand Down Expand Up @@ -175,6 +177,34 @@ def time(self):
return self.total


class TimeLogger():
def __init__(self):
self.timer = Timer()
self.tot_time = 0

def total_time(self):
return self.tot_time

def time(self):
return self.timer.time()

def log(self, done, total, report={}):
self.tot_time += self.timer.time()
self.timer.reset()
log = {}
log['total'] = done
if total > 0:
log['%done'] = done / total
if log["%done"] > 0:
log['time_left'] = str(int(self.tot_time / log['%done'] - self.tot_time)) + 's'
z = '%.2f' % ( 100*log['%done'])
log['%done'] = str(z) + '%'
for k, v in report.items():
if k not in log:
log[k] = v
text = str(int(self.tot_time)) + "s elapsed: " + str(log)
return text, log

class AttrDict(dict):
"""Helper class to have a dict-like object with dot access.

Expand Down
6 changes: 6 additions & 0 deletions parlai/core/worlds.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,9 @@ def num_examples(self):
def num_episodes(self):
return self.world.num_episodes()

def get_total_exs(self):
return self.world.get_total_exs()

def getID(self):
return self.world.getID()

Expand Down Expand Up @@ -833,6 +836,9 @@ def num_examples(self):
def num_episodes(self):
return self.inner_world.num_episodes()

def get_total_exs(self):
return self.inner_world.get_total_exs()

def get_total_epochs(self):
"""Return total amount of epochs on which the world has trained."""
if self.max_exs is None:
Expand Down
17 changes: 15 additions & 2 deletions parlai/scripts/build_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
from parlai.core.dict import DictionaryAgent
from parlai.core.params import ParlaiParser, str2class
from parlai.core.worlds import create_task
from parlai.core.utils import TimeLogger
import copy
import os

import sys

def setup_args(parser=None):
if parser is None:
Expand All @@ -22,6 +23,7 @@ def setup_args(parser=None):
help='Include validation set in dictionary building for task.')
dict_loop.add_argument('--dict-include-test', default=False, type='bool',
help='Include test set in dictionary building for task.')
dict_loop.add_argument('-ltim', '--log-every-n-secs', type=float, default=2)
partial, _ = parser.parse_known_args(nohelp=True)
if vars(partial).get('dict_class'):
str2class(vars(partial).get('dict_class')).add_cmdline_args(parser)
Expand All @@ -37,7 +39,6 @@ def build_dict(opt, skip_if_built=False):
print('Tried to build dictionary but `--dict-file` is not set. Set ' +
'this param so the dictionary can be saved.')
return
print('[ setting up dictionary. ]')

if skip_if_built and os.path.isfile(opt['dict_file']):
# Dictionary already built, skip all loading or setup
Expand Down Expand Up @@ -78,13 +79,25 @@ def build_dict(opt, skip_if_built=False):
ordered_opt['datatype'] = dt
world_dict = create_task(ordered_opt, dictionary)
# pass examples to dictionary
print('[ running dictionary over data.. ]')
log_every_n_secs = opt.get('log_every_n_secs', -1)
if log_every_n_secs <= 0:
log_every_n_secs = float('inf')
log_time = TimeLogger()
while not world_dict.epoch_done():
cnt += 1
if cnt > opt['dict_maxexs'] and opt['dict_maxexs'] > 0:
print('Processed {} exs, moving on.'.format(opt['dict_maxexs']))
# don't wait too long...
break
world_dict.parley()
if log_time.time() > log_every_n_secs:
sys.stdout.write('\r')
text, _log = log_time.log(cnt, max(opt.get('dict_maxexs',0),
world_dict.num_examples()))
sys.stdout.write(text)
sys.stdout.flush()

dictionary.save(opt['dict_file'], sort=True)
print('[ dictionary built with {} tokens ]'.format(len(dictionary)))
return dictionary
Expand Down
19 changes: 6 additions & 13 deletions parlai/scripts/detect_offensive_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from parlai.core.agents import create_agent
from parlai.core.worlds import create_task
from parlai.agents.repeat_label.repeat_label import RepeatLabelAgent
from parlai.core.utils import Timer, OffensiveLanguageDetector
from parlai.core.utils import OffensiveLanguageDetector, TimeLogger

import random

Expand Down Expand Up @@ -50,8 +50,7 @@ def detect(opt, printargs=None, print_parser=None):
log_every_n_secs = opt.get('log_every_n_secs', -1)
if log_every_n_secs <= 0:
log_every_n_secs = float('inf')
log_time = Timer()
tot_time = 0
log_time = TimeLogger()

# Show some example dialogs:
cnt = 0
Expand All @@ -71,17 +70,11 @@ def detect(opt, printargs=None, print_parser=None):
print(world.display() + "\n~~")
cnt += 1
if log_time.time() > log_every_n_secs:
tot_time += log_time.time()
report = world.report()
log = {'total': report['total']}
log['done'] = report['total'] / world.num_examples()
if log['done'] > 0:
log['eta'] = int(tot_time / log['done'] - tot_time)
z = '%.2f' % ( 100*log['done'])
log['done'] = str(z) + '%'
log['offenses'] = cnt
print(str(int(tot_time)) + "s elapsed: " + str(log))
log_time.reset()
log = { 'offenses': cnt }
text, log = log_time.log(report['total'], world.num_examples(), log)
print(text)

if world.epoch_done():
print("EPOCH DONE")
print(str(cnt) + " offensive messages found out of " +
Expand Down
11 changes: 5 additions & 6 deletions parlai/scripts/eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from parlai.core.params import ParlaiParser
from parlai.core.agents import create_agent
from parlai.core.worlds import create_task
from parlai.core.utils import Timer
from parlai.core.utils import TimeLogger

import random
import os
Expand Down Expand Up @@ -63,8 +63,7 @@ def eval_model(opt, printargs=None, print_parser=None):
log_every_n_secs = opt.get('log_every_n_secs', -1)
if log_every_n_secs <= 0:
log_every_n_secs = float('inf')
log_time = Timer()
tot_time = 0
log_time = TimeLogger()

# Show some example dialogs:
cnt = 0
Expand All @@ -74,9 +73,9 @@ def eval_model(opt, printargs=None, print_parser=None):
if opt['display_examples']:
print(world.display() + "\n~~")
if log_time.time() > log_every_n_secs:
tot_time += log_time.time()
print(str(int(tot_time)) + "s elapsed: " + str(world.report()))
log_time.reset()
report = world.report()
text, report = log_time.log(report['total'], world.num_examples(), report)
print(text)
if opt['num_examples'] > 0 and cnt >= opt['num_examples']:
break
if world.epoch_done():
Expand Down
2 changes: 1 addition & 1 deletion parlai/scripts/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def log(self):

# time elapsed
logs.append('time:{}s'.format(math.floor(self.train_time.time())))
logs.append('parleys:{}'.format(self.parleys))
logs.append('total_exs:{}'.format(self.world.get_total_exs()))

if 'time_left' in train_report:
logs.append('time_left:{}s'.format(
Expand Down
4 changes: 2 additions & 2 deletions projects/personachat/kvmemnn/kvmemnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,10 +710,10 @@ def valid(obs):
if ys is None:
# only build candidates in eval mode.
for o in observations:
if 'label_candidates' in o:
if 'label_candidates' in o and o['label_candidates'] is not None:
cs = []
ct = []
for c in o['label_candidates'] and o['label_candidates'] is not None:
for c in o['label_candidates']:
cs.append(Variable(torch.LongTensor(self.parse(c)).unsqueeze(0)))
ct.append(c)
cands.append(cs)
Expand Down